Skip to content

Commit c43a272

Browse files
committed
refactor(backend): add type hints for improved clarity and validation
Modified files (6): - backend/api/main.py: Updated root and health check endpoints - backend/api/routers/models.py: Enhanced delete job endpoint - backend/api/routers/predict.py: Improved prediction info endpoint - backend/api/services/batch_service.py: Refined job status method - backend/training/preprocessor.py: Updated feature metadata method - backend/training/train.py: Enhanced job status update method
1 parent 1e55316 commit c43a272

9 files changed

Lines changed: 20 additions & 24 deletions

File tree

backend/api/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from fastapi.middleware.cors import CORSMiddleware
33
from fastapi.responses import JSONResponse
44
from mangum import Mangum
5+
from typing import Dict
56
from .routers import upload, training, models, datasets, predict
67
from .utils.helpers import get_settings
78

@@ -34,7 +35,7 @@
3435

3536

3637
@app.get("/")
37-
async def root() -> dict[str, str]:
38+
async def root() -> Dict[str, str]:
3839
"""Health check endpoint"""
3940
return {
4041
"message": "AWS AutoML Lite API",
@@ -44,7 +45,7 @@ async def root() -> dict[str, str]:
4445

4546

4647
@app.get("/health")
47-
async def health_check() -> dict[str, str]:
48+
async def health_check() -> Dict[str, str]:
4849
"""Detailed health check"""
4950
return {
5051
"status": "healthy",

backend/api/routers/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from fastapi import APIRouter, HTTPException, status, Query
2-
from typing import Optional
2+
from typing import Dict, Optional, Any
33
from ..models.schemas import (
44
JobListResponse, JobResponse, JobStatus, ProblemType, JobUpdateRequest,
55
DeployRequest, DeployResponse, PreprocessingInfo
@@ -120,7 +120,7 @@ async def get_job_status(job_id: str) -> JobResponse:
120120

121121

122122
@router.delete("/{job_id}")
123-
async def delete_job(job_id: str, delete_data: bool = True) -> dict:
123+
async def delete_job(job_id: str, delete_data: bool = True) -> Dict[str, Any]:
124124
"""
125125
Delete a training job and optionally all associated data (model, report, dataset)
126126
"""

backend/api/routers/predict.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from fastapi import APIRouter, HTTPException, status
1414
import logging
1515
import time
16-
from typing import Dict, Any
16+
from typing import Dict, Any, Optional
1717
import numpy as np
1818
import tempfile
1919
import os
@@ -128,7 +128,7 @@ def _prepare_input(
128128
features: Dict[str, Any],
129129
feature_columns: list,
130130
session: Any,
131-
preprocessing_info: Dict[str, Any] | None = None
131+
preprocessing_info: Optional[Dict[str, Any]]
132132
) -> np.ndarray:
133133
"""
134134
Prepare input features for ONNX model inference.
@@ -306,7 +306,7 @@ async def make_prediction(job_id: str, request: PredictionInput) -> PredictionRe
306306

307307

308308
@router.get("/{job_id}/info")
309-
async def get_prediction_info(job_id: str) -> dict:
309+
async def get_prediction_info(job_id: str) -> Dict[str, Any]:
310310
"""
311311
Get information about a deployed model for making predictions.
312312

backend/api/services/batch_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import boto3
2-
from typing import Dict, Any
2+
from typing import Dict, Any, Optional
33
from botocore.exceptions import ClientError
44
from ..utils.helpers import get_settings
55

@@ -49,7 +49,7 @@ def submit_training_job(
4949
except ClientError as e:
5050
raise Exception(f"Error submitting batch job: {str(e)}")
5151

52-
def get_job_status(self, batch_job_id: str) -> Dict[str, Any] | None:
52+
def get_job_status(self, batch_job_id: str) -> Optional[Dict[str, Any]]:
5353
"""Get the status of a batch job"""
5454
try:
5555
response = self.batch_client.describe_jobs(jobs=[batch_job_id])

backend/tests/api/test_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
without requiring actual AWS services.
66
"""
77
import pytest
8-
from unittest.mock import patch, MagicMock
8+
from unittest.mock import patch
99
import sys
1010
from pathlib import Path
1111

backend/tests/api/test_s3_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def test_object_exists(self, s3_bucket):
287287
try:
288288
s3.head_object(Bucket=bucket, Key="not-exists.csv")
289289
not_exists = True
290-
except:
290+
except s3.exceptions.ClientError:
291291
not_exists = False
292292

293293
assert not_exists == False

backend/tests/api/test_schemas.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77
import pytest
88
from pydantic import ValidationError
9-
from datetime import datetime
109
import sys
1110
from pathlib import Path
1211

@@ -20,14 +19,10 @@
2019
DatasetMetadata,
2120
ColumnStats,
2221
JobStatus,
23-
JobDetails,
24-
JobResponse,
2522
PredictionInput,
2623
PredictionResponse,
27-
UploadRequest,
2824
UploadResponse,
2925
ProblemType,
30-
TrainingMetrics,
3126
)
3227

3328

backend/training/preprocessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from decimal import Decimal
44
from sklearn.model_selection import train_test_split
55
from sklearn.preprocessing import LabelEncoder, StandardScaler
6-
from typing import Tuple, List
6+
from typing import Dict, Any, Optional, Tuple, List
77

88
# Feature-engine for robust feature selection
99
from feature_engine.selection import DropConstantFeatures, DropDuplicateFeatures
@@ -154,7 +154,7 @@ def encode_categorical(self, df: pd.DataFrame, fit: bool = True) -> pd.DataFrame
154154

155155
return df
156156

157-
def get_feature_metadata(self, df: pd.DataFrame | None = None) -> dict:
157+
def get_feature_metadata(self, df: Optional[pd.DataFrame]) -> Dict[str, Any]:
158158
"""
159159
Get metadata about features for inference.
160160

backend/training/train.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55
from datetime import datetime, timezone
66
import traceback
7-
from typing import Any, Dict, List
7+
from typing import Any, Dict, List, Optional
88

99
from preprocessor import AutoPreprocessor
1010
from eda import generate_eda_report
@@ -202,7 +202,7 @@ def update_job_status(
202202
table: Any,
203203
job_id: str,
204204
status: str,
205-
error_message: str | None = None
205+
error_message: Optional[str]
206206
) -> None:
207207
"""Update job status in DynamoDB"""
208208
now = datetime.now(timezone.utc).isoformat()
@@ -241,14 +241,14 @@ def update_job_completion(
241241
target_column: str,
242242
problem_type: str,
243243
model_path: str,
244-
onnx_model_path: str | None,
244+
onnx_model_path: Optional[str],
245245
eda_report_s3_path: str,
246246
training_report_s3_path: str,
247247
metrics: Dict[str, Any],
248248
feature_importance: Dict[str, float],
249-
dropped_columns: List[str] | None = None,
250-
feature_columns: List[str] | None = None,
251-
feature_metadata: Dict[str, Any] | None = None
249+
dropped_columns: Optional[List[str]],
250+
feature_columns: Optional[List[str]],
251+
feature_metadata: Optional[Dict[str, Any]]
252252
) -> None:
253253
"""Update job with completion details"""
254254
from decimal import Decimal

0 commit comments

Comments
 (0)