Skip to content

Commit 44e522e

Browse files
SujeethJineshOrbax Authors
authored andcommitted
No public description
PiperOrigin-RevId: 889041970
1 parent ccb60c2 commit 44e522e

19 files changed

+2610
-273
lines changed

checkpoint/orbax/checkpoint/_src/futures/signaling_client.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,45 @@
1717
import functools
1818
import threading
1919
import time
20-
from typing import Sequence
20+
from typing import Optional, Sequence
2121
from absl import logging
2222
import jax
2323
from orbax.checkpoint._src.multihost import multihost
2424
from typing_extensions import Protocol
2525

2626

27+
_FORCE_THREADSAFE_SIGNALING_CLIENT = False
28+
_THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT = 0
29+
_OVERRIDE_LOCK = threading.RLock()
30+
31+
32+
def set_force_threadsafe_signaling_client(force: bool) -> None:
33+
"""Overrides signaling-client selection for the current process."""
34+
global _FORCE_THREADSAFE_SIGNALING_CLIENT
35+
with _OVERRIDE_LOCK:
36+
_FORCE_THREADSAFE_SIGNALING_CLIENT = force
37+
get_signaling_client.cache_clear()
38+
39+
40+
def acquire_threadsafe_signaling_client_override() -> None:
41+
"""Enables the threadsafe signaling override for the current process."""
42+
global _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT
43+
with _OVERRIDE_LOCK:
44+
_THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT += 1
45+
set_force_threadsafe_signaling_client(True)
46+
47+
48+
def release_threadsafe_signaling_client_override() -> None:
49+
"""Releases one threadsafe signaling override reference for this process."""
50+
global _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT
51+
with _OVERRIDE_LOCK:
52+
_THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT = max(
53+
0, _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT - 1
54+
)
55+
if _THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT == 0:
56+
set_force_threadsafe_signaling_client(False)
57+
58+
2759
class SignalingClient(Protocol):
2860
"""Client that supports signaling between threads and processes."""
2961

@@ -50,7 +82,7 @@ def blocking_key_value_get(self, key: str, timeout_secs: int) -> str:
5082
"""
5183
...
5284

53-
def key_value_try_get(self, key: str) -> str | None:
85+
def key_value_try_get(self, key: str) -> Optional[str]:
5486
"""Tries to get the value for a given key in the client without blocking.
5587
5688
Args:
@@ -82,7 +114,7 @@ def wait_at_barrier(
82114
key: str,
83115
*,
84116
timeout_secs: int,
85-
process_ids: Sequence[int] | None = None,
117+
process_ids: Optional[Sequence[int]] = None,
86118
):
87119
"""Waits at a barrier identified by key.
88120
@@ -136,7 +168,7 @@ def blocking_key_value_get(self, key: str, timeout_secs: int) -> str:
136168
"""
137169
return str(self._client.blocking_key_value_get(key, timeout_secs * 1000))
138170

139-
def key_value_try_get(self, key: str) -> str | None:
171+
def key_value_try_get(self, key: str) -> Optional[str]:
140172
"""Tries to get the value for a given key in the client without blocking.
141173
142174
Args:
@@ -187,7 +219,7 @@ def wait_at_barrier(
187219
key: str,
188220
*,
189221
timeout_secs: int,
190-
process_ids: Sequence[int] | None = None,
222+
process_ids: Optional[Sequence[int]] = None,
191223
):
192224
"""Waits at a barrier identified by key.
193225
@@ -284,7 +316,7 @@ def blocking_key_value_get(self, key: str, timeout_secs: int) -> str:
284316

285317
return self._data[key]
286318

