1717import functools
1818import threading
1919import time
20- from typing import Sequence
20+ from typing import Optional , Sequence
2121from absl import logging
2222import jax
2323from orbax .checkpoint ._src .multihost import multihost
2424from typing_extensions import Protocol
2525
2626
27+ _FORCE_THREADSAFE_SIGNALING_CLIENT = False
28+ _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT = 0
29+ _OVERRIDE_LOCK = threading .RLock ()
30+
31+
32+ def set_force_threadsafe_signaling_client (force : bool ) -> None :
33+ """Overrides signaling-client selection for the current process."""
34+ global _FORCE_THREADSAFE_SIGNALING_CLIENT
35+ with _OVERRIDE_LOCK :
36+ _FORCE_THREADSAFE_SIGNALING_CLIENT = force
37+ get_signaling_client .cache_clear ()
38+
39+
40+ def acquire_threadsafe_signaling_client_override () -> None :
41+ """Enables the threadsafe signaling override for the current process."""
42+ global _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT
43+ with _OVERRIDE_LOCK :
44+ _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT += 1
45+ set_force_threadsafe_signaling_client (True )
46+
47+
48+ def release_threadsafe_signaling_client_override () -> None :
49+ """Releases one threadsafe signaling override reference for this process."""
50+ global _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT
51+ with _OVERRIDE_LOCK :
52+ _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT = max (
53+ 0 , _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT - 1
54+ )
55+ if _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT == 0 :
56+ set_force_threadsafe_signaling_client (False )
57+
58+
2759class SignalingClient (Protocol ):
2860 """Client that supports signaling between threads and processes."""
2961
@@ -50,7 +82,7 @@ def blocking_key_value_get(self, key: str, timeout_secs: int) -> str:
5082 """
5183 ...
5284
53- def key_value_try_get (self , key : str ) -> str | None :
85+ def key_value_try_get (self , key : str ) -> Optional [ str ] :
5486 """Tries to get the value for a given key in the client without blocking.
5587
5688 Args:
@@ -82,7 +114,7 @@ def wait_at_barrier(
82114 key : str ,
83115 * ,
84116 timeout_secs : int ,
85- process_ids : Sequence [int ] | None = None ,
117+ process_ids : Optional [ Sequence [int ]] = None ,
86118 ):
87119 """Waits at a barrier identified by key.
88120
@@ -136,7 +168,7 @@ def blocking_key_value_get(self, key: str, timeout_secs: int) -> str:
136168 """
137169 return str (self ._client .blocking_key_value_get (key , timeout_secs * 1000 ))
138170
139- def key_value_try_get (self , key : str ) -> str | None :
171+ def key_value_try_get (self , key : str ) -> Optional [ str ] :
140172 """Tries to get the value for a given key in the client without blocking.
141173
142174 Args:
@@ -187,7 +219,7 @@ def wait_at_barrier(
187219 key : str ,
188220 * ,
189221 timeout_secs : int ,
190- process_ids : Sequence [int ] | None = None ,
222+ process_ids : Optional [ Sequence [int ]] = None ,
191223 ):
192224 """Waits at a barrier identified by key.
193225
@@ -284,7 +316,7 @@ def blocking_key_value_get(self, key: str, timeout_secs: int) -> str:
284316
285317 return self ._data [key ]
286318
287- def key_value_try_get (self , key : str ) -> str | None :
319+ def key_value_try_get (self , key : str ) -> Optional [ str ] :
288320 """Tries to get the value for a key without blocking.
289321
290322 Args:
@@ -342,7 +374,7 @@ def wait_at_barrier(
342374 key : str ,
343375 * ,
344376 timeout_secs : int ,
345- process_ids : Sequence [int ] | None = None ,
377+ process_ids : Optional [ Sequence [int ]] = None ,
346378 ):
347379 """Waits at a barrier identified by key.
348380
@@ -363,12 +395,24 @@ def get_signaling_client() -> SignalingClient:
363395 if multihost .is_jax_distributed_client_initialized ():
364396 logging .info ("Using JaxDistributedSignalingClient" )
365397 return JaxDistributedSignalingClient ()
366- else :
367- process_count = multihost .process_count ()
368- if process_count > 1 :
369- raise RuntimeError (
370- "ThreadSafeKeyValueSignalingClient should only be used in a single"
371- f" controller setup, process count: { process_count } ."
372- )
398+
399+ process_count = multihost .process_count ()
400+ if process_count == 1 :
373401 logging .info ("Using ThreadSafeKeyValueSignalingClient" )
374402 return ThreadSafeKeyValueSignalingClient ()
403+
404+ # Pathways sidecars run with process_count > 1 but do not initialize the JAX
405+ # distributed client. Allow the in-process client only when that sidecar path
406+ # explicitly opts in via `set_force_threadsafe_signaling_client(True)`.
407+ if _FORCE_THREADSAFE_SIGNALING_CLIENT :
408+ logging .warning (
409+ "Using ThreadSafeKeyValueSignalingClient with process_count=%s because"
410+ " force override is enabled and JAX distributed client is unavailable." ,
411+ process_count ,
412+ )
413+ return ThreadSafeKeyValueSignalingClient ()
414+
415+ raise RuntimeError (
416+ "ThreadSafeKeyValueSignalingClient should only be used in a single"
417+ f" controller setup, process count: { process_count } ."
418+ )
0 commit comments