Skip to content

Commit bf85d02

Browse files
Revert "Aggregate remote cold start data in workflow headers (#2209)" (#2222)
This reverts commit 8899554.
1 parent 8899554 commit bf85d02

11 files changed

Lines changed: 40 additions & 678 deletions

File tree

inference/core/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
KEYPOINTS_DETECTION_TASK = "keypoint-detection"
55
PROCESSING_TIME_HEADER = "X-Processing-Time"
66
MODEL_COLD_START_HEADER = "X-Model-Cold-Start"
7-
MODEL_COLD_START_COUNT_HEADER = "X-Model-Cold-Start-Count"
87
MODEL_LOAD_TIME_HEADER = "X-Model-Load-Time"
98
MODEL_LOAD_DETAILS_HEADER = "X-Model-Load-Details"
109
MODEL_ID_HEADER = "X-Model-Id"

inference/core/interfaces/http/http_api.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
from inference.core import logger
3434
from inference.core.constants import (
35-
MODEL_COLD_START_COUNT_HEADER,
3635
MODEL_COLD_START_HEADER,
3736
MODEL_ID_HEADER,
3837
MODEL_LOAD_DETAILS_HEADER,
@@ -230,12 +229,6 @@
230229
orjson_response,
231230
orjson_response_keeping_parent_id,
232231
)
233-
from inference.core.interfaces.http.request_metrics import (
234-
REMOTE_PROCESSING_TIME_HEADER,
235-
REMOTE_PROCESSING_TIMES_HEADER,
236-
GCPServerlessMiddleware,
237-
build_model_response_headers,
238-
)
239232
from inference.core.interfaces.stream_manager.api.entities import (
240233
CommandContext,
241234
CommandResponse,
@@ -323,9 +316,23 @@
323316
from inference.core.version import __version__
324317

325318
try:
326-
from inference_sdk.config import EXECUTION_ID_HEADER
319+
from inference_sdk.config import (
320+
EXECUTION_ID_HEADER,
321+
INTERNAL_REMOTE_EXEC_REQ_HEADER,
322+
INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER,
323+
RemoteProcessingTimeCollector,
324+
apply_duration_minimum,
325+
execution_id,
326+
remote_processing_times,
327+
)
327328
except ImportError:
329+
execution_id = None
330+
remote_processing_times = None
331+
RemoteProcessingTimeCollector = None
328332
EXECUTION_ID_HEADER = None
333+
INTERNAL_REMOTE_EXEC_REQ_HEADER = None
334+
INTERNAL_REMOTE_EXEC_REQ_VERIFIED_HEADER = None
335+
apply_duration_minimum = None
329336

330337

331338
def get_content_type(request: Request) -> str:
@@ -503,7 +510,6 @@ async def on_shutdown():
503510
REMOTE_PROCESSING_TIME_HEADER,
504511
REMOTE_PROCESSING_TIMES_HEADER,
505512
MODEL_COLD_START_HEADER,
506-
MODEL_COLD_START_COUNT_HEADER,
507513
MODEL_LOAD_TIME_HEADER,
508514
MODEL_LOAD_DETAILS_HEADER,
509515
MODEL_ID_HEADER,
@@ -814,35 +820,17 @@ async def track_model_load(request: Request, call_next):
814820
ids_collector = RequestModelIds()
815821
request_model_ids.set(ids_collector)
816822
response = await call_next(request)
817-
remote_processing_collector = getattr(
818-
request.state, "remote_processing_time_collector", None
819-
)
820-
if remote_processing_collector is not None:
821-
remote_model_ids = remote_processing_collector.snapshot_model_ids()
822-
remote_cold_start_entries = (
823-
remote_processing_collector.snapshot_cold_start_entries()
824-
)
825-
remote_cold_start_count = (
826-
remote_processing_collector.snapshot_cold_start_count()
827-
)
828-
remote_cold_start_total_load_time = (
829-
remote_processing_collector.snapshot_cold_start_total_load_time()
830-
)
823+
if load_collector.has_data():
824+
total, detail = load_collector.summarize()
825+
response.headers[MODEL_COLD_START_HEADER] = "true"
826+
response.headers[MODEL_LOAD_TIME_HEADER] = str(total)
827+
if detail is not None:
828+
response.headers[MODEL_LOAD_DETAILS_HEADER] = detail
831829
else:
832-
remote_model_ids = set()
833-
remote_cold_start_entries = []
834-
remote_cold_start_count = 0
835-
remote_cold_start_total_load_time = 0.0
836-
response.headers.update(
837-
build_model_response_headers(
838-
local_model_ids=ids_collector.get_ids(),
839-
local_cold_start_entries=load_collector.snapshot_entries(),
840-
remote_model_ids=remote_model_ids,
841-
remote_cold_start_entries=remote_cold_start_entries,
842-
remote_cold_start_count=remote_cold_start_count,
843-
remote_cold_start_total_load_time=remote_cold_start_total_load_time,
844-
)
845-
)
830+
response.headers[MODEL_COLD_START_HEADER] = "false"
831+
model_ids = ids_collector.get_ids()
832+
if model_ids:
833+
response.headers[MODEL_ID_HEADER] = ",".join(sorted(model_ids))
846834
wf_id = request_workflow_id.get(None)
847835
if wf_id:
848836
response.headers[WORKFLOW_ID_HEADER] = wf_id
@@ -868,7 +856,6 @@ async def structured_access_log(request: Request, call_next):
868856
"request_id": CORRELATION_ID_HEADER,
869857
"processing_time": PROCESSING_TIME_HEADER,
870858
"model_cold_start": MODEL_COLD_START_HEADER,
871-
"model_cold_start_count": MODEL_COLD_START_COUNT_HEADER,
872859
"model_load_time": MODEL_LOAD_TIME_HEADER,
873860
"model_id": MODEL_ID_HEADER,
874861
"workflow_id": WORKFLOW_ID_HEADER,

inference/core/interfaces/http/request_metrics.py

Lines changed: 0 additions & 134 deletions
This file was deleted.

inference/core/managers/model_load_collector.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,15 @@ def has_data(self) -> bool:
2525
with self._lock:
2626
return len(self._entries) > 0
2727

28-
def snapshot_entries(self) -> list:
29-
with self._lock:
30-
return list(self._entries)
31-
3228
def summarize(self, max_detail_bytes: int = 4096) -> Tuple[float, Optional[str]]:
3329
"""Return (total_load_time, entries_json_or_none).
3430
3531
Returns the total model load time and a JSON string of individual
3632
entries. If the JSON exceeds *max_detail_bytes*, the detail string
3733
is omitted (None).
3834
"""
39-
entries = self.snapshot_entries()
35+
with self._lock:
36+
entries = list(self._entries)
4037
total = sum(t for _, t in entries)
4138
detail = json.dumps([{"m": m, "t": t} for m, t in entries])
4239
if len(detail) > max_detail_bytes:

inference/core/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.2.2"
1+
__version__ = "1.2.1"
22

33

44
if __name__ == "__main__":

inference_sdk/config.py

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import os
44
import threading
5-
from typing import Iterable, Optional, Tuple
5+
from typing import Optional, Tuple
66

77
from inference_sdk.utils.environment import str2bool
88

@@ -23,90 +23,23 @@ class RemoteProcessingTimeCollector:
2323

2424
def __init__(self):
2525
self._entries: list = [] # list of (model_id, time) tuples
26-
self._model_ids: set = set()
27-
self._cold_start_entries: list = [] # list of (model_id, load_time) tuples
28-
self._cold_start_total_load_time: float = 0.0
29-
self._cold_start_count: int = 0
3026
self._lock = threading.Lock()
3127

3228
def add(self, processing_time: float, model_id: str = "unknown") -> None:
3329
with self._lock:
3430
self._entries.append((model_id, processing_time))
3531

36-
def add_model_id(self, model_id: Optional[str]) -> None:
37-
if model_id in (None, "", "unknown"):
38-
return
39-
with self._lock:
40-
self._model_ids.add(model_id)
41-
42-
def add_model_ids(self, model_ids: Iterable[str]) -> None:
43-
filtered_ids = {
44-
model_id for model_id in model_ids if model_id not in (None, "", "unknown")
45-
}
46-
if not filtered_ids:
47-
return
48-
with self._lock:
49-
self._model_ids.update(filtered_ids)
50-
51-
def record_cold_start(
52-
self,
53-
load_time: float,
54-
model_id: Optional[str] = None,
55-
count: int = 1,
56-
) -> None:
57-
with self._lock:
58-
self._cold_start_total_load_time += load_time
59-
self._cold_start_count += count
60-
if model_id not in (None, "", "unknown"):
61-
self._cold_start_entries.append((model_id, load_time))
62-
self._model_ids.add(model_id)
63-
6432
def drain(self) -> list:
6533
"""Atomically return all entries and clear the internal list."""
6634
with self._lock:
6735
entries = self._entries
6836
self._entries = []
6937
return entries
7038

71-
def snapshot_entries(self) -> list:
72-
with self._lock:
73-
return list(self._entries)
74-
75-
def snapshot_model_ids(self) -> set:
76-
with self._lock:
77-
return set(self._model_ids)
78-
79-
def snapshot_cold_start_entries(self) -> list:
80-
with self._lock:
81-
return list(self._cold_start_entries)
82-
83-
def snapshot_cold_start_total_load_time(self) -> float:
84-
with self._lock:
85-
return self._cold_start_total_load_time
86-
87-
def snapshot_cold_start_count(self) -> int:
88-
with self._lock:
89-
return self._cold_start_count
90-
9139
def has_data(self) -> bool:
9240
with self._lock:
9341
return len(self._entries) > 0
9442

95-
def has_cold_start_data(self) -> bool:
96-
with self._lock:
97-
return self._cold_start_count > 0
98-
99-
def snapshot_summary(
100-
self, max_detail_bytes: int = 4096
101-
) -> Tuple[float, Optional[str]]:
102-
"""Return (total_time, entries_json_or_none) without clearing entries."""
103-
entries = self.snapshot_entries()
104-
total = sum(t for _, t in entries)
105-
detail = json.dumps([{"m": m, "t": t} for m, t in entries])
106-
if len(detail) > max_detail_bytes:
107-
detail = None
108-
return total, detail
109-
11043
def summarize(self, max_detail_bytes: int = 4096) -> Tuple[float, Optional[str]]:
11144
"""Atomically drain entries and return (total_time, entries_json_or_none).
11245

0 commit comments

Comments
 (0)