Files
Np_Term/impl/socket/iocp.cpp

172 lines
4.8 KiB
C++

#include "socket/iocp.h"
#include "utils/thread_pool.h"
namespace Network {
IOCP::IOCP() {
gen_ = std::mt19937(rd_());
jitterDist_ = std::uniform_int_distribution<int>(-10, 10);
}
IOCP::~IOCP() { destruct(); }
void IOCP::destruct() {
#ifdef __linux__
#endif
}
void IOCP::registerSocket(IOCPPASSINDATA* data) {
#ifdef _WIN32
HANDLE returnData = ::CreateIoCompletionPort(
(HANDLE)data->socket->sock, completionPort_, data->socket->sock, 0);
if (returnData == 0) completionPort_ = returnData;
IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(data->bufsize);
recv_data->event = IOCPEVENT::READ;
recv_data->socket = data->socket;
DWORD recvbytes = 0, flags = 0;
int result = ::WSARecv(recv_data->socket->sock, &recv_data->wsabuf, 1,
&recvbytes, &flags, &recv_data->overlapped, NULL);
if (result == SOCKET_ERROR) {
int err = ::WSAGetLastError();
if (err != WSA_IO_PENDING) {
spdlog::error("WSARecv failed: {}", err);
// 반드시 여기서 리턴하거나 처리해야 합니다.
}
}
#endif
}
int IOCP::recv(IOCPPASSINDATA* data) {
SOCKET sock = data->socket->sock;
std::lock_guard lock(*GetRecvQueueMutex_(sock));
auto queue = GetRecvQueue_(sock);
std::uint32_t left_data = data->wsabuf.len;
std::uint32_t copied = 0;
while (!queue->empty() && left_data != 0) {
auto front = queue->front();
queue->pop_front();
std::uint32_t offset = front.second;
std::uint32_t available = front.first.size() - offset;
std::uint32_t to_copy = (left_data < available) ? left_data : available;
::memcpy(data->wsabuf.buf + copied, front.first.data() + offset, to_copy);
copied += to_copy;
left_data -= to_copy;
offset += to_copy;
if (offset < front.first.size()) {
front.second = offset;
queue->push_front(front);
break;
}
}
return copied;
}
int IOCP::send(SOCKET sock, std::vector<IOCPPASSINDATA*>* data) {
auto lk = GetSendQueueMutex_(sock);
auto queue = GetSendQueue_(sock);
std::lock_guard lock(*lk);
for (auto& it : *data) {
it->event = IOCPEVENT::WRITE;
queue->push_back(it);
}
IOCPThread_->enqueueJob(
[this, sock](utils::ThreadPool* th, std::uint8_t __) {
packet_sender_(sock);
},
0);
return 0;
}
std::shared_ptr<std::list<IOCPPASSINDATA*>> IOCP::GetSendQueue_(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (send_queue_.find(sock) == send_queue_.end()) {
send_queue_[sock] = std::make_shared<std::list<IOCPPASSINDATA*>>(
std::list<IOCPPASSINDATA*>());
}
return send_queue_[sock];
}
std::shared_ptr<std::list<std::pair<std::vector<char>, std::uint32_t>>>
IOCP::GetRecvQueue_(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (recv_queue_.find(sock) == recv_queue_.end()) {
recv_queue_[sock] = std::make_shared<
std::list<std::pair<std::vector<char>, std::uint32_t>>>(
std::list<std::pair<std::vector<char>, std::uint32_t>>());
}
return recv_queue_[sock];
}
std::shared_ptr<std::mutex> IOCP::GetSendQueueMutex_(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (send_queue_mutex_.find(sock) == send_queue_mutex_.end()) {
send_queue_mutex_[sock] = std::make_shared<std::mutex>();
}
return send_queue_mutex_[sock];
}
std::shared_ptr<std::mutex> IOCP::GetRecvQueueMutex_(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (recv_queue_mutex_.find(sock) == recv_queue_mutex_.end()) {
recv_queue_mutex_[sock] = std::make_shared<std::mutex>();
}
return recv_queue_mutex_[sock];
}
void IOCP::packet_sender_(SOCKET sock) {
auto queue = GetSendQueue_(sock);
std::unique_lock lock(*GetSendQueueMutex_(sock));
std::vector<char> buf(16384);
WSABUF wsabuf;
DWORD sendbytes = 0;
while (!queue->empty()) {
auto front = queue->front();
queue->pop_front();
front->event = IOCPEVENT::WRITE;
int data_len = 0;
if (proto_ == SessionProtocol::TLS || proto_ == SessionProtocol::QUIC) {
int ret = ::SSL_write(front->ssl.get(), front->wsabuf.buf, front->wsabuf.len);
if (ret <= 0) {
int err = ::SSL_get_error(front->ssl.get(), ret);
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
queue->push_front(front);
break;
}
std::unique_lock lk(socket_mod_mutex_);
send_queue_.erase(sock);
break;
}
while ((data_len = ::BIO_read(::SSL_get_wbio(front->ssl.get()), buf.data(), buf.size())) > 0) {
wsabuf.buf = buf.data();
wsabuf.len = data_len;
::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr);
}
} else {
data_len = front->wsabuf.len;
wsabuf.buf = front->wsabuf.buf;
wsabuf.len = data_len;
::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr);
}
}
}
} // namespace Network