Files
Np_Term/impl/socket/iocp.cpp

347 lines
11 KiB
C++

#include "socket/iocp.h"
#include "utils/thread_pool.h"
namespace Network {
IOCP::IOCP() {
gen_ = std::mt19937(rd_());
jitterDist_ = std::uniform_int_distribution<int>(-10, 10);
}
IOCP::~IOCP() { destruct(); }
void IOCP::init(utils::ThreadPool* __IOCPThread, SessionProtocol proto) {
IOCPThread_ = __IOCPThread;
proto_ = proto;
#ifdef _WIN32
completionPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
if (completionPort_ == NULL) {
spdlog::critical("CreateIoCompletionPort()");
std::exit(EXIT_FAILURE);
}
int tCount = __IOCPThread->threadCount;
spdlog::info("Resizing threadpool size to: {}", tCount * 2);
__IOCPThread->respawnWorker(tCount * 2);
for (int i = 0; i < tCount; i++)
IOCPThread_->enqueueJob(
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
0);
#endif
}
void IOCP::destruct() {
#ifdef __linux__
#endif
}
void IOCP::registerTCPSocket(Socket& sock, std::uint32_t bufsize) {
#ifdef _WIN32
HANDLE returnData = ::CreateIoCompletionPort((HANDLE)sock.sock,
completionPort_, sock.sock, 0);
if (returnData == 0) completionPort_ = returnData;
IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(bufsize);
recv_data->event = IOCPEVENT::READ;
recv_data->socket = std::make_shared<Socket>(sock);
recv_data->IOCPInstance = this;
DWORD recvbytes = 0, flags = 0;
int result = SOCKET_ERROR;
result = ::WSARecv(recv_data->socket->sock, &recv_data->wsabuf, 1, &recvbytes,
&flags, &recv_data->overlapped, NULL);
if (result == SOCKET_ERROR) {
int err = ::WSAGetLastError();
if (err != WSA_IO_PENDING) {
auto err_msg = std::format("WSARecv failed: {}", err);
throw std::runtime_error(err_msg);
}
}
#endif
}
void IOCP::registerUDPSocket(IOCPPASSINDATA* data, Address recv_addr) {
#ifdef _WIN32
HANDLE returnData = ::CreateIoCompletionPort(
(HANDLE)data->socket->sock, completionPort_, data->socket->sock, 0);
if (returnData == 0) completionPort_ = returnData;
IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(data->bufsize);
recv_data->event = IOCPEVENT::READ;
recv_data->socket = data->socket;
DWORD recvbytes = 0, flags = 0;
int result = SOCKET_ERROR;
::WSARecvFrom(recv_data->socket->sock, &recv_data->wsabuf, 1, &recvbytes,
&flags, &recv_addr.addr, &recv_addr.length,
&recv_data->overlapped, NULL);
if (result == SOCKET_ERROR) {
int err = ::WSAGetLastError();
if (err != WSA_IO_PENDING) {
auto err_msg = std::format("WSARecv failed: {}", err);
throw std::runtime_error(err_msg);
}
}
#endif
}
int IOCP::recv(Socket& sock, std::vector<char>& data) {
std::lock_guard lock(*GetRecvQueueMutex_(sock.sock));
auto queue = GetRecvQueue_(sock.sock);
std::uint32_t left_data = data.size();
std::uint32_t copied = 0;
while (!queue->empty() && left_data != 0) {
auto front = queue->front();
queue->pop_front();
std::uint32_t offset = front.second;
std::uint32_t available = front.first.size() - offset;
std::uint32_t to_copy = (left_data < available) ? left_data : available;
::memcpy(data.data() + copied, front.first.data() + offset, to_copy);
copied += to_copy;
left_data -= to_copy;
offset += to_copy;
if (offset < front.first.size()) {
front.second = offset;
queue->push_front(front);
break;
}
}
return copied;
}
int IOCP::send(Socket& sock, std::vector<char>& data) {
auto lk = GetSendQueueMutex_(sock.sock);
auto queue = GetSendQueue_(sock.sock);
std::lock_guard lock(*lk);
Network::IOCPPASSINDATA* packet = new Network::IOCPPASSINDATA(data.size());
packet->event = IOCPEVENT::WRITE;
packet->socket = std::make_shared<Network::Socket>(sock);
packet->IOCPInstance = this;
::memcpy(packet->wsabuf.buf, data.data(), data.size());
packet->wsabuf.len = data.size();
queue->push_back(packet);
IOCPThread_->enqueueJob(
[this, sock = sock.sock](utils::ThreadPool* th, std::uint8_t __) {
packet_sender_(sock);
},
0);
return 0;
}
int IOCP::GetRecvedBytes(SOCKET sock) {
auto queue = GetRecvQueue_(sock);
std::lock_guard lock(socket_mod_mutex_);
int bytes = 0;
for (auto it : *queue) {
bytes += it.first.size() - it.second;
}
return bytes;
}
void IOCP::iocpWatcher_(utils::ThreadPool* IOCPThread) {
IOCPPASSINDATA* data;
SOCKET sock;
DWORD cbTransfrred;
int jitter = jitterDist_(gen_);
int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred,
(PULONG_PTR)&sock,
(LPOVERLAPPED*)&data, 1000 + jitter);
if (retVal == 0 || cbTransfrred == 0) {
DWORD lasterror = GetLastError();
if (lasterror == WAIT_TIMEOUT) {
IOCPThread->enqueueJob(
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
0);
return;
}
data->event = IOCPEVENT::QUIT;
spdlog::debug("Disconnected. [{}]",
(std::string)(data->socket->remoteAddr));
delete data;
IOCPThread->enqueueJob(
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
0);
return;
} else {
data->transferredbytes = cbTransfrred;
}
std::vector<char> buf(16384); // SSL_read최대 반환 크기
int red_data = 0;
std::lock_guard lock(*GetRecvQueueMutex_(sock));
auto queue_list = GetRecvQueue_(sock);
if (data->event == IOCPEVENT::READ) {
if (proto_ == SessionProtocol::TLS ||
proto_ == SessionProtocol::QUIC) { // DEPRECATED. openssl을 사용할 수가
// 없기 때문에 추후 완성 뒤에 기능을
// 붙이든 해야 할 듯 함.
// DEBUG: BIO_write 전 OpenSSL 에러 스택 확인 (혹시 모를 이전 에러)
ERR_print_errors_fp(stderr); // 이미 오류 스택에 뭔가 있는지 확인용
fprintf(stderr, "--- Before BIO_write ---\n");
::BIO_write(::SSL_get_rbio(data->ssl.get()), data->wsabuf.buf,
cbTransfrred);
// DEBUG: BIO_write 후 OpenSSL 에러 스택 확인 (BIO_write에서 에러 발생 시)
ERR_print_errors_fp(stderr); // BIO_write에서도 에러가 발생할 수 있음
fprintf(stderr, "--- After BIO_write, cbTransfrred: %lu ---\n",
cbTransfrred);
while ((red_data = ::SSL_read(data->ssl.get(), buf.data(), buf.size())) >
0) {
queue_list->emplace_back(std::make_pair(
std::vector<char>(buf.begin(), buf.begin() + red_data), 0));
}
if (red_data == -1) {
auto ssl_error_code = SSL_get_error(
data->ssl.get(), red_data); // 여기서 SSL_get_error 결과 저장
auto err_msg = std::format("SSL_read failed with SSL_get_error: {}",
ssl_error_code);
fprintf(stderr, "%s\n", err_msg.c_str());
// *** 가장 중요한 부분: SSL_ERROR_SSL일 때 상세 에러를 강제로 출력 시도
// ***
if (ssl_error_code == SSL_ERROR_SSL) {
fprintf(stderr, "Detailed SSL_ERROR_SSL trace:\n");
unsigned long err_peek;
// ERR_get_error()를 사용하여 스택의 모든 오류를 팝하고 출력
while ((err_peek = ERR_get_error()) != 0) {
char err_str[256];
ERR_error_string_n(err_peek, err_str, sizeof(err_str));
fprintf(stderr, "OpenSSL stack error: %s\n", err_str);
}
} else {
// SSL_ERROR_SSL이 아닌 다른 오류 (SYSCALL, WANT_READ 등)일 경우
// ERR_print_errors_fp는 여전히 유용할 수 있음
ERR_print_errors_fp(stderr);
}
throw std::runtime_error(err_msg); // 예외 발생
}
} else {
::memcpy(buf.data(), data->wsabuf.buf, data->transferredbytes);
queue_list->emplace_back(std::make_pair(
std::vector<char>(buf.begin(), buf.begin() + data->transferredbytes),
0));
}
DWORD recvbytes = 0, flags = 0;
IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(data->bufsize);
recv_data->event = IOCPEVENT::READ;
recv_data->socket = data->socket;
delete data;
::WSARecv(recv_data->socket->sock, &recv_data->wsabuf, 1, &recvbytes,
&flags, &recv_data->overlapped, NULL);
} else { // WRITE 시, 무시한다.
delete data;
}
IOCPThread->enqueueJob(
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0);
}
std::shared_ptr<std::list<IOCPPASSINDATA*>> IOCP::GetSendQueue_(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (send_queue_.find(sock) == send_queue_.end()) {
send_queue_[sock] = std::make_shared<std::list<IOCPPASSINDATA*>>(
std::list<IOCPPASSINDATA*>());
}
return send_queue_[sock];
}
std::shared_ptr<std::list<std::pair<std::vector<char>, std::uint32_t>>>
IOCP::GetRecvQueue_(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (recv_queue_.find(sock) == recv_queue_.end()) {
recv_queue_[sock] = std::make_shared<
std::list<std::pair<std::vector<char>, std::uint32_t>>>(
std::list<std::pair<std::vector<char>, std::uint32_t>>());
}
return recv_queue_[sock];
}
std::shared_ptr<std::mutex> IOCP::GetSendQueueMutex_(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (send_queue_mutex_.find(sock) == send_queue_mutex_.end()) {
send_queue_mutex_[sock] = std::make_shared<std::mutex>();
}
return send_queue_mutex_[sock];
}
std::shared_ptr<std::mutex> IOCP::GetRecvQueueMutex_(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (recv_queue_mutex_.find(sock) == recv_queue_mutex_.end()) {
recv_queue_mutex_[sock] = std::make_shared<std::mutex>();
}
return recv_queue_mutex_[sock];
}
void IOCP::packet_sender_(SOCKET sock) {
auto queue = GetSendQueue_(sock);
std::unique_lock lock(*GetSendQueueMutex_(sock));
std::vector<char> buf(16384);
WSABUF wsabuf;
DWORD sendbytes = 0;
while (!queue->empty()) {
auto front = queue->front();
queue->pop_front();
front->event = IOCPEVENT::WRITE;
int data_len = 0;
if (proto_ == SessionProtocol::TLS || proto_ == SessionProtocol::QUIC) {
int ret =
::SSL_write(front->ssl.get(), front->wsabuf.buf, front->wsabuf.len);
if (ret <= 0) {
int err = ::SSL_get_error(front->ssl.get(), ret);
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
queue->push_front(front);
break;
}
std::unique_lock lk(socket_mod_mutex_);
send_queue_.erase(sock);
break;
}
while ((data_len = ::BIO_read(::SSL_get_wbio(front->ssl.get()),
buf.data(), buf.size())) > 0) {
wsabuf.buf = buf.data();
wsabuf.len = data_len;
::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr);
}
} else {
data_len = front->wsabuf.len;
wsabuf.buf = front->wsabuf.buf;
wsabuf.len = data_len;
::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr);
}
}
}
} // namespace Network