#include "socket/iocp.h" #include "utils/thread_pool.h" namespace Network { IOCP::IOCP() { gen_ = std::mt19937(rd_()); jitterDist_ = std::uniform_int_distribution(-10, 10); } IOCP::~IOCP() { destruct(); } void IOCP::destruct() { #ifdef __linux__ #endif } void IOCP::registerSocket(IOCPPASSINDATA* data) { #ifdef _WIN32 HANDLE returnData = ::CreateIoCompletionPort( (HANDLE)data->socket->sock, completionPort_, data->socket->sock, 0); if (returnData == 0) completionPort_ = returnData; data->event = IOCPEVENT::READ; DWORD recvbytes = 0, flags = 0; ::WSARecv(data->socket->sock, &data->wsabuf, 1, &recvbytes, &flags, &data->overlapped, NULL); #endif } int IOCP::recv(IOCPPASSINDATA* data) { SOCKET sock = data->socket->sock; std::lock_guard lock(*GetRecvQueueMutex_(sock)); auto queue = GetRecvQueue_(sock); std::uint32_t left_data = data->wsabuf.len; std::uint32_t copied = 0; while (!queue->empty() && left_data != 0) { auto front = queue->front(); queue->pop_front(); std::uint32_t offset = front.second; std::uint32_t available = front.first.size() - offset; std::uint32_t to_copy = (left_data < available) ? left_data : available; ::memcpy(data->wsabuf.buf + copied, front.first.data() + offset, to_copy); copied += to_copy; left_data -= to_copy; offset += to_copy; if (offset < front.first.size()) { front.second = offset; queue->push_front(front); break; } } return copied; } int IOCP::send(SOCKET sock, std::vector* data) { auto lk = GetSendQueueMutex_(sock); auto queue = GetSendQueue_(sock); std::lock_guard lock(*lk); for (auto& it : *data) { it->event = IOCPEVENT::WRITE; queue->push_back(it); } IOCPThread_->enqueueJob( [this, sock](utils::ThreadPool* th, std::uint8_t __) { packet_sender_(sock); }, 0); return 0; } std::shared_ptr> IOCP::GetSendQueue_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (send_queue_.find(sock) == send_queue_.end()) { send_queue_[sock] = std::make_shared>( std::list()); } return send_queue_[sock]; } std::shared_ptr, std::uint32_t>>> IOCP::GetRecvQueue_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (recv_queue_.find(sock) == recv_queue_.end()) { recv_queue_[sock] = std::make_shared< std::list, std::uint32_t>>>( std::list, std::uint32_t>>()); } return recv_queue_[sock]; } std::shared_ptr IOCP::GetSendQueueMutex_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (send_queue_mutex_.find(sock) == send_queue_mutex_.end()) { send_queue_mutex_[sock] = std::make_shared(); } return send_queue_mutex_[sock]; } std::shared_ptr IOCP::GetRecvQueueMutex_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (recv_queue_mutex_.find(sock) == recv_queue_mutex_.end()) { recv_queue_mutex_[sock] = std::make_shared(); } return recv_queue_mutex_[sock]; } void IOCP::packet_sender_(SOCKET sock) { auto queue = GetSendQueue_(sock); std::unique_lock lock(*GetSendQueueMutex_(sock)); std::vector buf(16384); WSABUF wsabuf; DWORD sendbytes = 0; while (!queue->empty()) { auto front = queue->front(); queue->pop_front(); front->event = IOCPEVENT::WRITE; int data_len = 0; if (proto_ == SessionProtocol::TLS || proto_ == SessionProtocol::QUIC) { int ret = ::SSL_write(front->ssl, front->wsabuf.buf, front->wsabuf.len); if (ret <= 0) { int err = ::SSL_get_error(front->ssl, ret); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { queue->push_front(front); break; } std::unique_lock lk(socket_mod_mutex_); send_queue_.erase(sock); break; } while ((data_len = ::BIO_read(front->wbio, buf.data(), buf.size())) > 0) { wsabuf.buf = buf.data(); wsabuf.len = data_len; ::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr); } } else { data_len = front->wsabuf.len; wsabuf.buf = front->wsabuf.buf; wsabuf.len = data_len; ::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr); } } } } // namespace Network