#include "socket/iocp.h" #include "utils/thread_pool.h" namespace Network { IOCP::IOCP() { gen_ = std::mt19937(rd_()); jitterDist_ = std::uniform_int_distribution(-10, 10); } IOCP::~IOCP() { destruct(); } void IOCP::init(utils::ThreadPool* __IOCPThread, SessionProtocol proto) { IOCPThread_ = __IOCPThread; proto_ = proto; #ifdef _WIN32 completionPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); if (completionPort_ == NULL) { spdlog::critical("CreateIoCompletionPort()"); std::exit(EXIT_FAILURE); } int tCount = __IOCPThread->threadCount; spdlog::info("Resizing threadpool size to: {}", tCount * 2); __IOCPThread->respawnWorker(tCount * 2); for (int i = 0; i < tCount; i++) IOCPThread_->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); #endif } void IOCP::destruct() { #ifdef __linux__ #endif } void IOCP::registerTCPSocket(IOCPPASSINDATA* data) { #ifdef _WIN32 HANDLE returnData = ::CreateIoCompletionPort( (HANDLE)data->socket->sock, completionPort_, data->socket->sock, 0); if (returnData == 0) completionPort_ = returnData; IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(data->bufsize); recv_data->event = IOCPEVENT::READ; recv_data->socket = data->socket; DWORD recvbytes = 0, flags = 0; int result = SOCKET_ERROR; ::WSARecv(recv_data->socket->sock, &recv_data->wsabuf, 1, &recvbytes, &flags, &recv_data->overlapped, NULL); if (result == SOCKET_ERROR) { int err = ::WSAGetLastError(); if (err != WSA_IO_PENDING) { auto err_msg = std::format("WSARecv failed: {}", err); throw std::runtime_error(err_msg); } } #endif } void IOCP::registerUDPSocket(IOCPPASSINDATA* data, Address recv_addr) { #ifdef _WIN32 HANDLE returnData = ::CreateIoCompletionPort( (HANDLE)data->socket->sock, completionPort_, data->socket->sock, 0); if (returnData == 0) completionPort_ = returnData; IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(data->bufsize); recv_data->event = IOCPEVENT::READ; recv_data->socket = data->socket; DWORD recvbytes = 0, flags = 0; int result = SOCKET_ERROR; ::WSARecvFrom(recv_data->socket->sock, &recv_data->wsabuf, 1, &recvbytes, &flags, &recv_addr.addr, &recv_addr.length, &recv_data->overlapped, NULL); if (result == SOCKET_ERROR) { int err = ::WSAGetLastError(); if (err != WSA_IO_PENDING) { auto err_msg = std::format("WSARecv failed: {}", err); throw std::runtime_error(err_msg); } } #endif } int IOCP::recv(IOCPPASSINDATA* data) { // 읽은 바이트수가 무조건 100임? 왜..? SOCKET sock = data->socket->sock; std::lock_guard lock(*GetRecvQueueMutex_(sock)); auto queue = GetRecvQueue_(sock); std::uint32_t left_data = data->wsabuf.len; std::uint32_t copied = 0; while (!queue->empty() && left_data != 0) { auto front = queue->front(); queue->pop_front(); std::uint32_t offset = front.second; std::uint32_t available = front.first.size() - offset; std::uint32_t to_copy = (left_data < available) ? left_data : available; ::memcpy(data->wsabuf.buf + copied, front.first.data() + offset, to_copy); copied += to_copy; left_data -= to_copy; offset += to_copy; if (offset < front.first.size()) { front.second = offset; queue->push_front(front); break; } } return copied; } int IOCP::send(SOCKET sock, std::vector* data) { auto lk = GetSendQueueMutex_(sock); auto queue = GetSendQueue_(sock); std::lock_guard lock(*lk); for (auto& it : *data) { it->event = IOCPEVENT::WRITE; queue->push_back(it); } IOCPThread_->enqueueJob( [this, sock](utils::ThreadPool* th, std::uint8_t __) { packet_sender_(sock); }, 0); return 0; } int IOCP::GetRecvedPacketCount(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); auto queue = GetRecvQueue_(sock); return queue->size(); } void IOCP::iocpWatcher_(utils::ThreadPool* IOCPThread) { IOCPPASSINDATA* data; SOCKET sock; DWORD cbTransfrred; int jitter = jitterDist_(gen_); int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred, (PULONG_PTR)&sock, (LPOVERLAPPED*)&data, 1000 + jitter); if (retVal == 0 || cbTransfrred == 0) { DWORD lasterror = GetLastError(); if (lasterror == WAIT_TIMEOUT) { IOCPThread->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); return; } data->event = IOCPEVENT::QUIT; spdlog::debug("Disconnected. [{}]", (std::string)(data->socket->remoteAddr)); delete data; } else { data->transferredbytes = cbTransfrred; } std::vector buf(16384); // SSL_read최대 반환 크기 int red_data = 0; std::lock_guard lock(*GetRecvQueueMutex_(sock)); auto queue_list = GetRecvQueue_(data->socket->sock); if (data->event == IOCPEVENT::READ) { if (proto_ == SessionProtocol::TLS || proto_ == SessionProtocol::QUIC) { // DEBUG: BIO_write 전 OpenSSL 에러 스택 확인 (혹시 모를 이전 에러) ERR_print_errors_fp(stderr); // 이미 오류 스택에 뭔가 있는지 확인용 fprintf(stderr, "--- Before BIO_write ---\n"); ::BIO_write(::SSL_get_rbio(data->ssl.get()), data->wsabuf.buf, cbTransfrred); // DEBUG: BIO_write 후 OpenSSL 에러 스택 확인 (BIO_write에서 에러 발생 시) ERR_print_errors_fp(stderr); // BIO_write에서도 에러가 발생할 수 있음 fprintf(stderr, "--- After BIO_write, cbTransfrred: %lu ---\n", cbTransfrred); while ((red_data = ::SSL_read(data->ssl.get(), buf.data(), buf.size())) > 0) { queue_list->emplace_back(std::make_pair( std::vector(buf.begin(), buf.begin() + red_data), 0)); } if (red_data == -1) { auto ssl_error_code = SSL_get_error( data->ssl.get(), red_data); // 여기서 SSL_get_error 결과 저장 auto err_msg = std::format("SSL_read failed with SSL_get_error: {}", ssl_error_code); fprintf(stderr, "%s\n", err_msg.c_str()); // *** 가장 중요한 부분: SSL_ERROR_SSL일 때 상세 에러를 강제로 출력 시도 // *** if (ssl_error_code == SSL_ERROR_SSL) { fprintf(stderr, "Detailed SSL_ERROR_SSL trace:\n"); unsigned long err_peek; // ERR_get_error()를 사용하여 스택의 모든 오류를 팝하고 출력 while ((err_peek = ERR_get_error()) != 0) { char err_str[256]; ERR_error_string_n(err_peek, err_str, sizeof(err_str)); fprintf(stderr, "OpenSSL stack error: %s\n", err_str); } } else { // SSL_ERROR_SSL이 아닌 다른 오류 (SYSCALL, WANT_READ 등)일 경우 // ERR_print_errors_fp는 여전히 유용할 수 있음 ERR_print_errors_fp(stderr); } throw std::runtime_error(err_msg); // 예외 발생 } } else { ::memcpy(buf.data(), data->wsabuf.buf, data->transferredbytes); queue_list->emplace_back(std::make_pair( std::vector(buf.begin(), buf.begin() + data->transferredbytes), 0)); } DWORD recvbytes = 0, flags = 0; IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(data->bufsize); recv_data->event = IOCPEVENT::READ; recv_data->socket = data->socket; delete data; ::WSARecv(recv_data->socket->sock, &recv_data->wsabuf, 1, &recvbytes, &flags, &recv_data->overlapped, NULL); } else { // WRITE 시, 무시한다. delete data; } IOCPThread->enqueueJob( [this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0); } std::shared_ptr> IOCP::GetSendQueue_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (send_queue_.find(sock) == send_queue_.end()) { send_queue_[sock] = std::make_shared>( std::list()); } return send_queue_[sock]; } std::shared_ptr, std::uint32_t>>> IOCP::GetRecvQueue_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (recv_queue_.find(sock) == recv_queue_.end()) { recv_queue_[sock] = std::make_shared< std::list, std::uint32_t>>>( std::list, std::uint32_t>>()); } return recv_queue_[sock]; } std::shared_ptr IOCP::GetSendQueueMutex_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (send_queue_mutex_.find(sock) == send_queue_mutex_.end()) { send_queue_mutex_[sock] = std::make_shared(); } return send_queue_mutex_[sock]; } std::shared_ptr IOCP::GetRecvQueueMutex_(SOCKET sock) { std::lock_guard lock(socket_mod_mutex_); if (recv_queue_mutex_.find(sock) == recv_queue_mutex_.end()) { recv_queue_mutex_[sock] = std::make_shared(); } return recv_queue_mutex_[sock]; } void IOCP::packet_sender_(SOCKET sock) { auto queue = GetSendQueue_(sock); std::unique_lock lock(*GetSendQueueMutex_(sock)); std::vector buf(16384); WSABUF wsabuf; DWORD sendbytes = 0; while (!queue->empty()) { auto front = queue->front(); queue->pop_front(); front->event = IOCPEVENT::WRITE; int data_len = 0; if (proto_ == SessionProtocol::TLS || proto_ == SessionProtocol::QUIC) { int ret = ::SSL_write(front->ssl.get(), front->wsabuf.buf, front->wsabuf.len); if (ret <= 0) { int err = ::SSL_get_error(front->ssl.get(), ret); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { queue->push_front(front); break; } std::unique_lock lk(socket_mod_mutex_); send_queue_.erase(sock); break; } while ((data_len = ::BIO_read(::SSL_get_wbio(front->ssl.get()), buf.data(), buf.size())) > 0) { wsabuf.buf = buf.data(); wsabuf.len = data_len; ::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr); } } else { data_len = front->wsabuf.len; wsabuf.buf = front->wsabuf.buf; wsabuf.len = data_len; ::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr); } } } } // namespace Network