262 lines
7.8 KiB
C++
262 lines
7.8 KiB
C++
#include "socket/iocp.h"
|
|
|
|
#include "utils/thread_pool.h"
|
|
|
|
namespace Network {
|
|
|
|
IOCP::IOCP() {
|
|
gen_ = std::mt19937(rd_());
|
|
jitterDist_ = std::uniform_int_distribution<int>(-10, 10);
|
|
}
|
|
|
|
IOCP::~IOCP() { destruct(); }
|
|
|
|
void IOCP::init(utils::ThreadPool* __IOCPThread, SessionProtocol proto) {
|
|
IOCPThread_ = __IOCPThread;
|
|
proto_ = proto;
|
|
|
|
#ifdef _WIN32
|
|
completionPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
|
|
if (completionPort_ == NULL) {
|
|
spdlog::critical("CreateIoCompletionPort()");
|
|
std::exit(EXIT_FAILURE);
|
|
}
|
|
int tCount = __IOCPThread->threadCount;
|
|
|
|
spdlog::info("Resizing threadpool size to: {}", tCount * 2);
|
|
|
|
__IOCPThread->respawnWorker(tCount * 2);
|
|
|
|
for (int i = 0; i < tCount; i++)
|
|
IOCPThread_->enqueueJob(
|
|
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
|
|
0);
|
|
#endif
|
|
}
|
|
|
|
void IOCP::destruct() {
|
|
#ifdef __linux__
|
|
|
|
#endif
|
|
}
|
|
|
|
void IOCP::registerSocket(IOCPPASSINDATA* data) {
|
|
#ifdef _WIN32
|
|
HANDLE returnData = ::CreateIoCompletionPort(
|
|
(HANDLE)data->socket->sock, completionPort_, data->socket->sock, 0);
|
|
if (returnData == 0) completionPort_ = returnData;
|
|
|
|
IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(data->bufsize);
|
|
recv_data->event = IOCPEVENT::READ;
|
|
recv_data->socket = data->socket;
|
|
DWORD recvbytes = 0, flags = 0;
|
|
int result = ::WSARecv(recv_data->socket->sock, &recv_data->wsabuf, 1,
|
|
&recvbytes, &flags, &recv_data->overlapped, NULL);
|
|
if (result == SOCKET_ERROR) {
|
|
int err = ::WSAGetLastError();
|
|
if (err != WSA_IO_PENDING) {
|
|
spdlog::error("WSARecv failed: {}", err);
|
|
// 반드시 여기서 리턴하거나 처리해야 합니다.
|
|
}
|
|
}
|
|
|
|
#endif
|
|
}
|
|
|
|
int IOCP::recv(IOCPPASSINDATA* data) { //읽은 바이트수가 무조건 100임? 왜..?
|
|
SOCKET sock = data->socket->sock;
|
|
std::lock_guard lock(*GetRecvQueueMutex_(sock));
|
|
auto queue = GetRecvQueue_(sock);
|
|
|
|
std::uint32_t left_data = data->wsabuf.len;
|
|
std::uint32_t copied = 0;
|
|
|
|
while (!queue->empty() && left_data != 0) {
|
|
auto front = queue->front();
|
|
queue->pop_front();
|
|
|
|
std::uint32_t offset = front.second;
|
|
std::uint32_t available = front.first.size() - offset;
|
|
std::uint32_t to_copy = (left_data < available) ? left_data : available;
|
|
|
|
::memcpy(data->wsabuf.buf + copied, front.first.data() + offset, to_copy);
|
|
copied += to_copy;
|
|
left_data -= to_copy;
|
|
offset += to_copy;
|
|
|
|
if (offset < front.first.size()) {
|
|
front.second = offset;
|
|
queue->push_front(front);
|
|
break;
|
|
}
|
|
}
|
|
|
|
return copied;
|
|
}
|
|
|
|
int IOCP::send(SOCKET sock, std::vector<IOCPPASSINDATA*>* data) {
|
|
auto lk = GetSendQueueMutex_(sock);
|
|
auto queue = GetSendQueue_(sock);
|
|
std::lock_guard lock(*lk);
|
|
for (auto& it : *data) {
|
|
it->event = IOCPEVENT::WRITE;
|
|
queue->push_back(it);
|
|
}
|
|
|
|
IOCPThread_->enqueueJob(
|
|
[this, sock](utils::ThreadPool* th, std::uint8_t __) {
|
|
packet_sender_(sock);
|
|
},
|
|
0);
|
|
return 0;
|
|
}
|
|
|
|
int IOCP::GetRecvedPacketCount(SOCKET sock) {
|
|
std::lock_guard lock(socket_mod_mutex_);
|
|
auto queue = GetRecvQueue_(sock);
|
|
return queue->size();
|
|
}
|
|
|
|
void IOCP::iocpWatcher_(utils::ThreadPool* IOCPThread) {
|
|
IOCPPASSINDATA* data;
|
|
SOCKET sock;
|
|
DWORD cbTransfrred;
|
|
int jitter = jitterDist_(gen_);
|
|
int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred,
|
|
(PULONG_PTR)&sock,
|
|
(LPOVERLAPPED*)&data, 1000 + jitter);
|
|
|
|
if (retVal == 0 || cbTransfrred == 0) {
|
|
DWORD lasterror = GetLastError();
|
|
if (lasterror == WAIT_TIMEOUT) {
|
|
IOCPThread->enqueueJob(
|
|
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); },
|
|
0);
|
|
return;
|
|
}
|
|
data->event = IOCPEVENT::QUIT;
|
|
spdlog::debug("Disconnected. [{}]",
|
|
(std::string)(data->socket->remoteAddr));
|
|
delete data;
|
|
} else {
|
|
data->transferredbytes = cbTransfrred;
|
|
}
|
|
|
|
std::vector<char> buf(16384); // SSL_read최대 반환 크기
|
|
int red_data = 0;
|
|
std::lock_guard lock(*GetRecvQueueMutex_(sock));
|
|
auto queue_list = GetRecvQueue_(data->socket->sock);
|
|
if (data->event == IOCPEVENT::READ) {
|
|
if (proto_ == SessionProtocol::TLS || proto_ == SessionProtocol::QUIC) {
|
|
::BIO_write(::SSL_get_rbio(data->ssl.get()), data->wsabuf.buf,
|
|
cbTransfrred);
|
|
|
|
while ((red_data = ::SSL_read(data->ssl.get(), buf.data(), buf.size())) >
|
|
0) {
|
|
queue_list->emplace_back(std::make_pair(
|
|
std::vector<char>(buf.begin(), buf.begin() + red_data), 0));
|
|
}
|
|
} else {
|
|
::memcpy(buf.data(), data->wsabuf.buf, data->transferredbytes);
|
|
queue_list->emplace_back(std::make_pair(
|
|
std::vector<char>(buf.begin(), buf.begin() + data->transferredbytes),
|
|
0));
|
|
}
|
|
DWORD recvbytes = 0, flags = 0;
|
|
|
|
IOCPPASSINDATA* recv_data = new IOCPPASSINDATA(data->bufsize);
|
|
recv_data->event = IOCPEVENT::READ;
|
|
recv_data->socket = data->socket;
|
|
|
|
delete data;
|
|
::WSARecv(recv_data->socket->sock, &recv_data->wsabuf, 1, &recvbytes,
|
|
&flags, &recv_data->overlapped, NULL);
|
|
} else { // WRITE 시, 무시한다.
|
|
delete data;
|
|
}
|
|
IOCPThread->enqueueJob(
|
|
[this](utils::ThreadPool* th, std::uint8_t __) { iocpWatcher_(th); }, 0);
|
|
}
|
|
|
|
std::shared_ptr<std::list<IOCPPASSINDATA*>> IOCP::GetSendQueue_(SOCKET sock) {
|
|
std::lock_guard lock(socket_mod_mutex_);
|
|
if (send_queue_.find(sock) == send_queue_.end()) {
|
|
send_queue_[sock] = std::make_shared<std::list<IOCPPASSINDATA*>>(
|
|
std::list<IOCPPASSINDATA*>());
|
|
}
|
|
return send_queue_[sock];
|
|
}
|
|
|
|
std::shared_ptr<std::list<std::pair<std::vector<char>, std::uint32_t>>>
|
|
IOCP::GetRecvQueue_(SOCKET sock) {
|
|
std::lock_guard lock(socket_mod_mutex_);
|
|
if (recv_queue_.find(sock) == recv_queue_.end()) {
|
|
recv_queue_[sock] = std::make_shared<
|
|
std::list<std::pair<std::vector<char>, std::uint32_t>>>(
|
|
std::list<std::pair<std::vector<char>, std::uint32_t>>());
|
|
}
|
|
return recv_queue_[sock];
|
|
}
|
|
|
|
std::shared_ptr<std::mutex> IOCP::GetSendQueueMutex_(SOCKET sock) {
|
|
std::lock_guard lock(socket_mod_mutex_);
|
|
if (send_queue_mutex_.find(sock) == send_queue_mutex_.end()) {
|
|
send_queue_mutex_[sock] = std::make_shared<std::mutex>();
|
|
}
|
|
return send_queue_mutex_[sock];
|
|
}
|
|
|
|
std::shared_ptr<std::mutex> IOCP::GetRecvQueueMutex_(SOCKET sock) {
|
|
std::lock_guard lock(socket_mod_mutex_);
|
|
if (recv_queue_mutex_.find(sock) == recv_queue_mutex_.end()) {
|
|
recv_queue_mutex_[sock] = std::make_shared<std::mutex>();
|
|
}
|
|
return recv_queue_mutex_[sock];
|
|
}
|
|
|
|
void IOCP::packet_sender_(SOCKET sock) {
|
|
auto queue = GetSendQueue_(sock);
|
|
std::unique_lock lock(*GetSendQueueMutex_(sock));
|
|
|
|
std::vector<char> buf(16384);
|
|
WSABUF wsabuf;
|
|
DWORD sendbytes = 0;
|
|
|
|
while (!queue->empty()) {
|
|
auto front = queue->front();
|
|
queue->pop_front();
|
|
front->event = IOCPEVENT::WRITE;
|
|
|
|
int data_len = 0;
|
|
|
|
if (proto_ == SessionProtocol::TLS || proto_ == SessionProtocol::QUIC) {
|
|
int ret = ::SSL_write(front->ssl.get(), front->wsabuf.buf, front->wsabuf.len);
|
|
if (ret <= 0) {
|
|
int err = ::SSL_get_error(front->ssl.get(), ret);
|
|
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
|
|
queue->push_front(front);
|
|
break;
|
|
}
|
|
std::unique_lock lk(socket_mod_mutex_);
|
|
send_queue_.erase(sock);
|
|
break;
|
|
}
|
|
|
|
while ((data_len = ::BIO_read(::SSL_get_wbio(front->ssl.get()), buf.data(), buf.size())) > 0) {
|
|
wsabuf.buf = buf.data();
|
|
wsabuf.len = data_len;
|
|
|
|
::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr);
|
|
}
|
|
|
|
} else {
|
|
data_len = front->wsabuf.len;
|
|
wsabuf.buf = front->wsabuf.buf;
|
|
wsabuf.len = data_len;
|
|
::WSASend(sock, &wsabuf, 1, &sendbytes, 0, nullptr, nullptr);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace Network
|