#include "socket/iocp.h" #include "utils/thread_pool.h" namespace Network { IOCP::IOCP() { gen_ = std::mt19937(rd_()); jitterDist_ = std::uniform_int_distribution(-10, 10); } IOCP::~IOCP() { destruct(); } void IOCP::init(utils::ThreadPool* __IOCPThread, SessionProtocol proto) { IOCPThread_ = __IOCPThread; proto_ = proto; #ifdef _WIN32 completionPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); if (completionPort_ == NULL) { spdlog::critical("CreateIoCompletionPort()"); std::exit(EXIT_FAILURE); } int tCount = __IOCPThread->threadCount; spdlog::info("Resizing threadpool size to: {}", tCount * 2); __IOCPThread->respawnWorker(tCount * 2); for (int i = 0; i < tCount; i++) IOCPThread_->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); #endif } void IOCP::destruct() { #ifdef __linux__ #endif } void IOCP::registerTCPSocket(Socket& sock, std::uint32_t bufsize) { #ifdef _WIN32 HANDLE returnData = ::CreateIoCompletionPort((HANDLE)sock.sock, completionPort_, sock.sock, 0); if (returnData == 0) completionPort_ = returnData; IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(bufsize); recv_data->event = IOCPEVENT::READ; recv_data->socket = std::make_shared(sock); recv_data->IOCPInstance = this; DWORD recvbytes = 0, flags = 0; int result = SOCKET_ERROR; result = ::WSARecv(recv_data->socket->sock, &recv_data->wsabuf, 1, &recvbytes, &flags, &recv_data->overlapped, NULL); if (result == SOCKET_ERROR) { int err = ::WSAGetLastError(); if (err != WSA_IO_PENDING) { auto err_msg = std::format("WSARecv failed: {}", err); throw std::runtime_error(err_msg); } } #endif } void IOCP::registerUDPSocket(IOCPPASSINDATA* data, Address recv_addr) { #ifdef _WIN32 HANDLE returnData = ::CreateIoCompletionPort( (HANDLE)data->socket->sock, completionPort_, data->socket->sock, 0); if (returnData == 0) completionPort_ = returnData; IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(data->bufsize); recv_data->event = IOCPEVENT::READ; recv_data->socket = data->socket; DWORD recvbytes = 0, flags = 0; int result = SOCKET_ERROR; ::WSARecvFrom(recv_data->socket->sock, &recv_data->wsabuf, 1, &recvbytes, &flags, &recv_addr.addr, &recv_addr.length, &recv_data->overlapped, NULL); if (result == SOCKET_ERROR) { int err = ::WSAGetLastError(); if (err != WSA_IO_PENDING) { auto err_msg = std::format("WSARecv failed: {}", err); throw std::runtime_error(err_msg); } } #endif } int IOCP::recv(Socket& sock, std::vector& data) { std::lock_guard lock(*GetRecvQueueMutex_(sock.sock)); auto queue = GetRecvQueue_(sock.sock); std::uint32_t left_data = data.size(); std::uint32_t copied = 0; while (!queue->empty() && left_data != 0) { auto front = queue->front(); queue->pop_front(); std::uint32_t offset = front.second; std::uint32_t available = front.first.size() - offset; std::uint32_t to_copy = (left_data < available) ? left_data : available; ::memcpy(data.data() + copied, front.first.data() + offset, to_copy); copied += to_copy; left_data -= to_copy; offset += to_copy; if (offset < front.first.size()) { front.second = offset; queue->push_front(front); break; } } return copied; } int IOCP::send(Socket& sock, std::vector& data) { auto lk = GetSendQueueMutex_(sock.sock); auto queue = GetSendQueue_(sock.sock); std::lock_guard lock(*lk); Network::IOCPPASSINDATA* packet = new Network::IOCPPASSINDATA(data.size()); packet->event = IOCPEVENT::WRITE; packet->socket = std::make_shared(sock); packet->IOCPInstance = this; ::memcpy(packet->wsabuf.buf, data.data(), data.size()); packet->wsabuf.len = data.size(); queue->push_back(packet); IOCPThread_->enqueueJob( [this, sock = sock.sock](utils::ThreadPool* th, std::uint8_t __) { packet_sender_(sock); }, 0); return 0; } int IOCP::GetRecvedBytes(SOCKET sock) { auto queue = GetRecvQueue_(sock); std::lock_guard lock(socket_mod_mutex_); int bytes = 0; for (auto it : *queue) { bytes += it.first.size() - it.second; } return bytes; } void IOCP::iocpWatcher_(utils::ThreadPool* IOCPThread) { IOCPPASSINDATA* data; SOCKET sock; DWORD cbTransfrred; int jitter = jitterDist_(gen_); int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred, (PULONG_PTR)&sock, (LPOVERLAPPED*)&data, 1000 + jitter); if (retVal == 0 || cbTransfrred == 0) { DWORD lasterror = GetLastError(); if (lasterror == WAIT_TIMEOUT) { IOCPThread->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); return; } data->event = IOCPEVENT::QUIT; spdlog::debug("Disconnected. [{}]", (std::string)(data->socket->remoteAddr)); delete data; IOCPThread->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); return; } else { data->transferredbytes = cbTransfrred; } std::vector buf(16384); // SSL_read최대 반환 크기 int red_data = 0; std::lock_guard lock(*GetRecvQueueMutex_(sock)); auto queue_list = GetRecvQueue_(sock); if (data->event == IOCPEVENT::READ) { ::memcpy(buf.data(), data->wsabuf.buf, data->transferredbytes); queue_list->emplace_back(std::make_pair( std::vector(buf.begin(), buf.begin() + data->transferredbytes), 0)); DWORD recvbytes = 0, flags = 0; IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(data->bufsize); recv_data->event = IOCPEVENT::READ; recv_data->socket = data->socket; delete data; ::WSARecv(recv_data->socket->sock, &recv_data->wsabuf, 1, &recvbytes, &flags, &recv_data->overlapped, NULL); } else { // WRITE 시, 무시한다. spdlog::debug("writed {} bytes to {}", cbTransfrred, (std::string)(data->socket->remoteAddr)); delete data; } IOCPThread->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); } std::shared_ptr> IOCP::GetSendQueue_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (send_queue_.find(sock) == send_queue_.end()) { send_queue_[sock] = std::make_shared>( std::list()); } return send_queue_[sock]; } std::shared_ptr, std::uint32_t>>> IOCP::GetRecvQueue_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (recv_queue_.find(sock) == recv_queue_.end()) { recv_queue_[sock] = std::make_shared< std::list, std::uint32_t>>>( std::list, std::uint32_t>>()); } return recv_queue_[sock]; } std::shared_ptr IOCP::GetSendQueueMutex_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (send_queue_mutex_.find(sock) == send_queue_mutex_.end()) { send_queue_mutex_[sock] = std::make_shared(); } return send_queue_mutex_[sock]; } std::shared_ptr IOCP::GetRecvQueueMutex_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (recv_queue_mutex_.find(sock) == recv_queue_mutex_.end()) { recv_queue_mutex_[sock] = std::make_shared(); } return recv_queue_mutex_[sock]; } void IOCP::packet_sender_(SOCKET sock) { auto queue = GetSendQueue_(sock); std::unique_lock lock(*GetSendQueueMutex_(sock)); std::vector buf(16384); WSABUF wsabuf; DWORD sendbytes = 0; while (!queue->empty()) { auto front = queue->front(); queue->pop_front(); front->event = IOCPEVENT::WRITE; int data_len = 0; data_len = front->wsabuf.len; wsabuf.buf = front->wsabuf.buf; wsabuf.len = data_len; ::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr); } } } // namespace Network