#include "socket/iocp.h" #include "utils/thread_pool.h" namespace Network { IOCP::IOCP() : IOCPThread_(nullptr), proto_(SessionProtocol::TCP) { gen_ = std::mt19937(rd_()); jitterDist_ = std::uniform_int_distribution(-10, 10); } IOCP::~IOCP() { destruct(); } void IOCP::init(utils::ThreadPool* __IOCPThread, SessionProtocol proto) { IOCPThread_ = __IOCPThread; proto_ = proto; #ifdef _WIN32 completionPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); if (completionPort_ == NULL) { spdlog::critical("CreateIoCompletionPort()"); std::exit(EXIT_FAILURE); } int tCount = __IOCPThread->threadCount; spdlog::info("Resizing threadpool size to: {}", tCount * 2); __IOCPThread->respawnWorker(tCount * 2); for (int i = 0; i < tCount; i++) IOCPThread_->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); #endif } void IOCP::destruct() { #ifdef __linux__ #endif } void IOCP::registerSocket(std::shared_ptr sock) { #ifdef _WIN32 HANDLE returnData = ::CreateIoCompletionPort((HANDLE)sock->sock, completionPort_, sock->sock, 0); if (returnData == 0) completionPort_ = returnData; #endif } std::future> IOCP::recvFull(std::shared_ptr sock, std::uint32_t bufsize) { auto promise = std::make_shared>>(); auto future = promise->get_future(); auto buffer = std::make_shared>(); buffer->reserve(bufsize); std::function recvChunk; recvChunk = [=](std::uint32_t remaining) mutable { this->recv(sock, remaining, [=](utils::ThreadPool* th, IOCPPASSINDATA* data) { buffer->insert(buffer->end(), data->wsabuf.buf, data->wsabuf.buf + data->transferredbytes); std::uint32_t still_left = bufsize - static_cast(buffer->size()); if (still_left > 0) { recvChunk(still_left); } else { promise->set_value(std::move(*buffer)); } return std::list(); }); }; recvChunk(bufsize); return future; } std::list DEFAULT_RECVALL_CALLBACK(utils::ThreadPool* th, IOCPPASSINDATA* data) { std::list return_value; return_value.insert(return_value.end(), data->wsabuf.buf, data->wsabuf.buf + data->transferredbytes); if (data->transferredbytes < data->wsabuf.len) { auto future = data->IOCPInstance->recv( data->socket, data->wsabuf.len - data->transferredbytes, DEFAULT_RECVALL_CALLBACK); auto result = future.get(); return_value.insert(return_value.end(), result.begin(), result.end()); } return return_value; } std::future> IOCP::recv( std::shared_ptr sock, std::uint32_t bufsize, std::function(utils::ThreadPool*, IOCPPASSINDATA*)> callback) { std::lock_guard lock(*GetRecvQueueMutex(sock->sock)); auto queue = GetRecvQueue(sock->sock); Network::IOCPPASSINDATA* data; std::packaged_task(utils::ThreadPool*, IOCPPASSINDATA*)> task; std::future> future; if (callback != nullptr) { task = std::packaged_task(utils::ThreadPool*, IOCPPASSINDATA*)>(callback); future = task.get_future(); data = new Network::IOCPPASSINDATA(sock, bufsize, this, std::move(task)); } else { data = new Network::IOCPPASSINDATA(sock, bufsize, this); } int result = SOCKET_ERROR; DWORD recvbytes = 0, flags = 0; result = ::WSARecv(sock->sock, &data->wsabuf, 1, &recvbytes, &flags, &data->overlapped, NULL); if (result == SOCKET_ERROR) { int err = ::WSAGetLastError(); if (err != WSA_IO_PENDING) { auto err_msg = std::format("WSARecv failed: {}", err); throw std::runtime_error(err_msg); } } return future; } int IOCP::send(std::shared_ptr sock, std::vector& data) { auto lk = GetSendQueueMutex(sock->sock); auto queue = GetSendQueue(sock->sock); std::lock_guard lock(*lk); Network::IOCPPASSINDATA* packet = new Network::IOCPPASSINDATA(sock, data.size(), this); packet->event = IOCPEVENT::WRITE; ::memcpy(packet->wsabuf.buf, data.data(), data.size()); packet->wsabuf.len = data.size(); queue->push_back(packet); IOCPThread_->enqueueJob( [this, sock = sock->sock](utils::ThreadPool* th, std::uint8_t __) { packet_sender_(sock); }, 0); return 0; } int IOCP::GetRecvedBytes(SOCKET sock) { auto queue = GetRecvQueue(sock); std::lock_guard lock(socket_mod_mutex_); int bytes = 0; for (auto it : *queue) { bytes += it.first.size() - it.second; } return bytes; } void IOCP::iocpWatcher_(utils::ThreadPool* IOCPThread) { IOCPPASSINDATA* data; SOCKET sock; DWORD cbTransfrred; int jitter = jitterDist_(gen_); int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred, (PULONG_PTR)&sock, (LPOVERLAPPED*)&data, 1000 + jitter); if (retVal == 0 || cbTransfrred == 0) { DWORD lasterror = GetLastError(); if (lasterror == WAIT_TIMEOUT) { IOCPThread->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); return; } data->event = IOCPEVENT::QUIT; spdlog::debug("Disconnected. [{}]", (std::string)(data->socket->remoteAddr)); auto task = [this, IOCPThread, data = std::move(data)]( utils::ThreadPool* th, std::uint8_t __) { if (data->callback.valid()) { data->callback(th, data); } data->socket->destruct(); delete data; }; IOCPThread->enqueueJob(task, 0); IOCPThread->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); return; } else { data->transferredbytes = cbTransfrred; } auto task = [this, IOCPThread, data = std::move(data)](utils::ThreadPool* th, std::uint8_t __) { if (data->callback.valid()) data->callback(th, data); delete data; }; IOCPThread->enqueueJob(task, 0); IOCPThread->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); } std::shared_ptr> IOCP::GetSendQueue(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (send_queue_.find(sock) == send_queue_.end()) { send_queue_[sock] = std::make_shared>( std::list()); } return send_queue_[sock]; } std::shared_ptr, std::uint32_t>>> IOCP::GetRecvQueue(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (recv_queue_.find(sock) == recv_queue_.end()) { recv_queue_[sock] = std::make_shared< std::list, std::uint32_t>>>( std::list, std::uint32_t>>()); } return recv_queue_[sock]; } std::shared_ptr IOCP::GetSendQueueMutex(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (send_queue_mutex_.find(sock) == send_queue_mutex_.end()) { send_queue_mutex_[sock] = std::make_shared(); } return send_queue_mutex_[sock]; } std::shared_ptr IOCP::GetRecvQueueMutex(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (recv_queue_mutex_.find(sock) == recv_queue_mutex_.end()) { recv_queue_mutex_[sock] = std::make_shared(); } return recv_queue_mutex_[sock]; } void IOCP::packet_sender_(SOCKET sock) { auto queue = GetSendQueue(sock); std::unique_lock lock(*GetSendQueueMutex(sock)); std::vector buf(16384); WSABUF wsabuf; DWORD sendbytes = 0; while (!queue->empty()) { auto front = queue->front(); queue->pop_front(); front->event = IOCPEVENT::WRITE; int data_len = 0; data_len = front->wsabuf.len; wsabuf.buf = front->wsabuf.buf; wsabuf.len = data_len; ::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr); } } } // namespace Network