From 3f052a5f7f629c6f2039bfe0171418196b66c110 Mon Sep 17 00:00:00 2001 From: HappyTanuki Date: Mon, 2 Jun 2025 04:30:23 +0900 Subject: [PATCH] =?UTF-8?q?iocp=EA=B0=80=20udp,=20tcp,=20tls,=20quic=20?= =?UTF-8?q?=EC=A0=84=EB=B6=80=20=EC=A7=80=EC=9B=90=ED=95=98=EB=8F=84?= =?UTF-8?q?=EB=A1=9D=20=EC=88=98=EC=A0=95=20=EC=A4=91...?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- impl/socket/iocp.cpp | 151 ++++++++++++++++++++++++++++++------ include/socket/iocp.h | 114 ++++++++++++++++----------- include/utils/thread_pool.h | 4 +- 3 files changed, 198 insertions(+), 71 deletions(-) diff --git a/impl/socket/iocp.cpp b/impl/socket/iocp.cpp index 550b3de..52e912d 100644 --- a/impl/socket/iocp.cpp +++ b/impl/socket/iocp.cpp @@ -17,37 +17,138 @@ void IOCP::registerSocket(IOCPPASSINDATA* data) { HANDLE returnData = ::CreateIoCompletionPort( (HANDLE)data->socket->sock, completionPort_, data->socket->sock, 0); if (returnData == 0) completionPort_ = returnData; + + data->event = IOCPEVENT::READ; + DWORD recvbytes = 0, flags = 0; + ::WSARecv(data->socket->sock, &data->wsabuf, 1, &recvbytes, &flags, + &data->overlapped, NULL); #endif } -//int IOCP::recv(IOCPPASSINDATA* data, int bufferCount) { -// data->event = IOCPEVENT::READ; -//#ifdef _WIN32 -// DWORD recvbytes = 0, flags = 0; -// return ::WSARecv(data->socket->sock, &data->wsabuf, bufferCount, &recvbytes, -// &flags, &data->overlapped, NULL); -//#endif -//} -// -//int IOCP::send(IOCPPASSINDATA* data, int bufferCount, int __flags) { -// data->event = IOCPEVENT::WRITE; -//#ifdef _WIN32 -// DWORD sendbytes = 0; -// return ::WSASend(data->socket->sock, &data->wsabuf, bufferCount, &sendbytes, -// __flags, &data->overlapped, NULL); -//#endif -//} - int IOCP::recv(IOCPPASSINDATA& data) { - data.event = IOCPEVENT::READ; -#ifdef _WIN32 -#endif + SOCKET sock = data.socket->sock; + std::lock_guard lock(*GetRecvQueueMutex_(sock)); + auto queue = GetRecvQueue_(sock); + + std::uint32_t left_data = data.wsabuf.len; + std::uint32_t copied = 0; + + while (!queue->empty() && left_data != 0) { + auto front = queue->front(); + queue->pop_front(); + + std::uint32_t offset = front.second; + std::uint32_t available = front.first.size() - offset; + std::uint32_t to_copy = (left_data < available) ? left_data : available; + + ::memcpy(data.wsabuf.buf + copied, front.first.data() + offset, to_copy); + copied += to_copy; + left_data -= to_copy; + offset += to_copy; + + if (offset < front.first.size()) { + front.second = offset; + queue->push_front(front); + break; + } + } + + return copied; } -int IOCP::send(std::vector data) { - data.event = IOCPEVENT::WRITE; -#ifdef _WIN32 -#endif +int IOCP::send(SOCKET sock, std::vector& data) { + std::lock_guard lock(*send_queue_mutex_[sock]); + for (auto it : data) { + it.event = IOCPEVENT::WRITE; + send_queue_[sock]->push_back(it); + } + + IOCPThread_->enqueueJob( + [this, sock](utils::ThreadPool* th, std::uint8_t __) { + packet_sender_(sock); + }, + 0); + return 0; +} + +std::shared_ptr> IOCP::GetSendQueue_(SOCKET sock) { + std::lock_guard lock(socket_mod_mutex_); + if (send_queue_.find(sock) == send_queue_.end()) { + send_queue_[sock] = std::make_shared>( + std::list()); + } + return send_queue_[sock]; +} + +std::shared_ptr, std::uint32_t>>> +IOCP::GetRecvQueue_(SOCKET sock) { + std::lock_guard lock(socket_mod_mutex_); + if (recv_queue_.find(sock) == recv_queue_.end()) { + recv_queue_[sock] = std::make_shared< + std::list, std::uint32_t>>>( + std::list, std::uint32_t>>()); + } + return recv_queue_[sock]; +} + +std::shared_ptr IOCP::GetSendQueueMutex_(SOCKET sock) { + std::lock_guard lock(socket_mod_mutex_); + if (send_queue_mutex_.find(sock) == send_queue_mutex_.end()) { + send_queue_mutex_[sock] = std::make_shared(); + } + return send_queue_mutex_[sock]; +} + +std::shared_ptr IOCP::GetRecvQueueMutex_(SOCKET sock) { + std::lock_guard lock(socket_mod_mutex_); + if (recv_queue_mutex_.find(sock) == recv_queue_mutex_.end()) { + recv_queue_mutex_[sock] = std::make_shared(); + } + return recv_queue_mutex_[sock]; +} + +void IOCP::packet_sender_(SOCKET sock) { + auto queue = GetSendQueue_(sock); + std::unique_lock lock(*GetSendQueueMutex_(sock)); + + std::vector buf(16384); + WSABUF wsabuf; + DWORD sendbytes = 0; + + while (!queue->empty()) { + auto front = queue->front(); + queue->pop_front(); + front.event = IOCPEVENT::WRITE; + + int data_len = 0; + + if (proto_ == SessionProtocol::TLS || proto_ == SessionProtocol::QUIC) { + int ret = ::SSL_write(front.ssl, front.wsabuf.buf, front.wsabuf.len); + if (ret <= 0) { + int err = ::SSL_get_error(front.ssl, ret); + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { + queue->push_front(front); + break; + } + std::unique_lock lk(socket_mod_mutex_); + send_queue_.erase(sock); + break; + } + + while ((data_len = ::BIO_read(wbio_, buf.data(), buf.size())) > 0) { + wsabuf.buf = buf.data(); + wsabuf.len = data_len; + + ::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr); + } + + } else { + data_len = front.wsabuf.len; + wsabuf.buf = front.wsabuf.buf; + wsabuf.len = data_len; + ::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr); + } + } } } // namespace Network diff --git a/include/socket/iocp.h b/include/socket/iocp.h index 5872231..c6d83ff 100644 --- a/include/socket/iocp.h +++ b/include/socket/iocp.h @@ -35,6 +35,8 @@ struct IOCPPASSINDATA { IOCPEVENT event; std::shared_ptr socket; SSL* ssl; + BIO* rbio; // bio는 ssl별로 달라야 하므로 분리해야 함.. + BIO* wbio; std::uint32_t transferredbytes; WSABUF wsabuf; std::uint32_t bufsize; @@ -52,6 +54,7 @@ struct IOCPPASSINDATA { IOCPInstance = nullptr; wsabuf.buf = new char[bufsize]; + wsabuf.len = bufsize; } IOCPPASSINDATA(const IOCPPASSINDATA& other) : event(other.event), @@ -65,12 +68,13 @@ struct IOCPPASSINDATA { #endif { std::memset(&overlapped, 0, sizeof(overlapped)); - wsabuf.buf = new char[bufsize]; + wsabuf.buf = new char[other.bufsize]; + wsabuf.len = other.bufsize; std::memcpy(wsabuf.buf, other.wsabuf.buf, other.wsabuf.len); } ~IOCPPASSINDATA() { - if (wsabuf.buf != nullptr) delete wsabuf.buf; + if (wsabuf.buf != nullptr) delete[] wsabuf.buf; } IOCPPASSINDATA& operator=(const IOCPPASSINDATA& other) { @@ -84,7 +88,9 @@ struct IOCPPASSINDATA { #ifdef __linux__ sendQueue = other.sendQueue; #endif - wsabuf.buf = new char[bufsize]; + if (wsabuf.buf != nullptr) delete[] wsabuf.buf; + wsabuf.buf = new char[other.bufsize]; + wsabuf.len = other.bufsize; std::memcpy(wsabuf.buf, other.wsabuf.buf, other.wsabuf.len); } return *this; @@ -95,11 +101,7 @@ class IOCP { public: ~IOCP(); - template - void init(utils::ThreadPool* __IOCPThread, SessionProtocol proto, - std::function - callback, - _args&&... args) { + void init(utils::ThreadPool* __IOCPThread, SessionProtocol proto) { IOCPThread_ = __IOCPThread; proto_ = proto; @@ -121,7 +123,9 @@ class IOCP { __IOCPThread->respawnWorker(tCount * 2); for (int i = 0; i < tCount; i++) - __IOCPThread->enqueueJob(iocpWatcher_, IOCPThread_, callback, args...); + IOCPThread_->enqueueJob( + [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, + 0); #endif } @@ -129,50 +133,68 @@ class IOCP { void registerSocket(IOCPPASSINDATA* data); - int send(std::vector data); + // data는 한 가지 소켓에 보내는 패킷만 담아야 합니다 + int send(SOCKET sock, std::vector& data); int recv(IOCPPASSINDATA& data); private: #ifdef _WIN32 - template - void iocpWatcher_( - utils::ThreadPool* IOCPThread, - std::function - callback, - _args&&... args) { + void iocpWatcher_(utils::ThreadPool* IOCPThread) { IOCPPASSINDATA* data; SOCKET sock; DWORD cbTransfrred; int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred, (PULONG_PTR)&sock, (LPOVERLAPPED*)&data, INFINITE); + if (retVal == 0 || cbTransfrred == 0) { data->event = IOCPEVENT::QUIT; spdlog::debug("Disconnected. [{}]", (std::string)(data->socket->remoteAddr)); + delete data; } else { - if (data->event == IOCPEVENT::READ && - (proto_ == SessionProtocol::TLS || proto_ == SessionProtocol::QUIC)) { - ::BIO_write(rbio_, data->wsabuf.buf, cbTransfrred); - ::SSL_read_ex(data->ssl, data->wsabuf.buf, data->bufsize); - } data->transferredbytes = cbTransfrred; } - IOCPThread->enqueueJob(callback, IOCPThread, data, args...); - IOCPThread->enqueueJob(iocpWatcher_, callback, args...); - }; -#elif __linux__ + std::vector buf(16384); // SSL_read최대 반환 크기 + int red_data = 0; + auto queue_list = GetRecvQueue_(data->socket->sock); + if (data->event == IOCPEVENT::READ) { + if (proto_ == SessionProtocol::TLS || proto_ == SessionProtocol::QUIC) { + ::BIO_write(rbio_, data->wsabuf.buf, cbTransfrred); -#endif - - void packet_sender_(utils::ThreadPool* IOCPThread) { -#ifdef _WIN32 - std::lock_guard lock(send_queue_mutex_); -#elif __linux__ - -#endif + while ((red_data = ::SSL_read(data->ssl, buf.data(), buf.size())) > 0) { + queue_list->emplace_back(std::make_pair( + std::vector(buf.begin(), buf.begin() + red_data), 0)); + } + } else { + ::memcpy(buf.data(), data->wsabuf.buf, data->transferredbytes); + queue_list->emplace_back(std::make_pair( + std::vector(buf.begin(), + buf.begin() + data->transferredbytes), + 0)); + } + DWORD recvbytes = 0, flags = 0; + ::WSARecv(data->socket->sock, &data->wsabuf, 1, &recvbytes, &flags, + &data->overlapped, NULL); + } else { // WRITE 시, 무시한다. + delete data; + } + IOCPThread->enqueueJob( + [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, + 0); } +#elif __linux__ + +#endif + + std::shared_ptr> GetSendQueue_(SOCKET sock); + std::shared_ptr, std::uint32_t>>> + GetRecvQueue_(SOCKET sock); + std::shared_ptr GetSendQueueMutex_(SOCKET sock); + std::shared_ptr GetRecvQueueMutex_(SOCKET sock); + + void packet_sender_(SOCKET sock); struct WSAManager* wsaManager = WSAManager::GetInstance(); utils::ThreadPool* IOCPThread_; @@ -182,7 +204,8 @@ class IOCP { SessionProtocol proto_; - // 밑의 unordered_map들에 키를 추가/제거 하려는 스레드는 이 뮤텍스를 잡아야 함. + // 밑의 unordered_map들에 키를 추가/제거 하려는 스레드는 이 뮤텍스를 잡아야 + // 함. std::mutex socket_mod_mutex_; // 각 소켓별 뮤텍스. 다른 스레드가 읽는 중이라면 수신 순서 보장을 위해 다른 @@ -192,11 +215,12 @@ class IOCP { // 채워야 함. (항시 socket에 대한 큐에 대해 읽기 시도가 행해질 수 있어야 // 한다는 뜻임) EPOLLIN에 대해서는 ONESHOT으로 등록해 놓고 읽는 도중에 버퍼에 // 새 값이 채워질 수 있으므로 읽기가 끝나고 나서 재등록한다 - std::unordered_map recv_queue_mutex_; + std::unordered_map> recv_queue_mutex_; // 각 소켓별 패킷, int는 그 vector의 시작 인덱스(vector의 끝까지 다 읽었으면 // 그 vector는 list에서 삭제되어야 하며, 데이터는 평문으로 변환하여 저장한다) - std::unordered_map, std::uint32_t>>> + std::unordered_map< + SOCKET, + std::shared_ptr, std::uint32_t>>>> recv_queue_; // 각 소켓별 뮤텍스. 다른 스레드가 쓰는 중이라면 송신 순서 보장을 위해 다른 @@ -206,13 +230,15 @@ class IOCP { // 윈도우에서는 송신을 iocp가 대행하기 때문에 이 큐가 필요 없는 것처럼 느껴질 // 수 있으나, 송신 순서를 보장하기 위해 WSASend를 한 스레드가 연속해서 // 호출해야만 하는데 이는 한번 쓰기 호출 시에 송신 중 데이터를 추가하려는 - // 스레드가 데이터를 추가하고 미전송된 경우를 대비하여 최대 1개까지의 스레드가 - // 대기하도록 한다. 리눅스에서도 send_queue에 데이터를 쌓고 최대 1개까지의 - // 스레드가 대기하도록 한다. - std::unordered_map pending_try_empty_; - std::unordered_map sending_; - std::unordered_map send_queue_mutex_; - std::unordered_map> send_queue_; + // 스레드가 데이터를 추가하고 미전송된 경우를 대비하여 스레드가 대기하도록 + // 한다. 리눅스에서도 send_queue에 데이터를 쌓고 스레드가 대기하도록 한다. + std::unordered_map> send_queue_mutex_; + std::unordered_map>> + send_queue_; + // 쓰기 싫었지만 쓰기 큐를 직렬화 하는 것 밖에 좋은 수가 생각이 안 남.. + /*std::mutex send_queue_mutex_; + std::condition_variable cv_send_queue_; + std::list send_queue_;*/ #ifdef _WIN32 HANDLE completionPort_ = INVALID_HANDLE_VALUE; diff --git a/include/utils/thread_pool.h b/include/utils/thread_pool.h index 44e4bb0..37f2d7e 100644 --- a/include/utils/thread_pool.h +++ b/include/utils/thread_pool.h @@ -28,7 +28,7 @@ class ThreadPool { std::invoke_result_t<_Callable, _Args...>& retVal, _Args&&... __args) { if (terminate_) { - spdlog::error("Cannot run jobs on threads that terminating..."); + spdlog::warn("Cannot run jobs on threads that terminating..."); return -1; } @@ -48,7 +48,7 @@ class ThreadPool { std::invoke_result_t<_Callable, ThreadPool*, _Args...>> int enqueueJob(_Callable&& __job, _Args&&... __args) { if (terminate_) { - spdlog::error("Cannot run jobs on threads that terminating..."); + spdlog::warn("Cannot run jobs on threads that terminating..."); return -1; }