292 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			292 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #pragma once
 | |
| #include "Utils/ThreadPool.hpp"
 | |
| #include "Socket/WSAManager.hpp"
 | |
| #include "Socket/TCPSocket.hpp"
 | |
| #include "Socket/Log.hpp"
 | |
| #include "Packet/Packet.hpp"
 | |
| #include <functional>
 | |
| #include <vector>
 | |
| #include <queue>
 | |
| 
 | |
| #include "precomp.hpp"
 | |
| 
 | |
| #ifdef __linux__
 | |
| 
 | |
| typedef struct _OVERLAPPED {
 | |
|     char dummy;
 | |
| } OVERLAPPED;
 | |
| 
 | |
| typedef struct __WSABUF {
 | |
|     std::uint32_t len;
 | |
|     char *buf;
 | |
| } WSABUF;
 | |
| 
 | |
| #endif
 | |
| 
 | |
| namespace Chattr {
 | |
| 
 | |
| class IOCP;
 | |
| 
 | |
| enum class IOCPEVENT {
 | |
|     ERROR_,
 | |
|     READ,
 | |
|     WRITE
 | |
| };
 | |
| 
 | |
| struct IOCPPASSINDATA {
 | |
|     OVERLAPPED overlapped;
 | |
|     IOCPEVENT event;
 | |
|     std::shared_ptr<TCPSocket> socket;
 | |
|     char buf[1501];
 | |
|     std::uint32_t recvbytes;
 | |
|     std::uint32_t sendbytes;
 | |
|     std::uint32_t transferredbytes;
 | |
|     WSABUF wsabuf;
 | |
|     IOCP* IOCPInstance;
 | |
| #ifdef __linux__
 | |
|     std::shared_ptr<std::queue<IOCPPASSINDATA*>> sendQueue;
 | |
| #endif
 | |
| };
 | |
| 
 | |
| class IOCP {
 | |
| public:
 | |
| #ifdef _WIN32
 | |
|     static void iocpWather(ThreadPool* threadPool, HANDLE completionPort_, std::function<void(ThreadPool*, IOCPPASSINDATA*)> callback) {
 | |
|         DWORD tid = GetCurrentThreadId();
 | |
|         spdlog::trace("Waiting IO to complete on TID: {}.", tid);
 | |
|         IOCPPASSINDATA* data;
 | |
|         SOCKET sock;
 | |
|         DWORD cbTransfrred;
 | |
|         int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred, (PULONG_PTR)&sock, (LPOVERLAPPED*)&data, INFINITE);
 | |
|         if (retVal == 0 || cbTransfrred == 0) {
 | |
|             spdlog::debug("Client disconnected. [{}]", (std::string)(data->socket->remoteAddr));
 | |
|             delete data;
 | |
|             threadPool->enqueueJob(iocpWather, completionPort_, callback);
 | |
|             return;
 | |
|         }
 | |
|         data->transferredbytes = cbTransfrred;
 | |
|         threadPool->enqueueJob(callback, data);
 | |
|         threadPool->enqueueJob(iocpWather, completionPort_, callback);
 | |
|     };
 | |
| #elif __linux__
 | |
|     static void socketReader(ThreadPool* threadPool, epoll_event event, int epollfd, std::function<void(ThreadPool*, IOCPPASSINDATA*)> callback) {
 | |
|         pthread_t tid = pthread_self();
 | |
| 
 | |
|         if (event.data.ptr == nullptr) {
 | |
|             spdlog::error("invalid call on {}", tid);
 | |
|             return;
 | |
|         }
 | |
|         
 | |
|         IOCPPASSINDATA* rootIocpData = (IOCPPASSINDATA*)event.data.ptr;
 | |
| 
 | |
|         std::lock_guard<std::mutex> lock(rootIocpData->socket->readMutex);
 | |
|         while (true) {
 | |
|             char peekBuffer[1];
 | |
|             int rc = rootIocpData->socket->recv(peekBuffer, 1, MSG_PEEK);
 | |
|             if (rc > 0);
 | |
|             else if (rc == 0) {
 | |
|                 spdlog::debug("Client disconnected. [{}]", (std::string)(rootIocpData->socket->remoteAddr));
 | |
|                 ::epoll_ctl(epollfd, EPOLL_CTL_DEL, rootIocpData->socket->sock, NULL);
 | |
|                 delete rootIocpData;
 | |
|                 return;
 | |
|             }
 | |
|             else {
 | |
|                 if (errno == EAGAIN || errno == EWOULDBLOCK) {
 | |
|                     spdlog::trace("No data to read on {}", tid);
 | |
|                     return;
 | |
|                 }
 | |
|                 else {
 | |
|                     spdlog::error("recv() [{}]", strerror(errno));
 | |
|                     ::epoll_ctl(epollfd, EPOLL_CTL_DEL, rootIocpData->socket->sock, NULL);
 | |
|                     delete rootIocpData;
 | |
|                     return;
 | |
|                 }
 | |
|             }
 | |
|             Chattr::IOCPPASSINDATA* ptr = new Chattr::IOCPPASSINDATA;
 | |
|             ::memcpy(ptr, rootIocpData, sizeof(IOCPPASSINDATA));
 | |
|             ::memset(&ptr->overlapped, 0, sizeof(OVERLAPPED));
 | |
|             ptr->recvbytes = ptr->sendbytes = 0;
 | |
|             ptr->wsabuf.buf = ptr->buf;
 | |
|             ptr->wsabuf.len = 1500;
 | |
| 
 | |
|             int redSize = 0;
 | |
|             int headerSize = 8;
 | |
|             int totalRedSize = 0;
 | |
| 
 | |
|             while (totalRedSize < headerSize) {
 | |
|                 redSize = ptr->socket->recv(ptr->buf + totalRedSize, headerSize - totalRedSize, 0);
 | |
|                 
 | |
|                 if (redSize == SOCKET_ERROR) {
 | |
|                     spdlog::error("recv() [{}]", strerror(errno));
 | |
|                     ::epoll_ctl(epollfd, EPOLL_CTL_DEL, ptr->socket->sock, NULL);
 | |
|                     delete ptr;
 | |
|                     return;
 | |
|                 }
 | |
|                 else if (redSize == 0) {
 | |
|                     spdlog::debug("Client disconnected. [{}]", (std::string)ptr->socket->remoteAddr);
 | |
|                     ::epoll_ctl(epollfd, EPOLL_CTL_DEL, ptr->socket->sock, NULL);
 | |
|                     delete ptr;
 | |
|                     return;
 | |
|                 }
 | |
|                 totalRedSize += redSize;
 | |
|             }
 | |
| 
 | |
|             Packet packet;
 | |
|             ::memcpy(packet.serialized, ptr->buf, headerSize);
 | |
| 
 | |
|             redSize = 0;
 | |
|             int dataLength = ntohs(packet.__data.packetLength);
 | |
| 
 | |
|             while (totalRedSize < dataLength + headerSize) {
 | |
|                 redSize = ptr->socket->recv(ptr->buf + totalRedSize, dataLength + headerSize - totalRedSize, 0);
 | |
|                 
 | |
|                 if (redSize == SOCKET_ERROR) {
 | |
|                     if (errno == EAGAIN || errno == EWOULDBLOCK) {
 | |
|                         spdlog::trace("No data to read on {}", tid);
 | |
|                         return;
 | |
|                     }
 | |
|                     spdlog::error("recv() [{}]", strerror(errno));
 | |
|                     ::epoll_ctl(epollfd, EPOLL_CTL_DEL, ptr->socket->sock, NULL);
 | |
|                     delete ptr;
 | |
|                     return;
 | |
|                 }
 | |
|                 else if (redSize == 0) {
 | |
|                     spdlog::debug("Client disconnected. [{}]", (std::string)ptr->socket->remoteAddr);
 | |
|                     ::epoll_ctl(epollfd, EPOLL_CTL_DEL, ptr->socket->sock, NULL);
 | |
|                     delete ptr;
 | |
|                     return;
 | |
|                 }
 | |
|                 totalRedSize += redSize;
 | |
|             }
 | |
|             ptr->transferredbytes = totalRedSize;
 | |
|             threadPool->enqueueJob(callback, ptr);
 | |
|         }
 | |
|     };
 | |
|     static void socketWriter(ThreadPool* threadPool, epoll_event event, int epollfd, std::function<void(ThreadPool*, IOCPPASSINDATA*)> callback) {
 | |
|         pthread_t tid = pthread_self();
 | |
| 
 | |
|         if (event.data.ptr == nullptr) {
 | |
|             spdlog::error("invalid call on {}", tid);
 | |
|             return;
 | |
|         }
 | |
|         
 | |
|         IOCPPASSINDATA* rootIocpData = (IOCPPASSINDATA*)event.data.ptr;
 | |
| 
 | |
|         std::lock_guard<std::mutex> lock(rootIocpData->socket->writeMutex);
 | |
|         while (!rootIocpData->sendQueue->empty()) {
 | |
|             IOCPPASSINDATA* data = rootIocpData->sendQueue->front();
 | |
|             rootIocpData->sendQueue->pop();
 | |
| 
 | |
|             if (data == nullptr) {
 | |
|                 spdlog::error("invalid call on {}", tid);
 | |
|                 break;
 | |
|             }
 | |
| 
 | |
|             int packetSize = data->wsabuf.len;
 | |
|             int totalSentSize = 0;
 | |
|             int sentSize = 0;
 | |
| 
 | |
|             spdlog::trace("Sending to: [{}]", (std::string)data->socket->remoteAddr);
 | |
| 
 | |
|             while (totalSentSize < packetSize) {
 | |
|                 sentSize = data->socket->send(data->buf + totalSentSize, packetSize - totalSentSize, 0);
 | |
| 
 | |
|                 if (sentSize == SOCKET_ERROR) {
 | |
|                     if (errno == EAGAIN || errno == EWOULDBLOCK) {
 | |
|                         spdlog::warn("buffer full");
 | |
|                         continue;
 | |
|                     }
 | |
|                     spdlog::error("send() [{}]", strerror(errno));
 | |
|                     ::epoll_ctl(epollfd, EPOLL_CTL_DEL, data->socket->sock, NULL);
 | |
|                     delete data;
 | |
|                     return;
 | |
|                 }
 | |
|                 totalSentSize += sentSize;
 | |
|             }
 | |
|             data->transferredbytes = totalSentSize;
 | |
|             threadPool->enqueueJob(callback, data);
 | |
|         }
 | |
|     };
 | |
|     static void iocpWatcher(ThreadPool* threadPool, int epollfd, std::function<void(ThreadPool*, IOCPPASSINDATA*)> callback) {
 | |
|         struct epoll_event events[FD_SETSIZE];
 | |
|         pthread_t tid = pthread_self();
 | |
| 
 | |
|         int nready = ::epoll_wait(epollfd, events, FD_SETSIZE, -1);
 | |
| 
 | |
|         for (int i=0; i<nready; i++) {
 | |
|             struct epoll_event current_event = events[i];
 | |
| 
 | |
|             if (current_event.events & EPOLLIN) {
 | |
|                 std::function<void(ThreadPool*, IOCPPASSINDATA*)> task(callback);
 | |
|                 threadPool->enqueueJob(socketReader, current_event, epollfd, task);
 | |
|             }
 | |
|             else if (current_event.events & EPOLLOUT) {
 | |
|                 std::function<void(ThreadPool*, IOCPPASSINDATA*)> task(callback);
 | |
|                 threadPool->enqueueJob(socketWriter, current_event, epollfd, task);
 | |
|             }
 | |
|             if (--nready <= 0)
 | |
|                 break;
 | |
|         }
 | |
|         threadPool->enqueueJob(iocpWatcher, epollfd, callback);
 | |
|     };
 | |
| #endif
 | |
| 
 | |
|     IOCP();
 | |
|     ~IOCP();
 | |
| 
 | |
|     template<typename _Callable>
 | |
|     void init(ThreadPool* __IOCPThread, _Callable&& callback) {
 | |
|         IOCPThread_ = __IOCPThread;
 | |
| #ifdef _WIN32
 | |
|         completionPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
 | |
|         if (completionPort_ == NULL)
 | |
|             log::critical("CreateIoCompletionPort()");
 | |
| #elif __linux__
 | |
|         epollfd_ = ::epoll_create(1);
 | |
|         epollDetroyerFd_ = ::eventfd(0, EFD_NONBLOCK);
 | |
|         ::epoll_ctl(epollfd_, EPOLL_CTL_ADD, epollDetroyerFd_, NULL);
 | |
| #endif
 | |
|         auto boundFunc = [callback = std::move(callback)](ThreadPool* __IOCPThread, IOCPPASSINDATA* data) mutable {
 | |
|             callback(__IOCPThread, data);
 | |
|         };
 | |
| 
 | |
| #ifdef _WIN32
 | |
|         int tCount = __IOCPThread->threadCount;
 | |
| 
 | |
|         spdlog::info("Resizing threadpool size to: {}", tCount * 2);
 | |
| 
 | |
|         __IOCPThread->respawnWorker(tCount * 2);
 | |
| 
 | |
|         for (int i = 0; i < tCount; i++) {
 | |
|             std::function<void(ThreadPool*, IOCPPASSINDATA*)> task(boundFunc);
 | |
|             __IOCPThread->enqueueJob(iocpWather, completionPort_, task);
 | |
|         }
 | |
| #elif __linux__
 | |
|         __IOCPThread->respawnWorker(__IOCPThread->threadCount + 1);
 | |
|         spdlog::info("Spawning 1 Epoll Waiter...");
 | |
|         __IOCPThread->enqueueJob(iocpWatcher, epollfd_, boundFunc);
 | |
| #endif
 | |
|     }
 | |
| 
 | |
|     void destruct();
 | |
| 
 | |
|     void registerSocket(IOCPPASSINDATA* data);
 | |
| 
 | |
|     int recv(IOCPPASSINDATA* data, int bufferCount);
 | |
|     int send(IOCPPASSINDATA* data, int bufferCount, int __flags, bool client = false);
 | |
| 
 | |
| private:
 | |
|     struct WSAManager wsaManager;
 | |
|     ThreadPool* IOCPThread_;
 | |
| 
 | |
| #ifdef _WIN32
 | |
|     HANDLE completionPort_ = INVALID_HANDLE_VALUE;
 | |
| #elif __linux__
 | |
|     int epollfd_ = -1;
 | |
|     int epollDetroyerFd_ = -1;
 | |
|     std::unordered_map<std::shared_ptr<TCPSocket>, std::mutex> writeMutex;
 | |
|     std::unordered_map<std::shared_ptr<TCPSocket>, std::queue<IOCPPASSINDATA*>> writeBuffer;
 | |
| #endif
 | |
| };
 | |
| 
 | |
| } |