diff --git a/catkit2/bindings.cpp b/catkit2/bindings.cpp index 46c014043..1d357ac9c 100644 --- a/catkit2/bindings.cpp +++ b/catkit2/bindings.cpp @@ -514,7 +514,9 @@ typedef std::function PythonRequestHandler; PYBIND11_MODULE(catkit_bindings, m) { py::class_(m, "Server") - .def(py::init()) + .def(py::init(), + py::arg("port"), + py::arg("num_workers") = 1) .def("register_request_handler", [](Server &server, std::string type, PythonRequestHandler request_handler) { server.RegisterRequestHandler(type, [request_handler](const std::string &data) diff --git a/catkit_core/Server.cpp b/catkit_core/Server.cpp index 001321d2e..51e98464f 100644 --- a/catkit_core/Server.cpp +++ b/catkit_core/Server.cpp @@ -16,8 +16,13 @@ using namespace std; using namespace zmq; -Server::Server(int port) - : m_Port(port), m_IsRunning(false), m_ShouldShutDown(false) +//#define DEBUG_PRINT(msg) std::cerr << "[DEBUG] " << __func__ << ":" << __LINE__ << " - " << msg << std::endl +//#define ERROR_PRINT(msg) std::cerr << "[ERROR] " << __func__ << ":" << __LINE__ << " - " << msg << std::endl +#define DEBUG_PRINT(msg) +#define ERROR_PRINT(msg) + +Server::Server(int port, int num_workers) + : m_Port(port), m_NumWorkers(num_workers), m_IsRunning(false), m_ShouldShutDown(false) { } @@ -41,17 +46,58 @@ void Server::Start() m_IsRunning = true; - m_RunThread = thread(&Server::RunInternal, this); + // Initialize ZMQ context and socket + m_Context = std::make_unique(); + m_Socket = std::make_unique(*m_Context, ZMQ_ROUTER); + m_Socket->bind("tcp://*:"s + std::to_string(m_Port)); + m_Socket->set(zmq::sockopt::rcvtimeo, 20); + m_Socket->set(zmq::sockopt::linger, 0); + + LOG_INFO("Starting server on port "s + to_string(m_Port) + " with " + to_string(m_NumWorkers) + " worker(s)."); + + // Start receive thread + m_ReceiveThread = thread(&Server::ReceiveLoop, this); + + // Start worker threads + m_WorkerThreads.reserve(m_NumWorkers); + for (int i = 0; i < m_NumWorkers; i++) { + m_WorkerThreads.emplace_back(&Server::WorkerLoop, this, i); + } } void Server::Stop() { m_ShouldShutDown = true; - if (m_RunThread.joinable()) - m_RunThread.join(); + // Wake up all waiting workers + { + std::lock_guard lock(m_QueueMutex); + m_QueueCV.notify_all(); + } + + // Join receive thread + if (m_ReceiveThread.joinable()) + m_ReceiveThread.join(); + + // Join all worker threads + for (auto& worker : m_WorkerThreads) { + if (worker.joinable()) + worker.join(); + } + + // Clean up ZMQ + if (m_Socket) { + m_Socket->close(); + m_Socket.reset(); + } + if (m_Context) { + m_Context.reset(); + } CleanupRequestHandlers(); + + m_IsRunning = false; + LOG_INFO("Server has shut down."); } void Server::CleanupRequestHandlers() @@ -59,106 +105,141 @@ void Server::CleanupRequestHandlers() m_RequestHandlers.clear(); } -void Server::RunInternal() +void Server::ReceiveLoop() { - LOG_INFO("Starting server on port "s + to_string(m_Port) + "."); - - zmq::context_t context; - - zmq::socket_t socket(context, ZMQ_ROUTER); - socket.bind("tcp://*:"s + std::to_string(m_Port)); - socket.set(zmq::sockopt::rcvtimeo, 20); - socket.set(zmq::sockopt::linger, 0); + LOG_DEBUG("Receive loop started."); + + while (!m_ShouldShutDown) + { + zmq::multipart_t request_msg; + auto res = zmq::recv_multipart(*m_Socket, std::back_inserter(request_msg)); + + if (!res.has_value()) + { + // Server has received no message (timeout). + continue; + } + + if (request_msg.size() != 5) + { + LOG_ERROR("The server has received a message with "s + std::to_string(request_msg.size()) + " frames instead of five. Ignoring."); + continue; + } + + PendingRequest req; + req.client_identity = request_msg.popstr(); + req.request_id = request_msg.popstr(); + std::string empty = request_msg.popstr(); // Empty delimiter frame + req.request_type = request_msg.popstr(); + req.request_data = request_msg.popstr(); + + DEBUG_PRINT("received: type=" << req.request_type << " client=" << req.client_identity); + LOG_DEBUG("Request received: "s + req.request_type); + + // Enqueue for workers + { + std::lock_guard lock(m_QueueMutex); + m_RequestQueue.push(std::move(req)); + } + m_QueueCV.notify_one(); + } + + LOG_DEBUG("Receive loop ended."); +} - Finally finally([this, &socket]() - { - socket.close(); +void Server::WorkerLoop(int worker_id) +{ + LOG_DEBUG("Worker "s + to_string(worker_id) + " started."); + + while (!m_ShouldShutDown) + { + PendingRequest req; + + // Dequeue (blocking with timeout to check shutdown periodically) + { + std::unique_lock lock(m_QueueMutex); + bool has_request = m_QueueCV.wait_for(lock, std::chrono::milliseconds(100), [this] { + return !m_RequestQueue.empty() || m_ShouldShutDown.load(); + }); + + if (!has_request || m_ShouldShutDown) + continue; + + req = std::move(m_RequestQueue.front()); + m_RequestQueue.pop(); + DEBUG_PRINT("Worker " << worker_id << " dequeued request: type=" << req.request_type); + } + + // Process request (this can take a long time, but doesn't block other workers) + string reply_data; + string reply_type = "OK"; + + auto handler = m_RequestHandlers.find(req.request_type); + + if (handler == m_RequestHandlers.end()) + { + LOG_ERROR("An unknown request type was received: "s + req.request_type + "."); + reply_type = "ERROR"; + reply_data = "Unknown request type"; + } + else + { + DEBUG_PRINT("Worker " << worker_id << " calling handler for: " << req.request_type); + try + { + // Move request_data to handler to avoid copy (handler takes const& but we don't need it after) + reply_data = handler->second(std::move(req.request_data)); + DEBUG_PRINT("Worker " << worker_id << " handler completed for: " << req.request_type); + } + catch (std::exception &e) + { + ERROR_PRINT("Worker " << worker_id << " exception in handler: " << e.what()); + LOG_ERROR("Encountered error during handling of request: "s + e.what()); + reply_type = "ERROR"; + reply_data = e.what(); + } + } + + // Send reply (move reply_data since we don't need it after) + SendResponse(req.client_identity, req.request_id, reply_type, std::move(reply_data)); + + LOG_DEBUG("Worker "s + to_string(worker_id) + " sent reply: " + reply_type); + } + + LOG_DEBUG("Worker "s + to_string(worker_id) + " ended."); +} - this->m_ShouldShutDown = true; - this->m_IsRunning = false; +void Server::SendResponse(const std::string& client_identity, const std::string& request_id, + const std::string& reply_type, std::string reply_data) +{ + multipart_t msg; - LOG_INFO("Server has shut down."); - }); + msg.addstr(client_identity); + msg.addstr(request_id); + msg.addstr(""); + msg.addstr(reply_type); + msg.addstr(std::move(reply_data)); // Move into ZMQ message - while (!m_ShouldShutDown) - { - zmq::multipart_t request_msg; - auto res = zmq::recv_multipart(socket, std::back_inserter(request_msg)); - - if (!res.has_value()) - { - // Server has received no message. - continue; - } - - if (request_msg.size() != 5) - { - // Each message should have five frames: request_id, identity, empty, type and data. - LOG_ERROR("The server has received a message with "s + std::to_string(request_msg.size()) + " frames instead of five. Ignoring."); - continue; - } - - std::string client_identity = request_msg.popstr(); - std::string request_id = request_msg.popstr(); - std::string empty = request_msg.popstr(); - std::string request_type = request_msg.popstr(); - std::string request_data = request_msg.popstr(); - - LOG_DEBUG("Request received: "s + request_type); - - // Call the request handler and return the result if no error occurred. - string reply_data; - string reply_type = "OK"; - - // Find the correct request handler. - auto handler = m_RequestHandlers.find(request_type); - - if (handler == m_RequestHandlers.end()) - { - LOG_ERROR("An unknown request type was received: "s + request_type + "."); - reply_type = "ERROR"; - reply_data = "Unknown request type"; - } - else - { - try - { - reply_data = handler->second(request_data); - } - catch (std::exception &e) - { - LOG_ERROR("Encountered error during handling of request: "s + e.what()); - - reply_type = "ERROR"; - reply_data = e.what(); - } - } - - // Send reply to the client. - multipart_t msg; - - msg.addstr(client_identity); - msg.addstr(request_id); - msg.addstr(""); - msg.addstr(reply_type); - msg.addstr(reply_data); - - msg.send(socket); - - LOG_DEBUG("Sent reply: "s + reply_type); - } + // ZMQ sockets are not thread-safe - must protect with mutex + std::lock_guard lock(m_SocketMutex); + msg.send(*m_Socket); } -bool Server::IsRunning() +bool Server::IsRunning() const { return m_IsRunning; } -int Server::GetPort() +int Server::GetPort() const { return m_Port; } +int Server::GetNumWorkers() const +{ + return m_NumWorkers; +} + void Server::Sleep(double sleep_time_in_sec, void (*error_check)()) { ::Sleep(sleep_time_in_sec, [this, error_check]() -> bool diff --git a/catkit_core/Server.h b/catkit_core/Server.h index 73b3024ee..d7c8b3a11 100644 --- a/catkit_core/Server.h +++ b/catkit_core/Server.h @@ -6,11 +6,29 @@ #include #include #include +#include +#include +#include +#include +#include + +// Forward declaration for ZMQ +namespace zmq { + class socket_t; + class context_t; +} + +struct PendingRequest { + std::string client_identity; + std::string request_id; + std::string request_type; + std::string request_data; +}; class Server { public: - Server(int port); + Server(int port, int num_workers = 1); virtual ~Server(); typedef std::function RequestHandler; @@ -20,9 +38,10 @@ class Server void Start(); void Stop(); - bool IsRunning(); + bool IsRunning() const; - int GetPort(); + int GetPort() const; + int GetNumWorkers() const; void Sleep(double sleep_time_in_sec, void (*error_check)()=nullptr); @@ -30,11 +49,27 @@ class Server protected: int m_Port; + int m_NumWorkers; private: - void RunInternal(); + void ReceiveLoop(); + void WorkerLoop(int worker_id); + void SendResponse(const std::string& client_identity, const std::string& request_id, + const std::string& reply_type, std::string reply_data); + + // Thread management + std::thread m_ReceiveThread; + std::vector m_WorkerThreads; + + // Thread pool queue + std::queue m_RequestQueue; + std::mutex m_QueueMutex; + std::condition_variable m_QueueCV; - std::thread m_RunThread; + // ZMQ context and socket (owned by Server) + std::unique_ptr m_Context; + std::unique_ptr m_Socket; + std::mutex m_SocketMutex; // Protects socket operations (ZMQ sockets are not thread-safe) std::map m_RequestHandlers;