Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion catkit2/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,9 @@ typedef std::function<std::string(py::bytes)> PythonRequestHandler;
PYBIND11_MODULE(catkit_bindings, m)
{
py::class_<Server>(m, "Server")
.def(py::init<int>())
.def(py::init<int, int>(),
py::arg("port"),
py::arg("num_workers") = 1)
.def("register_request_handler", [](Server &server, std::string type, PythonRequestHandler request_handler)
{
server.RegisterRequestHandler(type, [request_handler](const std::string &data)
Expand Down
261 changes: 171 additions & 90 deletions catkit_core/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
using namespace std;
using namespace zmq;

Server::Server(int port)
: m_Port(port), m_IsRunning(false), m_ShouldShutDown(false)
//#define DEBUG_PRINT(msg) std::cerr << "[DEBUG] " << __func__ << ":" << __LINE__ << " - " << msg << std::endl
//#define ERROR_PRINT(msg) std::cerr << "[ERROR] " << __func__ << ":" << __LINE__ << " - " << msg << std::endl
#define DEBUG_PRINT(msg)
#define ERROR_PRINT(msg)

Server::Server(int port, int num_workers)
: m_Port(port), m_NumWorkers(num_workers), m_IsRunning(false), m_ShouldShutDown(false)
{
}

Expand All @@ -41,124 +46,200 @@ void Server::Start()

m_IsRunning = true;

m_RunThread = thread(&Server::RunInternal, this);
// Initialize ZMQ context and socket
m_Context = std::make_unique<zmq::context_t>();
m_Socket = std::make_unique<zmq::socket_t>(*m_Context, ZMQ_ROUTER);
m_Socket->bind("tcp://*:"s + std::to_string(m_Port));
m_Socket->set(zmq::sockopt::rcvtimeo, 20);
m_Socket->set(zmq::sockopt::linger, 0);

LOG_INFO("Starting server on port "s + to_string(m_Port) + " with " + to_string(m_NumWorkers) + " worker(s).");

// Start receive thread
m_ReceiveThread = thread(&Server::ReceiveLoop, this);

// Start worker threads
m_WorkerThreads.reserve(m_NumWorkers);
for (int i = 0; i < m_NumWorkers; i++) {
m_WorkerThreads.emplace_back(&Server::WorkerLoop, this, i);
}
}

void Server::Stop()
{
m_ShouldShutDown = true;

if (m_RunThread.joinable())
m_RunThread.join();
// Wake up all waiting workers
{
std::lock_guard<std::mutex> lock(m_QueueMutex);
m_QueueCV.notify_all();
}

// Join receive thread
if (m_ReceiveThread.joinable())
m_ReceiveThread.join();

// Join all worker threads
for (auto& worker : m_WorkerThreads) {
if (worker.joinable())
worker.join();
}

// Clean up ZMQ
if (m_Socket) {
m_Socket->close();
m_Socket.reset();
}
if (m_Context) {
m_Context.reset();
}

CleanupRequestHandlers();

m_IsRunning = false;
LOG_INFO("Server has shut down.");
}

void Server::CleanupRequestHandlers()
{
m_RequestHandlers.clear();
}

void Server::RunInternal()
void Server::ReceiveLoop()
{
LOG_INFO("Starting server on port "s + to_string(m_Port) + ".");

zmq::context_t context;

zmq::socket_t socket(context, ZMQ_ROUTER);
socket.bind("tcp://*:"s + std::to_string(m_Port));
socket.set(zmq::sockopt::rcvtimeo, 20);
socket.set(zmq::sockopt::linger, 0);
LOG_DEBUG("Receive loop started.");

while (!m_ShouldShutDown)
{
zmq::multipart_t request_msg;
auto res = zmq::recv_multipart(*m_Socket, std::back_inserter(request_msg));

if (!res.has_value())
{
// Server has received no message (timeout).
continue;
}

if (request_msg.size() != 5)
{
LOG_ERROR("The server has received a message with "s + std::to_string(request_msg.size()) + " frames instead of five. Ignoring.");
continue;
}

PendingRequest req;
req.client_identity = request_msg.popstr();
req.request_id = request_msg.popstr();
std::string empty = request_msg.popstr(); // Empty delimiter frame
req.request_type = request_msg.popstr();
req.request_data = request_msg.popstr();

DEBUG_PRINT("received: type=" << req.request_type << " client=" << req.client_identity);
LOG_DEBUG("Request received: "s + req.request_type);

// Enqueue for workers
{
std::lock_guard<std::mutex> lock(m_QueueMutex);
m_RequestQueue.push(std::move(req));
}
m_QueueCV.notify_one();
}

LOG_DEBUG("Receive loop ended.");
}

Finally finally([this, &socket]()
{
socket.close();
void Server::WorkerLoop(int worker_id)
{
LOG_DEBUG("Worker "s + to_string(worker_id) + " started.");

while (!m_ShouldShutDown)
{
PendingRequest req;

// Dequeue (blocking with timeout to check shutdown periodically)
{
std::unique_lock<std::mutex> lock(m_QueueMutex);
bool has_request = m_QueueCV.wait_for(lock, std::chrono::milliseconds(100), [this] {
return !m_RequestQueue.empty() || m_ShouldShutDown.load();
});

if (!has_request || m_ShouldShutDown)
continue;

req = std::move(m_RequestQueue.front());
m_RequestQueue.pop();
DEBUG_PRINT("Worker " << worker_id << " dequeued request: type=" << req.request_type);
}

// Process request (this can take a long time, but doesn't block other workers)
string reply_data;
string reply_type = "OK";

auto handler = m_RequestHandlers.find(req.request_type);

if (handler == m_RequestHandlers.end())
{
LOG_ERROR("An unknown request type was received: "s + req.request_type + ".");
reply_type = "ERROR";
reply_data = "Unknown request type";
}
else
{
DEBUG_PRINT("Worker " << worker_id << " calling handler for: " << req.request_type);
try
{
// Move request_data to handler to avoid copy (handler takes const& but we don't need it after)
reply_data = handler->second(std::move(req.request_data));
DEBUG_PRINT("Worker " << worker_id << " handler completed for: " << req.request_type);
}
catch (std::exception &e)
{
ERROR_PRINT("Worker " << worker_id << " exception in handler: " << e.what());
LOG_ERROR("Encountered error during handling of request: "s + e.what());
reply_type = "ERROR";
reply_data = e.what();
}
}

// Send reply (move reply_data since we don't need it after)
SendResponse(req.client_identity, req.request_id, reply_type, std::move(reply_data));

LOG_DEBUG("Worker "s + to_string(worker_id) + " sent reply: " + reply_type);
}

LOG_DEBUG("Worker "s + to_string(worker_id) + " ended.");
}

this->m_ShouldShutDown = true;
this->m_IsRunning = false;
void Server::SendResponse(const std::string& client_identity, const std::string& request_id,
const std::string& reply_type, std::string reply_data)
{
multipart_t msg;

LOG_INFO("Server has shut down.");
});
msg.addstr(client_identity);
msg.addstr(request_id);
msg.addstr("");
msg.addstr(reply_type);
msg.addstr(std::move(reply_data)); // Move into ZMQ message

while (!m_ShouldShutDown)
{
zmq::multipart_t request_msg;
auto res = zmq::recv_multipart(socket, std::back_inserter(request_msg));

if (!res.has_value())
{
// Server has received no message.
continue;
}

if (request_msg.size() != 5)
{
// Each message should have five frames: request_id, identity, empty, type and data.
LOG_ERROR("The server has received a message with "s + std::to_string(request_msg.size()) + " frames instead of five. Ignoring.");
continue;
}

std::string client_identity = request_msg.popstr();
std::string request_id = request_msg.popstr();
std::string empty = request_msg.popstr();
std::string request_type = request_msg.popstr();
std::string request_data = request_msg.popstr();

LOG_DEBUG("Request received: "s + request_type);

// Call the request handler and return the result if no error occurred.
string reply_data;
string reply_type = "OK";

// Find the correct request handler.
auto handler = m_RequestHandlers.find(request_type);

if (handler == m_RequestHandlers.end())
{
LOG_ERROR("An unknown request type was received: "s + request_type + ".");
reply_type = "ERROR";
reply_data = "Unknown request type";
}
else
{
try
{
reply_data = handler->second(request_data);
}
catch (std::exception &e)
{
LOG_ERROR("Encountered error during handling of request: "s + e.what());

reply_type = "ERROR";
reply_data = e.what();
}
}

// Send reply to the client.
multipart_t msg;

msg.addstr(client_identity);
msg.addstr(request_id);
msg.addstr("");
msg.addstr(reply_type);
msg.addstr(reply_data);

msg.send(socket);

LOG_DEBUG("Sent reply: "s + reply_type);
}
// ZMQ sockets are not thread-safe - must protect with mutex
std::lock_guard<std::mutex> lock(m_SocketMutex);
msg.send(*m_Socket);
}

bool Server::IsRunning()
bool Server::IsRunning() const
{
return m_IsRunning;
}

int Server::GetPort()
int Server::GetPort() const
{
return m_Port;
}

int Server::GetNumWorkers() const
{
return m_NumWorkers;
}

void Server::Sleep(double sleep_time_in_sec, void (*error_check)())
{
::Sleep(sleep_time_in_sec, [this, error_check]() -> bool
Expand Down
45 changes: 40 additions & 5 deletions catkit_core/Server.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,29 @@
#include <functional>
#include <map>
#include <thread>
#include <queue>
#include <mutex>
#include <condition_variable>
#include <vector>
#include <memory>

// Forward declaration for ZMQ
namespace zmq {
class socket_t;
class context_t;
}

struct PendingRequest {
std::string client_identity;
std::string request_id;
std::string request_type;
std::string request_data;
};

class Server
{
public:
Server(int port);
Server(int port, int num_workers = 1);
virtual ~Server();

typedef std::function<std::string(const std::string&)> RequestHandler;
Expand All @@ -20,21 +38,38 @@ class Server
void Start();
void Stop();

bool IsRunning();
bool IsRunning() const;

int GetPort();
int GetPort() const;
int GetNumWorkers() const;

void Sleep(double sleep_time_in_sec, void (*error_check)()=nullptr);

void CleanupRequestHandlers();

protected:
int m_Port;
int m_NumWorkers;

private:
void RunInternal();
void ReceiveLoop();
void WorkerLoop(int worker_id);
void SendResponse(const std::string& client_identity, const std::string& request_id,
const std::string& reply_type, std::string reply_data);

// Thread management
std::thread m_ReceiveThread;
std::vector<std::thread> m_WorkerThreads;

// Thread pool queue
std::queue<PendingRequest> m_RequestQueue;
std::mutex m_QueueMutex;
std::condition_variable m_QueueCV;

std::thread m_RunThread;
// ZMQ context and socket (owned by Server)
std::unique_ptr<zmq::context_t> m_Context;
std::unique_ptr<zmq::socket_t> m_Socket;
std::mutex m_SocketMutex; // Protects socket operations (ZMQ sockets are not thread-safe)

std::map<std::string, RequestHandler> m_RequestHandlers;

Expand Down
Loading