287-
def key_value_try_get(self, key: str) -> str | None:
319+
def key_value_try_get(self, key: str) -> Optional[str]:
288320
"""Tries to get the value for a key without blocking.
289321
290322
Args:
@@ -342,7 +374,7 @@ def wait_at_barrier(
342374
key: str,
343375
*,
344376
timeout_secs: int,
345-
process_ids: Sequence[int] | None = None,
377+
process_ids: Optional[Sequence[int]] = None,
346378
):
347379
"""Waits at a barrier identified by key.
348380
@@ -363,12 +395,24 @@ def get_signaling_client() -> SignalingClient:
363395
if multihost.is_jax_distributed_client_initialized():
364396
logging.info("Using JaxDistributedSignalingClient")
365397
return JaxDistributedSignalingClient()
366-
else:
367-
process_count = multihost.process_count()
368-
if process_count > 1:
369-
raise RuntimeError(
370-
"ThreadSafeKeyValueSignalingClient should only be used in a single"
371-
f" controller setup, process count: {process_count}."
372-
)
398+
399+
process_count = multihost.process_count()
400+
if process_count == 1:
373401
logging.info("Using ThreadSafeKeyValueSignalingClient")
374402
return ThreadSafeKeyValueSignalingClient()
403+
404+
# Pathways sidecars run with process_count > 1 but do not initialize the JAX
405+
# distributed client. Allow the in-process client only when that sidecar path
406+
# explicitly opts in via `set_force_threadsafe_signaling_client(True)`.
407+
if _FORCE_THREADSAFE_SIGNALING_CLIENT:
408+
logging.warning(
409+
"Using ThreadSafeKeyValueSignalingClient with process_count=%s because"
410+
" force override is enabled and JAX distributed client is unavailable.",
411+
process_count,
412+
)
413+
return ThreadSafeKeyValueSignalingClient()
414+
415+
raise RuntimeError(
416+
"ThreadSafeKeyValueSignalingClient should only be used in a single"
417+
f" controller setup, process count: {process_count}."
418+
)

checkpoint/orbax/checkpoint/_src/futures/signaling_client_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,15 @@ class TestGetSignalingClient(absltest.TestCase):
285285

286286
def setUp(self):
287287
super().setUp()
288+
signaling_client.set_force_threadsafe_signaling_client(False)
288289
signaling_client.get_signaling_client.cache_clear()
289290

290291
def tearDown(self):
292+
signaling_client.set_force_threadsafe_signaling_client(False)
293+
while getattr(
294+
signaling_client, "_THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT", 0
295+
):
296+
signaling_client.release_threadsafe_signaling_client_override()
291297
super().tearDown()
292298
signaling_client.get_signaling_client.cache_clear()
293299

@@ -344,6 +350,57 @@ def test_raises_error_when_multiprocess_and_not_initialized(
344350
mock_is_init.assert_called_once()
345351
mock_process_count.assert_called_once()
346352

353+
@mock.patch.object(
354+
multihost, "is_jax_distributed_client_initialized", return_value=False
355+
)
356+
@mock.patch.object(multihost, "process_count", return_value=2)
357+
def test_force_override_allows_threadsafe_client_in_multiprocess(
358+
self, mock_is_init, mock_process_count
359+
):
360+
signaling_client.set_force_threadsafe_signaling_client(True)
361+
362+
client = signaling_client.get_signaling_client()
363+
364+
self.assertIsInstance(
365+
client, signaling_client.ThreadSafeKeyValueSignalingClient
366+
)
367+
mock_is_init.assert_called_once()
368+
mock_process_count.assert_called_once()
369+
370+
@mock.patch.object(
371+
multihost, "is_jax_distributed_client_initialized", return_value=False
372+
)
373+
@mock.patch.object(multihost, "process_count", return_value=2)
374+
def test_acquire_release_override_is_reference_counted(
375+
self, mock_is_init, mock_process_count
376+
):
377+
signaling_client.acquire_threadsafe_signaling_client_override()
378+
signaling_client.acquire_threadsafe_signaling_client_override()
379+
self.addCleanup(
380+
lambda: setattr(
381+
signaling_client, "_THREADSAFE_SIGNALING_OVERRIDE_REFCOUNT", 0
382+
)
383+
)
384+
385+
client = signaling_client.get_signaling_client()
386+
self.assertIsInstance(
387+
client, signaling_client.ThreadSafeKeyValueSignalingClient
388+
)
389+
390+
signaling_client.release_threadsafe_signaling_client_override()
391+
signaling_client.get_signaling_client.cache_clear()
392+
client = signaling_client.get_signaling_client()
393+
self.assertIsInstance(
394+
client, signaling_client.ThreadSafeKeyValueSignalingClient
395+
)
396+
397+
signaling_client.release_threadsafe_signaling_client_override()
398+
signaling_client.get_signaling_client.cache_clear()
399+
with self.assertRaisesRegex(RuntimeError, "process count: 2"):
400+
signaling_client.get_signaling_client()
401+
mock_is_init.assert_called()
402+
mock_process_count.assert_called()
403+
347404

348405
if __name__ == "__main__":
349406
absltest.main()

0 commit comments

Comments
 (0)