Files
Np_Term/impl/socket/iocp.cpp

271 lines
8.3 KiB
C++

#include "socket/iocp.h"
#include "utils/thread_pool.h"
namespace Network {
IOCP::IOCP() : IOCPThread_(nullptr), proto_(SessionProtocol::TCP) {
gen_ = std::mt19937(rd_());
jitterDist_ = std::uniform_int_distribution<int>(-10, 10);
}
IOCP::~IOCP() { destruct(); }
void IOCP::init(utils::ThreadPool* __IOCPThread, SessionProtocol proto) {
IOCPThread_ = __IOCPThread;
proto_ = proto;
#ifdef _WIN32
completionPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
if (completionPort_ == NULL) {
spdlog::critical("CreateIoCompletionPort()");
std::exit(EXIT_FAILURE);
}
int tCount = __IOCPThread->threadCount;
spdlog::info("Resizing threadpool size to: {}", tCount * 2);
__IOCPThread->respawnWorker(tCount * 2);
for (int i = 0; i < tCount; i++)
IOCPThread_->enqueueJob(
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
0);
#endif
}
void IOCP::destruct() {
#ifdef __linux__
#endif
}
void IOCP::registerSocket(std::shared_ptr<Socket> sock) {
#ifdef _WIN32
HANDLE returnData = ::CreateIoCompletionPort((HANDLE)sock->sock,
completionPort_, sock->sock, 0);
if (returnData == 0) completionPort_ = returnData;
#endif
}
std::future<std::vector<char>> IOCP::recvFull(std::shared_ptr<Socket> sock,
std::uint32_t bufsize) {
auto promise = std::make_shared<std::promise<std::vector<char>>>();
auto future = promise->get_future();
auto buffer = std::make_shared<std::vector<char>>();
buffer->reserve(bufsize);
std::function<void(std::uint32_t)> recvChunk;
recvChunk = [=](std::uint32_t remaining) mutable {
this->recv(sock, remaining,
[=](utils::ThreadPool* th, IOCPPASSINDATA* data) {
buffer->insert(buffer->end(), data->wsabuf.buf,
data->wsabuf.buf + data->transferredbytes);
std::uint32_t still_left =
bufsize - static_cast<std::uint32_t>(buffer->size());
if (still_left > 0) {
recvChunk(still_left);
} else {
promise->set_value(std::move(*buffer));
}
return std::list<char>();
});
};
recvChunk(bufsize);
return future;
}
std::list<char> DEFAULT_RECVALL_CALLBACK(utils::ThreadPool* th,
IOCPPASSINDATA* data) {
std::list<char> return_value;
return_value.insert(return_value.end(), data->wsabuf.buf,
data->wsabuf.buf + data->transferredbytes);
if (data->transferredbytes < data->wsabuf.len) {
auto future = data->IOCPInstance->recv(
data->socket, data->wsabuf.len - data->transferredbytes,
DEFAULT_RECVALL_CALLBACK);
auto result = future.get();
return_value.insert(return_value.end(), result.begin(), result.end());
}
return return_value;
}
std::future<std::list<char>> IOCP::recv(
std::shared_ptr<Socket> sock, std::uint32_t bufsize,
std::function<std::list<char>(utils::ThreadPool*, IOCPPASSINDATA*)>
callback) {
std::lock_guard lock(*GetRecvQueueMutex(sock->sock));
auto queue = GetRecvQueue(sock->sock);
Network::IOCPPASSINDATA* data;
std::packaged_task<std::list<char>(utils::ThreadPool*, IOCPPASSINDATA*)> task;
std::future<std::list<char>> future;
if (callback != nullptr) {
task = std::packaged_task<std::list<char>(utils::ThreadPool*,
IOCPPASSINDATA*)>(callback);
future = task.get_future();
data = new Network::IOCPPASSINDATA(sock, bufsize, this, std::move(task));
} else {
data = new Network::IOCPPASSINDATA(sock, bufsize, this);
}
int result = SOCKET_ERROR;
DWORD recvbytes = 0, flags = 0;
result = ::WSARecv(sock->sock, &data->wsabuf, 1, &recvbytes, &flags,
&data->overlapped, NULL);
if (result == SOCKET_ERROR) {
int err = ::WSAGetLastError();
if (err != WSA_IO_PENDING) {
auto err_msg = std::format("WSARecv failed: {}", err);
throw std::runtime_error(err_msg);
}
}
return future;
}
int IOCP::send(std::shared_ptr<Socket> sock, std::vector<char>& data) {
auto lk = GetSendQueueMutex(sock->sock);
auto queue = GetSendQueue(sock->sock);
std::lock_guard lock(*lk);
Network::IOCPPASSINDATA* packet = new Network::IOCPPASSINDATA(sock, data.size(), this);
packet->event = IOCPEVENT::WRITE;
::memcpy(packet->wsabuf.buf, data.data(), data.size());
packet->wsabuf.len = data.size();
queue->push_back(packet);
IOCPThread_->enqueueJob(
[this, sock = sock->sock](utils::ThreadPool* th, std::uint8_t __) {
packet_sender_(sock);
},
0);
return 0;
}
int IOCP::GetRecvedBytes(SOCKET sock) {
auto queue = GetRecvQueue(sock);
std::lock_guard lock(socket_mod_mutex_);
int bytes = 0;
for (auto it : *queue) {
bytes += it.first.size() - it.second;
}
return bytes;
}
void IOCP::iocpWatcher_(utils::ThreadPool* IOCPThread) {
IOCPPASSINDATA* data;
SOCKET sock;
DWORD cbTransfrred;
int jitter = jitterDist_(gen_);
int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred,
(PULONG_PTR)&sock,
(LPOVERLAPPED*)&data, 1000 + jitter);
if (retVal == 0 || cbTransfrred == 0) {
DWORD lasterror = GetLastError();
if (lasterror == WAIT_TIMEOUT) {
IOCPThread->enqueueJob(
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
0);
return;
}
data->event = IOCPEVENT::QUIT;
spdlog::debug("Disconnected. [{}]",
(std::string)(data->socket->remoteAddr));
auto task = [this, IOCPThread, data = std::move(data)](
utils::ThreadPool* th, std::uint8_t __) {
if (data->callback.valid()) {
data->callback(th, data);
}
data->socket->destruct();
delete data;
};
IOCPThread->enqueueJob(task, 0);
IOCPThread->enqueueJob(
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
0);
return;
} else {
data->transferredbytes = cbTransfrred;
}
auto task = [this, IOCPThread, data = std::move(data)](utils::ThreadPool* th,
std::uint8_t __) {
if (data->callback.valid()) data->callback(th, data);
delete data;
};
IOCPThread->enqueueJob(task, 0);
IOCPThread->enqueueJob(
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0);
}
std::shared_ptr<std::list<IOCPPASSINDATA*>> IOCP::GetSendQueue(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (send_queue_.find(sock) == send_queue_.end()) {
send_queue_[sock] = std::make_shared<std::list<IOCPPASSINDATA*>>(
std::list<IOCPPASSINDATA*>());
}
return send_queue_[sock];
}
std::shared_ptr<std::list<std::pair<std::vector<char>, std::uint32_t>>>
IOCP::GetRecvQueue(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (recv_queue_.find(sock) == recv_queue_.end()) {
recv_queue_[sock] = std::make_shared<
std::list<std::pair<std::vector<char>, std::uint32_t>>>(
std::list<std::pair<std::vector<char>, std::uint32_t>>());
}
return recv_queue_[sock];
}
std::shared_ptr<std::mutex> IOCP::GetSendQueueMutex(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (send_queue_mutex_.find(sock) == send_queue_mutex_.end()) {
send_queue_mutex_[sock] = std::make_shared<std::mutex>();
}
return send_queue_mutex_[sock];
}
std::shared_ptr<std::mutex> IOCP::GetRecvQueueMutex(SOCKET sock) {
std::lock_guard lock(socket_mod_mutex_);
if (recv_queue_mutex_.find(sock) == recv_queue_mutex_.end()) {
recv_queue_mutex_[sock] = std::make_shared<std::mutex>();
}
return recv_queue_mutex_[sock];
}
void IOCP::packet_sender_(SOCKET sock) {
auto queue = GetSendQueue(sock);
std::unique_lock lock(*GetSendQueueMutex(sock));
std::vector<char> buf(16384);
WSABUF wsabuf;
DWORD sendbytes = 0;
while (!queue->empty()) {
auto front = queue->front();
queue->pop_front();
front->event = IOCPEVENT::WRITE;
int data_len = 0;
data_len = front->wsabuf.len;
wsabuf.buf = front->wsabuf.buf;
wsabuf.len = data_len;
::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr);
}
}
} // namespace Network