1616from sqlalchemy import create_engine , text
1717from sqlalchemy .orm import sessionmaker , Session
1818import uuid
19+ import ipaddress
20+ import socket
1921
2022# Configuration
2123DATABASE_URL = os .getenv ("DATABASE_URL" , "postgresql://postgres:postgres@db:5432/m3u8_db" )
2224REDIS_URL = os .getenv ("REDIS_URL" , "redis://redis:6379/0" )
2325API_KEY = os .getenv ("API_KEY" )
26+ RATE_LIMIT_PER_MINUTE = int (os .getenv ("RATE_LIMIT_PER_MINUTE" , "0" ) or "0" )
27+ ALLOWED_CLIENT_CIDRS_RAW = os .getenv ("ALLOWED_CLIENT_CIDRS" , "" ).strip ()
28+ SSRF_GUARD_ENABLED = os .getenv ("SSRF_GUARD" , "false" ).strip ().lower () in ("1" , "true" , "yes" , "y" , "on" )
2429
2530# Backward-compatible default: allow all origins unless explicitly restricted.
2631_allowed_origins_raw = os .getenv ("ALLOWED_ORIGINS" , "*" ).strip ()
5863# Redis setup
5964redis_client = redis .from_url (REDIS_URL , decode_responses = True )
6065
66+ # Security helpers
67+ def _get_client_ip (request : Request ) -> str :
68+ """
69+ Best-effort client IP for rate limiting / allowlisting.
70+ If you're behind a reverse proxy, ensure it is trusted before relying on X-Forwarded-For.
71+ """
72+ xff = (request .headers .get ("x-forwarded-for" ) or "" ).strip ()
73+ if xff :
74+ # Use the left-most (original) IP
75+ return xff .split ("," )[0 ].strip ()
76+ if request .client and request .client .host :
77+ return request .client .host
78+ return "unknown"
79+
80+
81+ def _parse_allowed_client_networks () -> list [ipaddress ._BaseNetwork ]:
82+ if not ALLOWED_CLIENT_CIDRS_RAW :
83+ return []
84+ networks : list [ipaddress ._BaseNetwork ] = []
85+ for token in [t .strip () for t in ALLOWED_CLIENT_CIDRS_RAW .split ("," ) if t .strip ()]:
86+ try :
87+ networks .append (ipaddress .ip_network (token , strict = False ))
88+ except ValueError :
89+ raise HTTPException (status_code = 503 , detail = f"Server misconfigured: invalid ALLOWED_CLIENT_CIDRS entry: { token } " )
90+ return networks
91+
92+
93+ _ALLOWED_CLIENT_NETWORKS = _parse_allowed_client_networks ()
94+
95+
96+ def _enforce_client_allowlist (request : Request ) -> None :
97+ if not _ALLOWED_CLIENT_NETWORKS :
98+ return
99+ client_ip_str = _get_client_ip (request )
100+ try :
101+ client_ip = ipaddress .ip_address (client_ip_str )
102+ except ValueError :
103+ raise HTTPException (status_code = 403 , detail = "Client IP not allowed" )
104+ for net in _ALLOWED_CLIENT_NETWORKS :
105+ if client_ip in net :
106+ return
107+ raise HTTPException (status_code = 403 , detail = "Client IP not allowed" )
108+
109+
110+ def _rate_limit (request : Request , bucket : str ) -> None :
111+ if RATE_LIMIT_PER_MINUTE <= 0 :
112+ return
113+ client_ip = _get_client_ip (request )
114+ window = int (datetime .utcnow ().timestamp () // 60 )
115+ key = f"rl:{ bucket } :{ client_ip } :{ window } "
116+ try :
117+ count = redis_client .incr (key )
118+ redis_client .expire (key , 90 )
119+ except Exception :
120+ # If Redis is unavailable, skip rate limiting (avoid breaking core API).
121+ return
122+ if count > RATE_LIMIT_PER_MINUTE :
123+ raise HTTPException (status_code = 429 , detail = "Rate limit exceeded" )
124+
125+
126+ def _resolve_host_ips (hostname : str ) -> list [ipaddress ._BaseAddress ]:
127+ # Resolve A/AAAA records; if it fails, we treat as invalid for SSRF protection.
128+ infos = socket .getaddrinfo (hostname , None , proto = socket .IPPROTO_TCP )
129+ ips : list [ipaddress ._BaseAddress ] = []
130+ for info in infos :
131+ sockaddr = info [4 ]
132+ ip_str = sockaddr [0 ]
133+ ips .append (ipaddress .ip_address (ip_str ))
134+ return ips
135+
136+
137+ def _is_ip_public (ip : ipaddress ._BaseAddress ) -> bool :
138+ # Block common SSRF targets: loopback, link-local, RFC1918/ULA, multicast, reserved, etc.
139+ if ip .is_loopback or ip .is_private or ip .is_link_local or ip .is_multicast or ip .is_reserved or ip .is_unspecified :
140+ return False
141+ return True
142+
143+
144+ def _enforce_ssrf_guard (url : HttpUrl ) -> None :
145+ if not SSRF_GUARD_ENABLED :
146+ return
147+ hostname = url .host
148+ if not hostname :
149+ raise HTTPException (status_code = 400 , detail = "Invalid URL host" )
150+ if hostname .lower () in ("localhost" ,):
151+ raise HTTPException (status_code = 400 , detail = "URL host not allowed" )
152+ try :
153+ ips = _resolve_host_ips (hostname )
154+ except Exception :
155+ raise HTTPException (status_code = 400 , detail = "URL host could not be resolved" )
156+ if not ips :
157+ raise HTTPException (status_code = 400 , detail = "URL host could not be resolved" )
158+ for ip in ips :
159+ if not _is_ip_public (ip ):
160+ raise HTTPException (status_code = 400 , detail = "URL host not allowed" )
161+
61162# Pydantic models
62163class DownloadRequest (BaseModel ):
63164 url : HttpUrl
@@ -73,6 +174,8 @@ def validate_video_url(cls, v):
73174 is_valid = '.m3u8' in url_str or '.mp4' in url_str
74175 if not is_valid :
75176 raise ValueError ('URL must contain .m3u8 or .mp4' )
177+ # Optional SSRF protection for public deployments
178+ _enforce_ssrf_guard (v )
76179 return v
77180
78181class JobResponse (BaseModel ):
@@ -102,8 +205,10 @@ def get_db():
102205 finally :
103206 db .close ()
104207
105- def verify_api_key (authorization : Optional [str ] = Header (None )):
208+ def verify_api_key (request : Request , authorization : Optional [str ] = Header (None )):
106209 """Verify API key from Authorization header"""
210+ _enforce_client_allowlist (request )
211+ _rate_limit (request , bucket = "auth" )
107212 if not API_KEY or API_KEY .strip () == "" or API_KEY .strip () == "change-this-key" :
108213 raise HTTPException (status_code = 503 , detail = "Server not configured: API_KEY is not set" )
109214 if not authorization :
@@ -128,9 +233,15 @@ async def root():
128233 }
129234
130235@app .get ("/api/health" )
131- async def health_check ():
236+ async def health_check (request : Request , authorization : Optional [ str ] = Header ( None ) ):
132237 """Health check endpoint"""
133238 try :
239+ # Avoid exposing internal status to the public internet.
240+ # Allow localhost checks (Docker healthcheck) without auth; require API key otherwise.
241+ client_ip = _get_client_ip (request )
242+ if client_ip not in ("127.0.0.1" , "::1" ):
243+ verify_api_key (request = request , authorization = authorization )
244+
134245 # Check database
135246 db = SessionLocal ()
136247 db .execute (text ("SELECT 1" ))
0 commit comments