Skip to content

Commit 9d61a78

Browse files
committed
feat(inference): add serverless model deployment and ONNX predictions
- Add POST /jobs/{job_id}/deploy endpoint for one-click model deploy - Add POST /predict/{job_id} endpoint with ONNX Runtime inference - Add GET /predict/{job_id}/info endpoint for model metadata - Add prediction playground UI on results page - Cache ONNX models in Lambda memory for fast subsequent predictions - Update ONNX Runtime to 1.20.1 for Docker compatibility - Add privileged mode to docker-compose for local development Cost: $0 idle vs ~$50-100/month SageMaker endpoint
1 parent c9c3994 commit 9d61a78

12 files changed

Lines changed: 800 additions & 8 deletions

File tree

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased]
99

1010
### Added
11+
- **Serverless Model Inference** - Deploy and make predictions without SageMaker
12+
- One-click model deploy button on results page
13+
- `POST /jobs/{job_id}/deploy` endpoint to deploy/undeploy models
14+
- `POST /predict/{job_id}` endpoint for making predictions with ONNX Runtime
15+
- `GET /predict/{job_id}/info` endpoint for model metadata
16+
- ONNX model caching in Lambda memory for fast subsequent predictions
17+
- Prediction Playground UI with interactive feature input form
18+
- Real-time prediction results with confidence and probabilities
19+
- Cost comparison panel: Lambda ($0 idle) vs SageMaker (~$50-100/month)
20+
- ONNX Runtime 1.20.1 for serverless inference (compatible with Docker local dev)
21+
1122
- **Dark Mode Support** - Full dark/light/system theme support across all pages
1223
- Integrated `next-themes` for flicker-free theme switching
1324
- `ThemeToggle` component with 3-way cycling (Light → Dark → System)

backend/api/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from fastapi.middleware.cors import CORSMiddleware
33
from fastapi.responses import JSONResponse
44
from mangum import Mangum
5-
from .routers import upload, training, models, datasets
5+
from .routers import upload, training, models, datasets, predict
66
from .utils.helpers import get_settings
77

88
settings = get_settings()
@@ -30,6 +30,7 @@
3030
app.include_router(datasets.router)
3131
app.include_router(training.router)
3232
app.include_router(models.router)
33+
app.include_router(predict.router)
3334

3435

3536
@app.get("/")

backend/api/models/schemas.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ class JobDetails(BaseModel):
119119
model_config = {"protected_namespaces": ()}
120120

121121

122+
class PreprocessingInfo(BaseModel):
123+
"""Preprocessing information for inference"""
124+
feature_columns: Optional[List[str]] = None
125+
feature_count: Optional[int] = None
126+
dropped_columns: Optional[List[str]] = None
127+
dropped_count: Optional[int] = None
128+
129+
122130
class JobResponse(BaseModel):
123131
job_id: str
124132
dataset_id: str
@@ -139,6 +147,8 @@ class JobResponse(BaseModel):
139147
error_message: Optional[str] = None
140148
tags: Optional[List[str]] = None # Custom labels for filtering
141149
notes: Optional[str] = None # User notes for experiment tracking
150+
deployed: bool = False # Whether the model is deployed for inference
151+
preprocessing_info: Optional[PreprocessingInfo] = None # Feature info for inference
142152

143153
model_config = {"protected_namespaces": ()}
144154

@@ -149,6 +159,35 @@ class JobUpdateRequest(BaseModel):
149159
notes: Optional[str] = Field(default=None, max_length=1000, description="User notes for experiment tracking")
150160

151161

162+
class DeployRequest(BaseModel):
163+
"""Request schema for deploying/undeploying a model"""
164+
deploy: bool = Field(..., description="True to deploy, False to undeploy")
165+
166+
167+
class DeployResponse(BaseModel):
168+
"""Response schema for deploy/undeploy operations"""
169+
job_id: str
170+
deployed: bool
171+
message: str
172+
173+
174+
class PredictionInput(BaseModel):
175+
"""Request schema for making predictions"""
176+
features: Dict[str, float | int | str] = Field(..., description="Input features for prediction")
177+
178+
179+
class PredictionResponse(BaseModel):
180+
"""Response schema for predictions"""
181+
job_id: str
182+
prediction: float | int | str
183+
probability: Optional[float] = None
184+
probabilities: Optional[Dict[str, float]] = None
185+
inference_time_ms: float
186+
model_type: str
187+
188+
model_config = {"protected_namespaces": ()}
189+
190+
152191
class JobListResponse(BaseModel):
153192
jobs: List[JobDetails]
154193
next_token: Optional[str] = None

