4343#endif
4444
4545#if defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_HIP) || \
46- defined (USE_MACA) || defined(USE_UBSHMEM)
46+ defined (USE_MACA) || defined(USE_UBSHMEM) || defined(USE_SUNRISE)
4747#include < cassert>
4848
4949#if defined(USE_MNNVL) || defined(USE_UBSHMEM)
@@ -88,7 +88,8 @@ DEFINE_string(mode, "initiator",
8888DEFINE_string (operation, " read" , " Operation type: read or write" );
8989
9090DEFINE_string (protocol, " rdma" ,
91- " Transfer protocol: rdma|barex|tcp|efa|nvlink|nvlink_intra|hip" );
91+ " Transfer protocol: "
92+ " rdma|barex|tcp|efa|nvlink|nvlink_intra|hip|sunrise_link" );
9293
9394DEFINE_string (device_name, " mlx5_2" ,
9495 " Device name to use, valid if protocol=rdma" );
@@ -107,7 +108,7 @@ DEFINE_uint32(report_precision, 2, "Report precision");
107108DEFINE_string (backend, " classic" , " Backend to use: classic|tent" );
108109
109110#if defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_HIP) || \
110- defined (USE_MACA) || defined(USE_UBSHMEM)
111+ defined (USE_MACA) || defined(USE_UBSHMEM) || defined(USE_SUNRISE)
111112DEFINE_bool(use_vram, true , " Allocate memory from GPU/NPU VRAM" );
112113DEFINE_bool (init_mem, true , " Initialize allocated memory" );
113114DEFINE_int32 (gpu_id, 0 ,
@@ -119,7 +120,7 @@ using namespace mooncake;
119120static void *allocateMemoryPool (size_t size, int buffer_id,
120121 bool from_vram = false ) {
121122#if defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_HIP) || \
122- defined (USE_MACA) || defined (USE_UBSHMEM)
123+ defined (USE_MACA) || defined (USE_UBSHMEM) || defined (USE_SUNRISE)
123124 if (from_vram) {
124125 int gpu_id;
125126 if (FLAGS_gpu_id == -1 ) {
@@ -190,7 +191,7 @@ static void *allocateMemoryPool(size_t size, int buffer_id,
190191
191192static void freeMemoryPool (void *addr, size_t size) {
192193#if defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_HIP) || \
193- defined (USE_MACA) || defined (USE_UBSHMEM)
194+ defined (USE_MACA) || defined (USE_UBSHMEM) || defined (USE_SUNRISE)
194195 if (FLAGS_protocol == " nvlink" || FLAGS_protocol == " hip" ) {
195196#ifdef USE_MNNVL
196197 if (FLAGS_use_vram) {
@@ -214,6 +215,10 @@ static void freeMemoryPool(void *addr, size_t size) {
214215#endif
215216 } else {
216217#ifndef USE_UBSHMEM
218+ if (!FLAGS_use_vram) {
219+ numa_free (addr, size);
220+ return ;
221+ }
217222 // check pointer on GPU
218223 cudaPointerAttributes attributes;
219224 checkCudaError (cudaPointerGetAttributes (&attributes, addr),
@@ -271,10 +276,21 @@ static inline std::string calculateRate(uint64_t data_bytes, double duration) {
271276volatile bool running = true ;
272277std::atomic<size_t > total_batch_count (0 );
273278
279+ // Ensure each worker thread has a valid GPU context before issuing transfers.
280+ static inline void setWorkerDeviceIfNeeded () {
281+ #if defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_HIP) || \
282+ defined (USE_MACA) || defined (USE_SUNRISE)
283+ if (FLAGS_use_vram && FLAGS_gpu_id >= 0 ) {
284+ checkCudaError (cudaSetDevice (FLAGS_gpu_id),
285+ " Failed to set device in worker" );
286+ }
287+ #endif
288+ }
289+
274290// Common helper to determine buffer count based on GPU/NUMA configuration
275291static int determineBufferCount () {
276292#if defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_HIP) || \
277- defined (USE_MACA)
293+ defined (USE_MACA) || defined (USE_SUNRISE)
278294 if (FLAGS_use_vram) {
279295 int gpu_num;
280296 LOG (INFO) << " VRAM is used" ;
@@ -305,7 +321,7 @@ static std::vector<void *> allocateBuffers() {
305321 buffer_num = determineBufferCount ();
306322 std::vector<void *> addr (buffer_num);
307323#if defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_HIP) || \
308- defined (USE_MACA) || defined (USE_UBSHMEM)
324+ defined (USE_MACA) || defined (USE_UBSHMEM) || defined (USE_SUNRISE)
309325 for (int i = 0 ; i < buffer_num; ++i) {
310326 addr[i] = allocateMemoryPool (FLAGS_buffer_size, i, FLAGS_use_vram);
311327 }
@@ -328,7 +344,7 @@ static void freeBuffers(std::vector<void *> &addr) {
328344// Helper to get location name for classic backend
329345static std::string getLocationName (int buffer_id) {
330346#if defined(USE_CUDA) || defined(USE_MUSA) || defined(USE_HIP) || \
331- defined (USE_MACA) || defined (USE_UBSHMEM)
347+ defined (USE_MACA) || defined (USE_UBSHMEM) || defined (USE_SUNRISE)
332348 if (FLAGS_use_vram) {
333349 int name_suffix = (FLAGS_gpu_id == -1 ) ? buffer_id : FLAGS_gpu_id;
334350 return std::string (GPU_PREFIX) + std::to_string (name_suffix);
@@ -476,7 +492,8 @@ static Transport *installTransportFromFlags(TransferEngine *engine) {
476492 xport = engine->installTransport (" efa" , nullptr );
477493 } else if (FLAGS_protocol == " tcp" || FLAGS_protocol == " nvlink" ||
478494 FLAGS_protocol == " hip" || FLAGS_protocol == " nvlink_intra" ||
479- FLAGS_protocol == " ubshmem" ) {
495+ FLAGS_protocol == " ubshmem" ||
496+ FLAGS_protocol == " sunrise_link" ) {
480497 xport = engine->installTransport (FLAGS_protocol.c_str (), nullptr );
481498 } else {
482499 LOG (ERROR) << " Unsupported protocol: " << FLAGS_protocol;
@@ -644,6 +661,7 @@ void initiatorWorker(mooncake::tent::TransferEngine *engine,
644661 void *addr,
645662 const mooncake::tent::SegmentInfo &segment_info) {
646663 bindToSocket (thread_id % NR_SOCKETS);
664+ setWorkerDeviceIfNeeded ();
647665 mooncake::tent::Request::OpCode opcode;
648666 if (FLAGS_operation == " read" )
649667 opcode = mooncake::tent::Request::READ;
0 commit comments