244 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			244 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #pragma once
 | |
| #include "Utils/ThreadPool.hpp"
 | |
| #include "Socket/WSAManager.hpp"
 | |
| #include "Socket/TCPSocket.hpp"
 | |
| #include "Socket/Log.hpp"
 | |
| #include "Packet/Packet.hpp"
 | |
| #include <functional>
 | |
| #include <vector>
 | |
| 
 | |
| #include "precomp.hpp"
 | |
| 
 | |
| #ifdef __linux__
 | |
| 
 | |
| typedef struct _OVERLAPPED {
 | |
|     char dummy;
 | |
| } OVERLAPPED;
 | |
| 
 | |
| typedef struct __WSABUF {
 | |
|     std::uint32_t len;
 | |
|     char *buf;
 | |
| } WSABUF;
 | |
| 
 | |
| #endif
 | |
| 
 | |
| namespace Chattr {
 | |
| 
 | |
| class IOCP;
 | |
| 
 | |
| enum class IOCPEVENT {
 | |
|     ERROR_,
 | |
|     READ,
 | |
|     WRITE
 | |
| };
 | |
| 
 | |
| struct IOCPPASSINDATA {
 | |
|     OVERLAPPED overlapped;
 | |
|     IOCPEVENT event;
 | |
|     std::shared_ptr<TCPSocket> socket;
 | |
|     char buf[1501];
 | |
|     std::uint32_t recvbytes;
 | |
|     std::uint32_t sendbytes;
 | |
|     std::uint32_t transferredbytes;
 | |
|     WSABUF wsabuf;
 | |
|     IOCP* IOCPInstance;
 | |
| };
 | |
| 
 | |
| class IOCP {
 | |
| public:
 | |
| #ifdef _WIN32
 | |
|     static void iocpWather(ThreadPool* threadPool, HANDLE completionPort_, std::function<void(ThreadPool*, IOCPPASSINDATA*)> callback) {
 | |
|         DWORD tid = GetCurrentThreadId();
 | |
|         spdlog::trace("Waiting IO to complete on TID: {}.", tid);
 | |
|         IOCPPASSINDATA* data;
 | |
|         SOCKET sock;
 | |
|         DWORD cbTransfrred;
 | |
|         int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred, (PULONG_PTR)&sock, (LPOVERLAPPED*)&data, INFINITE);
 | |
|         if (retVal == 0 || cbTransfrred == 0) {
 | |
|             spdlog::debug("Client disconnected. [{}]", (std::string)(data->socket->remoteAddr));
 | |
|             delete data;
 | |
|             threadPool->enqueueJob(iocpWather, completionPort_, callback);
 | |
|             return;
 | |
|         }
 | |
|         data->transferredbytes = cbTransfrred;
 | |
|         threadPool->enqueueJob(callback, data);
 | |
|         threadPool->enqueueJob(iocpWather, completionPort_, callback);
 | |
|     };
 | |
| #elif __linux__
 | |
|     static void socketReader(ThreadPool* threadPool, epoll_event event, int epollfd, std::function<void(ThreadPool*, IOCPPASSINDATA*)> callback) {
 | |
|         IOCPPASSINDATA* iocpData = (IOCPPASSINDATA*)event.data.ptr;
 | |
| 
 | |
|         pthread_t tid = pthread_self();
 | |
| 
 | |
|         if (iocpData == nullptr) {
 | |
|             spdlog::error("invalid call on {}", tid);
 | |
|             return;
 | |
|         }
 | |
| 
 | |
|         std::lock_guard<std::mutex> lock(iocpData->socket->readMutex);
 | |
| 
 | |
|         spdlog::trace("reading on tid: {} [{}]", tid, (std::string)iocpData->socket->remoteAddr);
 | |
| 
 | |
|         int redSize = 0;
 | |
|         int headerSize = 8;
 | |
|         int totalRedSize = 0;
 | |
| 
 | |
|         while (totalRedSize < headerSize) {
 | |
|             redSize = iocpData->socket->recv(iocpData->buf + totalRedSize, headerSize - totalRedSize, 0);
 | |
|             
 | |
|             if (redSize == SOCKET_ERROR) {
 | |
|                 spdlog::error("recv() [{}]", strerror(errno));
 | |
|                 ::epoll_ctl(epollfd, EPOLL_CTL_DEL, iocpData->socket->sock, NULL);
 | |
|                 delete iocpData;
 | |
|                 return;
 | |
|             }
 | |
|             else if (redSize == 0) {
 | |
|                 spdlog::debug("Client disconnected. [{}]", (std::string)iocpData->socket->remoteAddr);
 | |
|                 ::epoll_ctl(epollfd, EPOLL_CTL_DEL, iocpData->socket->sock, NULL);
 | |
|                 delete iocpData;
 | |
|                 return;
 | |
|             }
 | |
|             totalRedSize += redSize;
 | |
|         }
 | |
| 
 | |
|         Packet packet;
 | |
|         ::memcpy(packet.serialized, iocpData->buf, headerSize);
 | |
| 
 | |
|         redSize = 0;
 | |
|         int dataLength = ntohs(packet.__data.packetLength);
 | |
| 
 | |
|         while (totalRedSize < dataLength + headerSize) {
 | |
|             redSize = iocpData->socket->recv(iocpData->buf + totalRedSize, dataLength + headerSize - totalRedSize, 0);
 | |
|             
 | |
|             if (redSize == SOCKET_ERROR) {
 | |
|                 spdlog::error("recv() [{}]", strerror(errno));
 | |
|                 ::epoll_ctl(epollfd, EPOLL_CTL_DEL, iocpData->socket->sock, NULL);
 | |
|                 delete iocpData;
 | |
|                 return;
 | |
|             }
 | |
|             else if (redSize == 0) {
 | |
|                 spdlog::debug("Client disconnected. [{}]", (std::string)iocpData->socket->remoteAddr);
 | |
|                 ::epoll_ctl(epollfd, EPOLL_CTL_DEL, iocpData->socket->sock, NULL);
 | |
|                 delete iocpData;
 | |
|                 return;
 | |
|             }
 | |
|             totalRedSize += redSize;
 | |
|         }
 | |
|         iocpData->transferredbytes = totalRedSize;
 | |
|         threadPool->enqueueJob(callback, iocpData);
 | |
|     };
 | |
|     static void socketWriter(ThreadPool* threadPool, epoll_event event, int epollfd, std::function<void(ThreadPool*, IOCPPASSINDATA*)> callback) {
 | |
|         IOCPPASSINDATA* iocpData = (IOCPPASSINDATA*)event.data.ptr;
 | |
|         
 | |
|         pthread_t tid = pthread_self();
 | |
| 
 | |
|         if (iocpData == nullptr) {
 | |
|             spdlog::error("invalid call on {}", tid);
 | |
|             return;
 | |
|         }
 | |
| 
 | |
|         std::lock_guard<std::mutex> lock(iocpData->socket->writeMutex);
 | |
| 
 | |
|         spdlog::trace("Writing on tid: {} [{}]", tid, (std::string)iocpData->socket->remoteAddr);
 | |
| 
 | |
|         int packetSize = iocpData->wsabuf.len;
 | |
|         int totalSentSize = 0;
 | |
|         int sentSize = 0;
 | |
| 
 | |
|         spdlog::trace("Sending to: [{}]", (std::string)iocpData->socket->remoteAddr);
 | |
| 
 | |
|         while (totalSentSize < packetSize) {
 | |
|             sentSize = iocpData->socket->send(iocpData->buf + totalSentSize, packetSize - totalSentSize, 0);
 | |
| 
 | |
|             if (sentSize == SOCKET_ERROR) {
 | |
|                 spdlog::error("send() [{}]", strerror(errno));
 | |
|                 ::epoll_ctl(epollfd, EPOLL_CTL_DEL, iocpData->socket->sock, NULL);
 | |
|                 delete iocpData;
 | |
|                 return;
 | |
|             }
 | |
|             totalSentSize += sentSize;
 | |
|         }
 | |
|         iocpData->transferredbytes = totalSentSize;
 | |
|         threadPool->enqueueJob(callback, iocpData);
 | |
|     };
 | |
|     static void iocpWatcher(ThreadPool* threadPool, int epollfd, std::function<void(ThreadPool*, IOCPPASSINDATA*)> callback) {
 | |
|         struct epoll_event events[FD_SETSIZE];
 | |
|         pthread_t tid = pthread_self();
 | |
| 
 | |
|         spdlog::trace("epoll waiting on {}", tid);
 | |
|         int nready = ::epoll_wait(epollfd, events, FD_SETSIZE, -1);
 | |
| 
 | |
|         for (int i=0; i<nready; i++) {
 | |
|             struct epoll_event current_event = events[i];
 | |
| 
 | |
|             if (current_event.events & EPOLLIN) {
 | |
|                 std::function<void(ThreadPool*, IOCPPASSINDATA*)> task(callback);
 | |
|                 threadPool->enqueueJob(socketReader, current_event, epollfd, task);
 | |
|             }
 | |
|             else if (current_event.events & EPOLLOUT) {
 | |
|                 std::function<void(ThreadPool*, IOCPPASSINDATA*)> task(callback);
 | |
|                 threadPool->enqueueJob(socketWriter, current_event, epollfd, task);
 | |
|             }
 | |
|             if (--nready <= 0)
 | |
|                 break;
 | |
|         }
 | |
|         threadPool->enqueueJob(iocpWatcher, epollfd, callback);
 | |
|     };
 | |
| #endif
 | |
| 
 | |
|     IOCP();
 | |
|     ~IOCP();
 | |
| 
 | |
|     template<typename _Callable>
 | |
|     void init(ThreadPool* __IOCPThread, _Callable&& callback) {
 | |
|         IOCPThread_ = __IOCPThread;
 | |
| #ifdef _WIN32
 | |
|         completionPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
 | |
|         if (completionPort_ == NULL)
 | |
|             log::critical("CreateIoCompletionPort()");
 | |
| #elif __linux__
 | |
|         epollfd_ = ::epoll_create(1);
 | |
| #endif
 | |
|         auto boundFunc = [callback = std::move(callback)](ThreadPool* __IOCPThread, IOCPPASSINDATA* data) mutable {
 | |
|             callback(__IOCPThread, data);
 | |
|         };
 | |
| 
 | |
| #ifdef _WIN32
 | |
|         int tCount = __IOCPThread->threadCount;
 | |
| 
 | |
|         spdlog::info("Resizing threadpool size to: {}", tCount * 2);
 | |
| 
 | |
|         __IOCPThread->respawnWorker(tCount * 2);
 | |
| 
 | |
|         for (int i = 0; i < tCount; i++) {
 | |
|             std::function<void(ThreadPool*, IOCPPASSINDATA*)> task(boundFunc);
 | |
|             __IOCPThread->enqueueJob(iocpWather, completionPort_, task);
 | |
|         }
 | |
| #elif __linux__
 | |
|         __IOCPThread->respawnWorker(__IOCPThread->threadCount + 1);
 | |
|         spdlog::info("Spawning 1 Epoll Waiter...");
 | |
|         __IOCPThread->enqueueJob(iocpWatcher, epollfd_, boundFunc);
 | |
| #endif
 | |
|     }
 | |
| 
 | |
|     void destruct();
 | |
| 
 | |
|     void registerSocket(IOCPPASSINDATA* data);
 | |
| 
 | |
|     int recv(IOCPPASSINDATA* data, int bufferCount);
 | |
|     int send(IOCPPASSINDATA* data, int bufferCount, int __flags, bool client = false);
 | |
| 
 | |
| private:
 | |
|     struct WSAManager wsaManager;
 | |
|     ThreadPool* IOCPThread_;
 | |
| 
 | |
| #ifdef _WIN32
 | |
|     HANDLE completionPort_ = INVALID_HANDLE_VALUE;
 | |
| #elif __linux__
 | |
|     int epollfd_ = -1;
 | |
|     std::unordered_map<std::shared_ptr<TCPSocket>, std::mutex> writeMutex;
 | |
|     std::unordered_map<std::shared_ptr<TCPSocket>, std::queue<IOCPPASSINDATA*>> writeBuffer;
 | |
| #endif
 | |
| };
 | |
| 
 | |
| } |