271 lines
8.3 KiB
C++
271 lines
8.3 KiB
C++
#include "socket/iocp.h"
|
|
|
|
#include "utils/thread_pool.h"
|
|
|
|
namespace Network {
|
|
|
|
IOCP::IOCP() : IOCPThread_(nullptr), proto_(SessionProtocol::TCP) {
|
|
gen_ = std::mt19937(rd_());
|
|
jitterDist_ = std::uniform_int_distribution<int>(-10, 10);
|
|
}
|
|
|
|
IOCP::~IOCP() { destruct(); }
|
|
|
|
void IOCP::init(utils::ThreadPool* __IOCPThread, SessionProtocol proto) {
|
|
IOCPThread_ = __IOCPThread;
|
|
proto_ = proto;
|
|
|
|
#ifdef _WIN32
|
|
completionPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
|
|
if (completionPort_ == NULL) {
|
|
spdlog::critical("CreateIoCompletionPort()");
|
|
std::exit(EXIT_FAILURE);
|
|
}
|
|
int tCount = __IOCPThread->threadCount;
|
|
|
|
spdlog::info("Resizing threadpool size to: {}", tCount * 2);
|
|
|
|
__IOCPThread->respawnWorker(tCount * 2);
|
|
|
|
for (int i = 0; i < tCount; i++)
|
|
IOCPThread_->enqueueJob(
|
|
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
|
|
0);
|
|
#endif
|
|
}
|
|
|
|
void IOCP::destruct() {
|
|
#ifdef __linux__
|
|
|
|
#endif
|
|
}
|
|
|
|
void IOCP::registerSocket(std::shared_ptr<Socket> sock) {
|
|
#ifdef _WIN32
|
|
HANDLE returnData = ::CreateIoCompletionPort((HANDLE)sock->sock,
|
|
completionPort_, sock->sock, 0);
|
|
if (returnData == 0) completionPort_ = returnData;
|
|
#endif
|
|
}
|
|
|
|
std::future<std::vector<char>> IOCP::recvFull(std::shared_ptr<Socket> sock,
|
|
std::uint32_t bufsize) {
|
|
auto promise = std::make_shared<std::promise<std::vector<char>>>();
|
|
auto future = promise->get_future();
|
|
|
|
auto buffer = std::make_shared<std::vector<char>>();
|
|
buffer->reserve(bufsize);
|
|
|
|
std::function<void(std::uint32_t)> recvChunk;
|
|
recvChunk = [=](std::uint32_t remaining) mutable {
|
|
this->recv(sock, remaining,
|
|
[=](utils::ThreadPool* th, IOCPPASSINDATA* data) {
|
|
buffer->insert(buffer->end(), data->wsabuf.buf,
|
|
data->wsabuf.buf + data->transferredbytes);
|
|
|
|
std::uint32_t still_left =
|
|
bufsize - static_cast<std::uint32_t>(buffer->size());
|
|
if (still_left > 0) {
|
|
recvChunk(still_left);
|
|
} else {
|
|
promise->set_value(std::move(*buffer));
|
|
}
|
|
|
|
return std::list<char>();
|
|
});
|
|
};
|
|
|
|
recvChunk(bufsize);
|
|
|
|
return future;
|
|
}
|
|
|
|
std::list<char> DEFAULT_RECVALL_CALLBACK(utils::ThreadPool* th,
|
|
IOCPPASSINDATA* data) {
|
|
std::list<char> return_value;
|
|
return_value.insert(return_value.end(), data->wsabuf.buf,
|
|
data->wsabuf.buf + data->transferredbytes);
|
|
|
|
if (data->transferredbytes < data->wsabuf.len) {
|
|
auto future = data->IOCPInstance->recv(
|
|
data->socket, data->wsabuf.len - data->transferredbytes,
|
|
DEFAULT_RECVALL_CALLBACK);
|
|
auto result = future.get();
|
|
return_value.insert(return_value.end(), result.begin(), result.end());
|
|
}
|
|
|
|
return return_value;
|
|
}
|
|
|
|
std::future<std::list<char>> IOCP::recv(
|
|
std::shared_ptr<Socket> sock, std::uint32_t bufsize,
|
|
std::function<std::list<char>(utils::ThreadPool*, IOCPPASSINDATA*)>
|
|
callback) {
|
|
std::lock_guard lock(*GetRecvQueueMutex(sock->sock));
|
|
auto queue = GetRecvQueue(sock->sock);
|
|
|
|
Network::IOCPPASSINDATA* data;
|
|
std::packaged_task<std::list<char>(utils::ThreadPool*, IOCPPASSINDATA*)> task;
|
|
std::future<std::list<char>> future;
|
|
if (callback != nullptr) {
|
|
task = std::packaged_task<std::list<char>(utils::ThreadPool*,
|
|
IOCPPASSINDATA*)>(callback);
|
|
future = task.get_future();
|
|
data = new Network::IOCPPASSINDATA(sock, bufsize, this, std::move(task));
|
|
} else {
|
|
data = new Network::IOCPPASSINDATA(sock, bufsize, this);
|
|
}
|
|
|
|
int result = SOCKET_ERROR;
|
|
DWORD recvbytes = 0, flags = 0;
|
|
result = ::WSARecv(sock->sock, &data->wsabuf, 1, &recvbytes, &flags,
|
|
&data->overlapped, NULL);
|
|
if (result == SOCKET_ERROR) {
|
|
int err = ::WSAGetLastError();
|
|
if (err != WSA_IO_PENDING) {
|
|
auto err_msg = std::format("WSARecv failed: {}", err);
|
|
throw std::runtime_error(err_msg);
|
|
}
|
|
}
|
|
|
|
return future;
|
|
}
|
|
|
|
int IOCP::send(std::shared_ptr<Socket> sock, std::vector<char>& data) {
|
|
auto lk = GetSendQueueMutex(sock->sock);
|
|
auto queue = GetSendQueue(sock->sock);
|
|
std::lock_guard lock(*lk);
|
|
|
|
Network::IOCPPASSINDATA* packet = new Network::IOCPPASSINDATA(sock, data.size(), this);
|
|
packet->event = IOCPEVENT::WRITE;
|
|
::memcpy(packet->wsabuf.buf, data.data(), data.size());
|
|
packet->wsabuf.len = data.size();
|
|
queue->push_back(packet);
|
|
|
|
IOCPThread_->enqueueJob(
|
|
[this, sock = sock->sock](utils::ThreadPool* th, std::uint8_t __) {
|
|
packet_sender_(sock);
|
|
},
|
|
0);
|
|
return 0;
|
|
}
|
|
|
|
int IOCP::GetRecvedBytes(SOCKET sock) {
|
|
auto queue = GetRecvQueue(sock);
|
|
std::lock_guard lock(socket_mod_mutex_);
|
|
|
|
int bytes = 0;
|
|
for (auto it : *queue) {
|
|
bytes += it.first.size() - it.second;
|
|
}
|
|
|
|
return bytes;
|
|
}
|
|
|
|
void IOCP::iocpWatcher_(utils::ThreadPool* IOCPThread) {
|
|
IOCPPASSINDATA* data;
|
|
SOCKET sock;
|
|
DWORD cbTransfrred;
|
|
int jitter = jitterDist_(gen_);
|
|
int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred,
|
|
(PULONG_PTR)&sock,
|
|
(LPOVERLAPPED*)&data, 1000 + jitter);
|
|
|
|
if (retVal == 0 || cbTransfrred == 0) {
|
|
DWORD lasterror = GetLastError();
|
|
if (lasterror == WAIT_TIMEOUT) {
|
|
IOCPThread->enqueueJob(
|
|
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
|
|
0);
|
|
return;
|
|
}
|
|
data->event = IOCPEVENT::QUIT;
|
|
spdlog::debug("Disconnected. [{}]",
|
|
(std::string)(data->socket->remoteAddr));
|
|
auto task = [this, IOCPThread, data = std::move(data)](
|
|
utils::ThreadPool* th, std::uint8_t __) {
|
|
if (data->callback.valid()) {
|
|
data->callback(th, data);
|
|
}
|
|
data->socket->destruct();
|
|
delete data;
|
|
};
|
|
IOCPThread->enqueueJob(task, 0);
|
|
IOCPThread->enqueueJob(
|
|
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
|
|
0);
|
|
return;
|
|
} else {
|
|
data->transferredbytes = cbTransfrred;
|
|
}
|
|
|
|
auto task = [this, IOCPThread, data = std::move(data)](utils::ThreadPool* th,
|
|
std::uint8_t __) {
|
|
if (data->callback.valid()) data->callback(th, data);
|
|
delete data;
|
|
};
|
|
IOCPThread->enqueueJob(task, 0);
|
|
IOCPThread->enqueueJob(
|
|
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0);
|
|
}
|
|
|
|
std::shared_ptr<std::list<IOCPPASSINDATA*>> IOCP::GetSendQueue(SOCKET sock) {
|
|
std::lock_guard lock(socket_mod_mutex_);
|
|
if (send_queue_.find(sock) == send_queue_.end()) {
|
|
send_queue_[sock] = std::make_shared<std::list<IOCPPASSINDATA*>>(
|
|
std::list<IOCPPASSINDATA*>());
|
|
}
|
|
return send_queue_[sock];
|
|
}
|
|
|
|
std::shared_ptr<std::list<std::pair<std::vector<char>, std::uint32_t>>>
|
|
IOCP::GetRecvQueue(SOCKET sock) {
|
|
std::lock_guard lock(socket_mod_mutex_);
|
|
if (recv_queue_.find(sock) == recv_queue_.end()) {
|
|
recv_queue_[sock] = std::make_shared<
|
|
std::list<std::pair<std::vector<char>, std::uint32_t>>>(
|
|
std::list<std::pair<std::vector<char>, std::uint32_t>>());
|
|
}
|
|
return recv_queue_[sock];
|
|
}
|
|
|
|
std::shared_ptr<std::mutex> IOCP::GetSendQueueMutex(SOCKET sock) {
|
|
std::lock_guard lock(socket_mod_mutex_);
|
|
if (send_queue_mutex_.find(sock) == send_queue_mutex_.end()) {
|
|
send_queue_mutex_[sock] = std::make_shared<std::mutex>();
|
|
}
|
|
return send_queue_mutex_[sock];
|
|
}
|
|
|
|
std::shared_ptr<std::mutex> IOCP::GetRecvQueueMutex(SOCKET sock) {
|
|
std::lock_guard lock(socket_mod_mutex_);
|
|
if (recv_queue_mutex_.find(sock) == recv_queue_mutex_.end()) {
|
|
recv_queue_mutex_[sock] = std::make_shared<std::mutex>();
|
|
}
|
|
return recv_queue_mutex_[sock];
|
|
}
|
|
|
|
void IOCP::packet_sender_(SOCKET sock) {
|
|
auto queue = GetSendQueue(sock);
|
|
std::unique_lock lock(*GetSendQueueMutex(sock));
|
|
|
|
std::vector<char> buf(16384);
|
|
WSABUF wsabuf;
|
|
DWORD sendbytes = 0;
|
|
|
|
while (!queue->empty()) {
|
|
auto front = queue->front();
|
|
queue->pop_front();
|
|
front->event = IOCPEVENT::WRITE;
|
|
|
|
int data_len = 0;
|
|
|
|
data_len = front->wsabuf.len;
|
|
wsabuf.buf = front->wsabuf.buf;
|
|
wsabuf.len = data_len;
|
|
::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr);
|
|
}
|
|
}
|
|
|
|
} // namespace Network
|