-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathcog_rfdetr_repro_predictor.py
More file actions
63 lines (57 loc) · 2.38 KB
/
cog_rfdetr_repro_predictor.py
File metadata and controls
63 lines (57 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tempfile
from pathlib import Path
from cog import BasePredictor, Input, Path as CogPath
from core.rf_detr_runtime import (
DEFAULT_RF_DETR_MODEL_ID,
ensure_python_nvidia_libs_preferred,
sync_python_nvidia_runtime_libs_to_system,
supported_rf_detr_model_ids,
)
from core.rf_detr_repro import bundle_repro_artifacts, run_rf_detr_repro
ensure_python_nvidia_libs_preferred()
class Predictor(BasePredictor):
def setup(self) -> None:
sync_python_nvidia_runtime_libs_to_system()
def predict(
self,
media: CogPath = Input(description="Input still image or short video for RF-DETR repro."),
modelId: str = Input(
description="RF-DETR segmentation model to load.",
choices=supported_rf_detr_model_ids(),
default=DEFAULT_RF_DETR_MODEL_ID,
),
device: str = Input(
description="Device to request for RF-DETR inference.",
choices=["auto", "cpu", "cuda", "mps"],
default="auto",
),
threshold: float = Input(description="Detection threshold.", ge=0.05, le=0.95, default=0.4),
maxFrames: int = Input(description="Maximum number of video frames to process.", ge=1, le=24, default=8),
cropMode: str = Input(
description="Crop mode before inference.",
choices=["full", "left_half", "right_half", "center_square"],
default="full",
),
writeOverlayVideo: bool = Input(description="Write a tiny overlay MP4 when the input is a video.", default=True),
) -> CogPath:
media_path = Path(media)
with tempfile.TemporaryDirectory(prefix="rf-detr-repro-") as tmpdir:
output_dir = Path(tmpdir) / "artifacts"
run_rf_detr_repro(
input_path=media_path,
output_dir=output_dir,
model_id=modelId,
requested_device=device,
threshold=threshold,
max_frames=maxFrames,
crop_mode=cropMode,
write_overlay_video=writeOverlayVideo,
)
bundle_file = tempfile.NamedTemporaryFile(
prefix="rf-detr-repro-artifacts-",
suffix=".zip",
delete=False,
)
bundle_file.close()
bundled = bundle_repro_artifacts(output_dir, Path(bundle_file.name))
return CogPath(bundled)