diff --git a/Server/src/server.cpp b/Server/src/server.cpp index 2837cb6..743cde8 100644 --- a/Server/src/server.cpp +++ b/Server/src/server.cpp @@ -6,6 +6,7 @@ #include "Utils/ThreadPool.hpp" #include "Utils/StringTokenizer.hpp" #include "Utils/Snowflake.hpp" +#include "Socket/IOCP.hpp" #include "Packet/Packet.hpp" #include @@ -13,13 +14,16 @@ #include "precomp.hpp" -void _TCPRecvClient(Chattr::ThreadPool* thread, Chattr::TCPSocket sock, Chattr::Address addr); -void _TCPSendClient(Chattr::ThreadPool* thread, Chattr::TCPSocket sock, Chattr::Address addr, std::queue packets); +void _TCPRecvClient(Chattr::ThreadPool* threadPool, Chattr::IOCPPASSINDATA* data); +void _TCPSendClient(Chattr::ThreadPool* threadPool, Chattr::TCPSocket sock, std::queue packets); int main() { - struct Chattr::WSAManager wsaManager; auto config = Chattr::ConfigManager::load(); Chattr::log::setDefaultLogger(config.logLevel, config.logFileName, config.logfileSize, config.logfileCount); + Chattr::ThreadPool threadPool(0); + Chattr::IOCP iocp; + iocp.init(&threadPool, _TCPRecvClient); + // struct Chattr::WSAManager wsaManager; Chattr::TCPSocket sock; struct Chattr::Address serveraddr; @@ -43,61 +47,84 @@ int main() { #elif __linux__ pid_t pid = getpid(); #endif - spdlog::info("PID : {}", pid); + spdlog::debug("PID : {}", pid); - - Chattr::ThreadPool threadPool(0); - /*int returnedIntager = 0; - int passvalue = 2; - threadPool.enqueueJob([](int& i){ - spdlog::info("JobTest"); - if (i == 2) { - i = 1; - return 2; - } - return 1; - }, returnedIntager, std::ref(passvalue));*/ + DWORD recvbytes = 0, flags = 0; while (true) { spdlog::info("Waiting for connection..."); sock.accept(clientSock, clientAddr); + Chattr::IOCPPASSINDATA* ptr = new Chattr::IOCPPASSINDATA; + ZeroMemory(&ptr->overlapped, sizeof(OVERLAPPED)); + ptr->socket = std::move(clientSock); + ptr->recvbytes = ptr->sendbytes = 0; + ptr->wsabuf.buf = ptr->buf; + ptr->wsabuf.len = 1500; - threadPool.enqueueJob(_TCPRecvClient, std::move(clientSock), clientAddr); + iocp.registerSocket(ptr->socket.sock); + + int returnData = WSARecv(ptr->socket.sock, &ptr->wsabuf, 1, &recvbytes, &flags, &ptr->overlapped, NULL); } } -void _TCPRecvClient(Chattr::ThreadPool* thread, Chattr::TCPSocket sock, Chattr::Address addr) { +void _TCPRecvClient(Chattr::ThreadPool* thread, Chattr::IOCPPASSINDATA* data) { +#ifdef _WIN32 + DWORD tid = GetCurrentThreadId(); +#elif __linux__ + pthread_t tid = pthread_self(); +#endif + spdlog::info("entered recvfunc on TID: {}", tid); Chattr::Packet pack; int packetSize = 1500; int totalRedSize = 0; int redSize = 0; - while (totalRedSize < packetSize) { - redSize = sock.recv(&pack.serialized, 1500 - totalRedSize, 0); - - if (redSize <= 0) { - spdlog::info("Client disconnected. [{}]", (std::string)addr); - return; - } - totalRedSize += redSize; - } + // spdlog::info("Receving from: [{}]", (std::string)sock.remoteAddr); + + //while (totalRedSize < packetSize) { + // redSize = data->socket.recv(&pack.serialized, 1500 - totalRedSize, 0); + // + // if (redSize <= 0) { + // // spdlog::info("Client disconnected. [{}]", (std::string)sock.remoteAddr); + // return; + // } + // totalRedSize += redSize; + //} + + memcpy(pack.serialized, data->wsabuf.buf, data->wsabuf.len); std::string recvString; bool packetError = false; Chattr::DataPostPacket dataPostPacket; + + Chattr::ResponsePacket packet; + std::queue responsePackets; + switch (pack.__data.packetType) { case Chattr::PacketType::PACKET_POST: + if (pack.__data.requestType != Chattr::RequestType::DATA){ + packetError = true; + break; + } switch (pack.__data.dataType) { case Chattr::DataType::TEXT: std::memcpy(&dataPostPacket.serialized, &pack, 1500); dataPostPacket.convToH(); - recvString = std::string((char *)dataPostPacket.__data.data, dataPostPacket.__data.packetLength - sizeof(std::uint16_t)); - spdlog::info("Red size : {}, {} from : [{}]", redSize, recvString, (std::string)addr); + dataPostPacket.__data.packetLength = (dataPostPacket.__data.packetLength > 1487) ? 1487 : dataPostPacket.__data.packetLength; + recvString = std::string((char *)dataPostPacket.__data.data, dataPostPacket.__data.packetLength - (sizeof(std::uint16_t)*4)); + spdlog::info("Red size : {}, {} from : [{}]", totalRedSize, recvString, (std::string)data->socket.remoteAddr); break; case Chattr::DataType::BINARY: break; default: + packet.__data.packetType = Chattr::PacketType::PACKET_RESPONSE; + packet.__data.requestType = Chattr::RequestType::DATA; + packet.__data.dataType = Chattr::DataType::TEXT; + packet.__data.packetLength = sizeof(Chattr::ResponseStatusCode); + packet.__data.responseStatusCode = Chattr::ResponseStatusCode::BAD_REQUEST; + packet.convToN(); + responsePackets.push(packet); packetError = true; break; } @@ -117,6 +144,13 @@ void _TCPRecvClient(Chattr::ThreadPool* thread, Chattr::TCPSocket sock, Chattr:: case Chattr::RequestType::USERS_LIST: break; default: + packet.__data.packetType = Chattr::PacketType::PACKET_RESPONSE; + packet.__data.requestType = Chattr::RequestType::DATA; + packet.__data.dataType = Chattr::DataType::TEXT; + packet.__data.packetLength = sizeof(Chattr::ResponseStatusCode); + packet.__data.responseStatusCode = Chattr::ResponseStatusCode::BAD_REQUEST; + packet.convToN(); + responsePackets.push(packet); packetError = true; break; } @@ -124,42 +158,48 @@ void _TCPRecvClient(Chattr::ThreadPool* thread, Chattr::TCPSocket sock, Chattr:: case Chattr::PacketType::PACKET_CONTINUE: break; default: + packet.__data.packetType = Chattr::PacketType::PACKET_RESPONSE; + packet.__data.requestType = Chattr::RequestType::DATA; + packet.__data.dataType = Chattr::DataType::TEXT; + packet.__data.packetLength = sizeof(Chattr::ResponseStatusCode); + packet.__data.responseStatusCode = Chattr::ResponseStatusCode::BAD_REQUEST; + packet.convToN(); + responsePackets.push(packet); packetError = true; break; } - if (packetError) { - Chattr::ResponsePacket packet; - packet.__data.responseStatusCode = Chattr::ResponseStatusCode::BAD_REQUEST; - std::queue packs; - packet.convToN(); - packs.push(packet); - thread->enqueueJob(_TCPSendClient, std::move(sock), addr, packs); + Sleep(10000); + + /*if (packetError) { + thread->enqueueJob(_TCPRecvClient, data);; } else - thread->enqueueJob(_TCPRecvClient, std::move(sock), addr); + thread->enqueueJob(_TCPRecvClient, data);*/ } -void _TCPSendClient(Chattr::ThreadPool* thread, Chattr::TCPSocket sock, Chattr::Address addr, std::queue packets) { +void _TCPSendClient(Chattr::ThreadPool* thread, Chattr::TCPSocket sock, std::queue packets) { Chattr::ResponsePacket pack = packets.front(); packets.pop(); int packetSize = 1500; int totalSentSize = 0; int sentSize = 0; + spdlog::info("Sending to: [{}]", (std::string)sock.remoteAddr); + while (totalSentSize < packetSize) { sentSize = sock.send(&pack.serialized, 1500 - totalSentSize, 0); if (sentSize <= 0) { - spdlog::info("Client disconnected. [{}]", (std::string)addr); + spdlog::info("Client disconnected. [{}]", (std::string)sock.remoteAddr); return; } totalSentSize += sentSize; } - if (packets.empty()) - thread->enqueueJob(_TCPRecvClient, std::move(sock), addr); + /*if (packets.empty()) + thread->enqueueJob(_TCPRecvClient, std::move(sock)); else - thread->enqueueJob(_TCPSendClient, std::move(sock), addr, packets); + thread->enqueueJob(_TCPSendClient, std::move(sock), packets);*/ } \ No newline at end of file diff --git a/impl/Socket/IOCP.cpp b/impl/Socket/IOCP.cpp index c731702..3629591 100644 --- a/impl/Socket/IOCP.cpp +++ b/impl/Socket/IOCP.cpp @@ -1,24 +1,15 @@ #include "Socket/IOCP.hpp" -#include "Socket/WSAManager.hpp" -#include "Socket/Log.hpp" +#include "Utils/ThreadPool.hpp" + #include "precomp.hpp" namespace Chattr { -IOCP::IOCP(std::shared_ptr __IOCPThread) { - init(__IOCPThread); -} - -IOCP::~IOCP() { -} - -void IOCP::init(std::shared_ptr __IOCPThread) { - IOCPThread_ = __IOCPThread; +void IOCP::registerSocket(SOCKET sock) { #ifdef _WIN32 - struct Chattr::WSAManager wsaManager; - completinPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); - if (completinPort_ == NULL) - log::critical("CreateIoCompletionPort()"); + HANDLE returnData = ::CreateIoCompletionPort((HANDLE)sock, completionPort_, sock, 0); + if (returnData == 0) + completionPort_ = returnData; #elif __linux__ #endif diff --git a/impl/Socket/Socket.cpp b/impl/Socket/Socket.cpp index fe1348f..3f65bd9 100644 --- a/impl/Socket/Socket.cpp +++ b/impl/Socket/Socket.cpp @@ -15,8 +15,8 @@ Socket::~Socket() { int Socket::init(int domain, int type, int protocol) { this->domain = domain; - sock_ = ::socket(domain, type, protocol); - if (sock_ == INVALID_SOCKET) + sock = ::socket(domain, type, protocol); + if (sock == INVALID_SOCKET) log::critical("socket()"); valid_ = true; @@ -28,9 +28,9 @@ void Socket::destruct() { if (!valid_) return; #ifdef _WIN32 - ::closesocket(sock_); + ::closesocket(sock); #elif __linux__ - ::close(sock_); + ::close(sock); #endif valid_ = false; } @@ -38,7 +38,7 @@ void Socket::destruct() { Socket::operator SOCKET() { if (valid_) { valid_ = false; - return sock_; + return sock; } spdlog::critical("No valid socket created."); return INVALID_SOCKET; @@ -50,41 +50,41 @@ void Socket::set(const SOCKET __sock, int __domain) { destruct(); - sock_ = __sock; + sock = __sock; valid_ = true; }; int Socket::bind(Address __addr) { bindAddr = __addr; - int retVal = ::bind(sock_, &__addr.addr, __addr.length); + int retVal = ::bind(sock, &__addr.addr, __addr.length); if (retVal == INVALID_SOCKET) log::critical("bind()"); return retVal; } int Socket::recvfrom(void *__restrict __buf, size_t __n, int __flags, struct Address& __addr) { - int retVal = ::recvfrom(sock_, (char*)__buf, __n, __flags, &__addr.addr, &__addr.length); + int retVal = ::recvfrom(sock, (char*)__buf, __n, __flags, &__addr.addr, &__addr.length); if (retVal == SOCKET_ERROR) log::error("recvfrom()"); return retVal; } int Socket::sendto(const void *__buf, size_t __n, int __flags, struct Address __addr) { - int retVal = ::sendto(sock_, (char*)__buf, __n, __flags, &__addr.addr, __addr.length); + int retVal = ::sendto(sock, (char*)__buf, __n, __flags, &__addr.addr, __addr.length); if (retVal == SOCKET_ERROR) log::error("sendto()"); return retVal; } -Socket::Socket(Socket &&other_) { +Socket::Socket(Socket &&other_) noexcept { other_.valid_ = false; - sock_ = other_.sock_; + memcpy(this, &other_, sizeof(Socket)); valid_ = true; } -Socket& Socket::operator=(Socket && other_) { +Socket& Socket::operator=(Socket && other_) noexcept { other_.valid_ = false; - sock_ = other_.sock_; + memcpy(this, &other_, sizeof(Socket)); valid_ = true; return *this; diff --git a/impl/Socket/TCPSocket.cpp b/impl/Socket/TCPSocket.cpp index 0462f19..a88f0db 100644 --- a/impl/Socket/TCPSocket.cpp +++ b/impl/Socket/TCPSocket.cpp @@ -9,35 +9,36 @@ int TCPSocket::init(int domain) { } int TCPSocket::listen(int __n) { - int retVal = ::listen(sock_, __n); + int retVal = ::listen(sock, __n); if (retVal == INVALID_SOCKET) log::error("listen()"); return retVal; } void TCPSocket::accept(TCPSocket& newSock, Address& __addr) { - newSock.set(::accept(sock_, &__addr.addr, &__addr.length), domain); - if (newSock == INVALID_SOCKET) + newSock.set(::accept(sock, &__addr.addr, &__addr.length), domain); + memcpy(&newSock.remoteAddr, &__addr, sizeof(Chattr::Address)); + if (newSock.sock == INVALID_SOCKET) log::error("accept()"); } int TCPSocket::connect(Address& serveraddr) { - int retVal = ::connect(sock_, (struct sockaddr *)&serveraddr.addr, serveraddr.length); - remoteAddr = serveraddr; + int retVal = ::connect(sock, (struct sockaddr *)&serveraddr.addr, serveraddr.length); + memcpy(&remoteAddr, &serveraddr, sizeof(Chattr::Address)); if (retVal == INVALID_SOCKET) log::error("connect()"); return retVal; } int TCPSocket::recv(void *__restrict __buf, size_t __n, int __flags) { - int retVal = ::recv(sock_, (char *)__buf, __n, __flags); + int retVal = ::recv(sock, (char *)__buf, __n, __flags); if (retVal == SOCKET_ERROR) log::error("recv()"); return retVal; } int TCPSocket::send(const void *__buf, size_t __n, int __flags) { - int retVal = ::send(sock_, (char*)__buf, __n, __flags); + int retVal = ::send(sock, (char*)__buf, __n, __flags); if (retVal == SOCKET_ERROR) log::error("send()"); return retVal; diff --git a/impl/Utils/ThreadPool.cpp b/impl/Utils/ThreadPool.cpp index 5033dcd..4ce4ef2 100644 --- a/impl/Utils/ThreadPool.cpp +++ b/impl/Utils/ThreadPool.cpp @@ -11,11 +11,7 @@ ThreadPool::ThreadPool(std::uint32_t numThreads) { } ThreadPool::~ThreadPool() { - terminate_ = true; - jobQueueCV_.notify_all(); - - for (auto& t : workers_) - t.join(); + terminate(); } void ThreadPool::init(std::uint32_t numThreads) { @@ -29,12 +25,36 @@ void ThreadPool::init(std::uint32_t numThreads) { numCPU = sysconf(_SC_NPROCESSORS_ONLN); #endif spdlog::info("Auto-detected cpu count: {}", numCPU); - spdlog::info("Set ThreadPool Worker count to: {}", numCPU); + if (numCPU == 1 || numCPU == 2) { + numCPU = 4; + spdlog::info("Set ThreadPool Worker count to: {} due to program to oprate concurrently", numCPU); + } + else { + spdlog::info("Set ThreadPool Worker count to: {}", numCPU); + } } + threadCount = numCPU; workers_.reserve(numCPU); while (numCPU--) - workers_.push_back([this]() { this->Worker(); }); + workers_.push_back([this]() { + this->Worker(); + }); +} + +void ThreadPool::terminate() { + terminate_ = true; + jobQueueCV_.notify_all(); + + spdlog::debug("waiting for threads to end their jobs..."); + for (auto& t : workers_) + t.join(); +} + +void ThreadPool::respawnWorker(std::uint32_t numThreads) { + terminate(); + terminate_ = false; + init(numThreads); } void* ThreadPool::Worker() { @@ -43,18 +63,17 @@ void* ThreadPool::Worker() { #elif __linux__ pthread_t pid = pthread_self(); #endif - spdlog::info("ThreadPool Worker : {} up.", pid); while (!terminate_) { std::unique_lock lock(jobQueueMutex); jobQueueCV_.wait(lock, [this]() { return !this->jobs_.empty() || terminate_; }); if (this->jobs_.empty()) - return nullptr; + break; auto job = std::move(jobs_.front()); jobs_.pop(); lock.unlock(); - spdlog::info("ThreadPool Worker : {} Executing a job", pid); + spdlog::debug("ThreadPool Worker : {} Executing a job", pid); job(); } diff --git a/include/Socket/IOCP.hpp b/include/Socket/IOCP.hpp index 40fce48..79a4ccc 100644 --- a/include/Socket/IOCP.hpp +++ b/include/Socket/IOCP.hpp @@ -1,24 +1,81 @@ #pragma once #include "Utils/ThreadPool.hpp" +#include "Socket/WSAManager.hpp" +#include "Socket/TCPSocket.hpp" +#include "Socket/Log.hpp" +#include namespace Chattr { + +struct IOCPPASSINDATA { + OVERLAPPED overlapped; + TCPSocket socket; + char buf[1501]; + int recvbytes; + int sendbytes; + WSABUF wsabuf; +}; + class IOCP { public: - IOCP(std::shared_ptr __IOCPThread); - ~IOCP(); + static void iocpWather(ThreadPool* threadPool, HANDLE completionPort_, std::function callback) { + DWORD tid = GetCurrentThreadId(); + spdlog::debug("Waiting IO to complete on TID: {}.", tid); + IOCPPASSINDATA* data; + SOCKET sock; + DWORD cbTransfrred; + int retVal = GetQueuedCompletionStatus(completionPort_, &cbTransfrred, (PULONG_PTR)&sock, (LPOVERLAPPED*)&data, INFINITE); + if (retVal == 0 || cbTransfrred == 0) { + spdlog::info("Client disconnected. [{}]", (std::string)(data->socket.remoteAddr)); + threadPool->enqueueJob(iocpWather, completionPort_, callback); + return; + } + threadPool->enqueueJob(callback, data); + threadPool->enqueueJob(iocpWather, completionPort_, callback); + }; - void init(std::shared_ptr __IOCPThread); + template + void init(ThreadPool* __IOCPThread, _Callable&& callback) { + IOCPThread_ = __IOCPThread; +#ifdef _WIN32 + completionPort_ = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); + if (completionPort_ == NULL) + log::critical("CreateIoCompletionPort()"); + + auto boundFunc = [callback = std::move(callback)](ThreadPool* __IOCPThread, IOCPPASSINDATA* data) mutable { + callback(__IOCPThread, data); + }; + + int tCount = __IOCPThread->threadCount; + + spdlog::info("Resizing threadpool size to: {}", tCount * 2); + + __IOCPThread->respawnWorker(tCount * 2); + + spdlog::info("Set IOCP Worker count to: {}", tCount); + for (int i = 0; i < tCount; i++) { + std::function task(boundFunc); + __IOCPThread->enqueueJob(iocpWather, completionPort_, task); + } +#elif __linux__ + +#endif + } + + void registerSocket(SOCKET sock); int recv(void* __restrict __buf, size_t __n, int __flags); int send(const void* __buf, size_t __n, int __flags); private: - std::shared_ptr IOCPThread_; + struct Chattr::WSAManager wsaManager; + ThreadPool* IOCPThread_; #ifdef _WIN32 - HANDLE completinPort_; + HANDLE completionPort_ = INVALID_HANDLE_VALUE; #elif __linux__ #endif }; + } \ No newline at end of file diff --git a/include/Socket/Socket.hpp b/include/Socket/Socket.hpp index 3f24e36..edf869d 100644 --- a/include/Socket/Socket.hpp +++ b/include/Socket/Socket.hpp @@ -23,17 +23,17 @@ public: int sendto(const void *__buf, size_t __n, int __flags, struct Address __addr); Socket(const Socket&) = delete; - Socket(Socket&&); + Socket(Socket&&) noexcept; Socket& operator=(const Socket&) = delete; - Socket& operator=(Socket&&); + Socket& operator=(Socket&&) noexcept; struct Address bindAddr = {}; struct Address remoteAddr = {}; int domain = 0; + SOCKET sock = INVALID_SOCKET; protected: bool valid_ = false; - SOCKET sock_ = INVALID_SOCKET; }; } \ No newline at end of file diff --git a/include/Utils/Thread.hpp b/include/Utils/Thread.hpp index 13c3711..94e5ae9 100644 --- a/include/Utils/Thread.hpp +++ b/include/Utils/Thread.hpp @@ -39,7 +39,7 @@ public: requires (!std::is_same_v, Thread>) && (!std::is_void_v>) Thread(_Callable&& __f, _Args&&... __args) { - auto boundFunc = [this, __f = std::move(__f), ... __args = std::move(__args)]() mutable { + auto boundFunc = [this, __f, ... __args = std::move(__args)]() mutable { returnValuePtr = new std::invoke_result_t<_Callable, _Args...>(__f(std::move(__args)...)); }; std::packaged_task* funcPtr = new std::packaged_task(std::move(boundFunc)); @@ -53,7 +53,7 @@ public: requires (!std::is_same_v, Thread>) && std::is_void_v> Thread(_Callable&& __f, _Args&&... __args) { - auto boundFunc = [this, __f = std::move(__f), ... __args = std::move(__args)]() mutable { + auto boundFunc = [this, __f, ... __args = std::move(__args)]() mutable { __f(std::move(__args)...); }; std::packaged_task* funcPtr = new std::packaged_task(std::move(boundFunc)); diff --git a/include/Utils/ThreadPool.hpp b/include/Utils/ThreadPool.hpp index b903e03..710d45e 100644 --- a/include/Utils/ThreadPool.hpp +++ b/include/Utils/ThreadPool.hpp @@ -16,6 +16,9 @@ public: ~ThreadPool(); void init(std::uint32_t numThreads); + void terminate(); + + void respawnWorker(std::uint32_t numThreads); template requires (!std::is_void_v>) @@ -26,7 +29,7 @@ public: } std::lock_guard lock(jobQueueMutex); - auto boundFunc = [this, &retVal, __job = std::move(__job), ... __args = std::move(__args)]() mutable { + auto boundFunc = [this, &retVal, __job, ... __args = std::move(__args)]() mutable { retVal = __job(this, std::move(__args)...); }; auto task = std::packaged_task(std::move(boundFunc)); @@ -44,7 +47,7 @@ public: } std::lock_guard lock(jobQueueMutex); - auto boundFunc = [this, __job = std::move(__job), ... __args = std::move(__args)]() mutable { + auto boundFunc = [this, __job, ... __args = std::move(__args)]() mutable { __job(this, std::move(__args)...); }; auto task = std::packaged_task(std::move(boundFunc)); @@ -54,6 +57,7 @@ public: return 0; } + int threadCount = 0; private: void* Worker(); diff --git a/include/precomp.hpp b/include/precomp.hpp index 005a5aa..09deefa 100644 --- a/include/precomp.hpp +++ b/include/precomp.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #define in_addr_t ULONG #elif __linux__