backend/api/routers/models.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from fastapi import APIRouter, HTTPException, status, Query
22
from typing import Optional
3-
from ..models.schemas import JobListResponse, JobResponse, JobStatus, ProblemType, JobUpdateRequest
3+
from ..models.schemas import (
4+
JobListResponse, JobResponse, JobStatus, ProblemType, JobUpdateRequest,
5+
DeployRequest, DeployResponse, PreprocessingInfo
6+
)
47
from ..services.dynamo_service import dynamodb_service
58
from ..services.s3_service import s3_service
69
from ..utils.helpers import get_settings
@@ -22,6 +25,16 @@ async def get_job_status(job_id: str):
2225
detail="Job not found"
2326
)
2427

28+
# Build preprocessing_info if available
29+
preprocessing_info = None
30+
if job.get('preprocessing_info'):
31+
preprocessing_info = PreprocessingInfo(
32+
feature_columns=job['preprocessing_info'].get('feature_columns'),
33+
feature_count=job['preprocessing_info'].get('feature_count'),
34+
dropped_columns=job['preprocessing_info'].get('dropped_columns'),
35+
dropped_count=job['preprocessing_info'].get('dropped_count')
36+
)
37+
2538
response = JobResponse(
2639
job_id=job['job_id'],
2740
dataset_id=job['dataset_id'],
@@ -36,7 +49,9 @@ async def get_job_status(job_id: str):
3649
metrics=job.get('metrics'),
3750
error_message=job.get('error_message'),
3851
tags=job.get('tags'),
39-
notes=job.get('notes')
52+
notes=job.get('notes'),
53+
deployed=job.get('deployed', False),
54+
preprocessing_info=preprocessing_info
4055
)
4156

4257
# Generate download URLs if job is completed
@@ -234,6 +249,54 @@ async def update_job_metadata(job_id: str, request: JobUpdateRequest):
234249
)
235250

236251

252+
@router.post("/{job_id}/deploy", response_model=DeployResponse)
253+
async def deploy_model(job_id: str, request: DeployRequest):
254+
"""
255+
Deploy or undeploy a trained model for inference.
256+
Only completed jobs with ONNX models can be deployed.
257+
"""
258+
try:
259+
# Verify job exists
260+
job = dynamodb_service.get_job(job_id)
261+
if not job:
262+
raise HTTPException(
263+
status_code=status.HTTP_404_NOT_FOUND,
264+
detail="Job not found"
265+
)
266+
267+
# Check if job is completed
268+
if job['status'] != JobStatus.COMPLETED.value:
269+
raise HTTPException(
270+
status_code=status.HTTP_400_BAD_REQUEST,
271+
detail=f"Cannot deploy job with status '{job['status']}'. Only completed jobs can be deployed."
272+
)
273+
274+
# Check if ONNX model exists
275+
if request.deploy and not job.get('onnx_model_path'):
276+
raise HTTPException(
277+
status_code=status.HTTP_400_BAD_REQUEST,
278+
detail="No ONNX model available for this job. Only jobs with ONNX export can be deployed."
279+
)
280+
281+
# Update deployed status
282+
dynamodb_service.update_job_deployed(job_id, request.deploy)
283+
284+
action = "deployed" if request.deploy else "undeployed"
285+
return DeployResponse(
286+
job_id=job_id,
287+
deployed=request.deploy,
288+
message=f"Model successfully {action}"
289+
)
290+
291+
except HTTPException:
292+
raise
293+
except Exception as e:
294+
raise HTTPException(
295+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
296+
detail=f"Error deploying model: {str(e)}"
297+
)
298+
299+
237300
@router.get("", response_model=JobListResponse)
238301
async def list_jobs(
239302
limit: int = Query(default=20, ge=1, le=100),

0 commit comments

Comments
 (0)