Skip to content

Commit fd6dd46

Browse files
authored
af: strip query string from airflow_url so version detection works (#210)
1 parent 92e1120 commit fd6dd46

8 files changed

Lines changed: 167 additions & 41 deletions

File tree

astro-airflow-mcp/src/astro_airflow_mcp/adapters/__init__.py

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from astro_airflow_mcp.adapters.airflow_v3 import AirflowV3Adapter
99
from astro_airflow_mcp.adapters.base import AirflowAdapter, NotFoundError
1010
from astro_airflow_mcp.astro_pat import AstroPATError
11+
from astro_airflow_mcp.utils import normalize_airflow_url
1112

1213

1314
def detect_version(
@@ -33,6 +34,8 @@ def detect_version(
3334
Raises:
3435
RuntimeError: If version detection fails
3536
"""
37+
airflow_url = normalize_airflow_url(airflow_url)
38+
3639
headers: dict[str, str] = {}
3740
auth: tuple[str, str] | httpx.Auth | None = None
3841

@@ -46,48 +49,52 @@ def detect_version(
4649
if basic_auth_getter:
4750
auth = basic_auth_getter()
4851

49-
# Try Airflow 3 API first (/api/v2/version)
50-
try:
51-
with httpx.Client(timeout=10.0, verify=verify) as client:
52-
response = client.get(
53-
f"{airflow_url}/api/v2/version",
54-
headers=headers,
55-
auth=auth,
56-
)
57-
if response.status_code == 200:
58-
data = response.json()
59-
version = data.get("version", "3.0.0")
60-
major = int(version.split(".")[0])
61-
return (major, version)
62-
except AstroPATError:
63-
# PAT misconfiguration (no astro login, refresh failed, etc) —
64-
# surface to the caller rather than masking as a version detection
65-
# failure.
66-
raise
67-
except Exception: # nosec B110 - try v1 API next
68-
pass
69-
70-
# Try Airflow 2 API (/api/v1/version)
71-
try:
72-
with httpx.Client(timeout=10.0, verify=verify) as client:
73-
response = client.get(
74-
f"{airflow_url}/api/v1/version",
75-
headers=headers,
76-
auth=auth,
77-
)
78-
if response.status_code == 200:
52+
probe_failures: list[str] = []
53+
54+
def _probe(api_path: str, default_version: str) -> tuple[int, str] | None:
55+
try:
56+
with httpx.Client(timeout=10.0, verify=verify) as client:
57+
response = client.get(
58+
f"{airflow_url}{api_path}/version",
59+
headers=headers,
60+
auth=auth,
61+
)
62+
except AstroPATError:
63+
# PAT misconfiguration (no astro login, refresh failed, etc) —
64+
# surface to the caller rather than masking as a version
65+
# detection failure.
66+
raise
67+
except Exception as e:
68+
probe_failures.append(f"{api_path}: {type(e).__name__}: {e}")
69+
return None
70+
71+
if response.status_code == 200:
72+
try:
7973
data = response.json()
80-
version = data.get("version", "2.0.0")
81-
major = int(version.split(".")[0])
82-
return (major, version)
83-
except AstroPATError:
84-
raise
85-
except Exception: # nosec B110 - raise RuntimeError below
86-
pass
87-
74+
except ValueError as e:
75+
probe_failures.append(
76+
f"{api_path}: 200 but non-JSON body ({type(e).__name__}); "
77+
f"got Content-Type={response.headers.get('content-type', '?')}"
78+
)
79+
return None
80+
version = data.get("version", default_version)
81+
major = int(version.split(".")[0])
82+
return (major, version)
83+
84+
probe_failures.append(f"{api_path}: HTTP {response.status_code}")
85+
return None
86+
87+
result = _probe("/api/v2", "3.0.0")
88+
if result is not None:
89+
return result
90+
result = _probe("/api/v1", "2.0.0")
91+
if result is not None:
92+
return result
93+
94+
detail = "; ".join(probe_failures) if probe_failures else "no response"
8895
raise RuntimeError(
8996
f"Failed to detect Airflow version at {airflow_url}. "
90-
"Ensure Airflow is running and accessible."
97+
f"Probes: {detail}. Ensure Airflow is running and accessible."
9198
)
9299

93100

astro-airflow-mcp/src/astro_airflow_mcp/adapters/airflow_v3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import httpx
77

88
from astro_airflow_mcp.adapters.base import AirflowAdapter, NotFoundError
9+
from astro_airflow_mcp.utils import normalize_airflow_url
910

1011

1112
class AirflowV3Adapter(AirflowAdapter):
@@ -71,7 +72,7 @@ def _exchange_for_token(
7172
try:
7273
with httpx.Client(timeout=10.0, verify=verify) as client:
7374
response = client.post(
74-
f"{airflow_url}/auth/token",
75+
f"{normalize_airflow_url(airflow_url)}/auth/token",
7576
data={"username": username, "password": password},
7677
headers={"Content-Type": "application/x-www-form-urlencoded"},
7778
)

astro-airflow-mcp/src/astro_airflow_mcp/adapters/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import httpx
99

1010
from astro_airflow_mcp.constants import READ_ONLY_ENV_VAR
11+
from astro_airflow_mcp.utils import normalize_airflow_url
1112

1213

1314
class ReadOnlyError(Exception):
@@ -65,7 +66,7 @@ def __init__(
6566
verify: SSL verification setting. True (default) enables verification,
6667
False disables it, or a string path to a CA bundle file.
6768
"""
68-
self.airflow_url = airflow_url
69+
self.airflow_url = normalize_airflow_url(airflow_url)
6970
self.version = version
7071
self._token_getter = token_getter
7172
self._basic_auth_getter = basic_auth_getter

astro-airflow-mcp/src/astro_airflow_mcp/discovery/astro_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import yaml
1010

11+
from astro_airflow_mcp.utils import normalize_airflow_url
12+
1113

1214
class AstroCliError(Exception):
1315
"""Base exception for Astro CLI errors."""
@@ -46,6 +48,10 @@ def from_inspect_yaml(cls, data: dict) -> AstroDeployment:
4648
webserver_url = metadata.get("webserver_url", "")
4749
if webserver_url and not webserver_url.startswith("http"):
4850
webserver_url = f"https://{webserver_url}"
51+
# Strip any query string / fragment that the control plane may
52+
# include (eg ``?orgId=…``). API URLs are built by string
53+
# concatenation downstream, so a stray ``?`` corrupts the path.
54+
webserver_url = normalize_airflow_url(webserver_url)
4955

5056
return cls(
5157
id=metadata.get("deployment_id", ""),

astro-airflow-mcp/src/astro_airflow_mcp/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,27 @@
11
"""Shared utility functions for CLI and MCP server."""
22

33
from typing import Any
4+
from urllib.parse import urlsplit, urlunsplit
45

56
from astro_airflow_mcp.constants import FAILED_TASK_STATES
67

78

9+
def normalize_airflow_url(url: str) -> str:
10+
"""Strip query string, fragment, and trailing slash from an Airflow base URL.
11+
12+
API URLs are built by concatenation (eg ``f"{airflow_url}/api/v2/version"``),
13+
so a stored URL like ``https://host/dep?orgId=foo`` produces the malformed
14+
``https://host/dep?orgId=foo/api/v2/version`` — the path stays at ``/dep``
15+
and ``/api/...`` ends up inside the query string. Normalizing once at the
16+
boundary keeps every downstream call safe.
17+
"""
18+
if not url:
19+
return url
20+
parts = urlsplit(url)
21+
path = parts.path.rstrip("/")
22+
return urlunsplit((parts.scheme, parts.netloc, path, "", ""))
23+
24+
825
def filter_connection_passwords(connections: list[dict[str, Any]]) -> list[dict[str, Any]]:
926
"""Filter sensitive fields from connections.
1027

astro-airflow-mcp/tests/test_adapters.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ def test_api_base_path(self):
3434
)
3535
assert adapter.api_base_path == "/api/v1"
3636

37+
def test_constructor_normalizes_airflow_url(self):
38+
"""A query string on the stored URL must not corrupt API URLs.
39+
Existing configs with ?orgId=… should keep working without re-discovery."""
40+
adapter = AirflowV2Adapter(
41+
"https://host.example.com/dep?orgId=org_abc",
42+
"2.9.0",
43+
)
44+
assert adapter.airflow_url == "https://host.example.com/dep"
45+
3746
def test_setup_auth_with_token_getter(self):
3847
"""Test auth setup with token getter."""
3948
adapter = AirflowV2Adapter(
@@ -537,6 +546,47 @@ def test_detect_version_with_token_getter(self, mocker):
537546
call_kwargs = mock_client.get.call_args[1]
538547
assert call_kwargs["headers"]["Authorization"] == "Bearer test_token"
539548

549+
def test_detect_version_strips_query_string_from_base_url(self, mocker):
550+
"""A stored URL with ?orgId=… (eg from Astro discovery) must not
551+
corrupt the probe path. See bug recreated against an Astro deployment
552+
where the saved URL had ?orgId=org_… appended."""
553+
mock_response = mocker.Mock()
554+
mock_response.status_code = 200
555+
mock_response.json.return_value = {"version": "3.2.1"}
556+
557+
mock_client = mocker.Mock()
558+
mock_client.get.return_value = mock_response
559+
mock_client.__enter__ = mocker.Mock(return_value=mock_client)
560+
mock_client.__exit__ = mocker.Mock(return_value=False)
561+
562+
mocker.patch("httpx.Client", return_value=mock_client)
563+
564+
major, full = detect_version("https://host.example.com/dep?orgId=org_abc")
565+
566+
assert (major, full) == (3, "3.2.1")
567+
called_url = mock_client.get.call_args[0][0]
568+
assert called_url == "https://host.example.com/dep/api/v2/version"
569+
570+
def test_detect_version_failure_includes_probe_detail(self, mocker):
571+
"""RuntimeError must surface the actual probe failure (status code or
572+
exception) so users don't get a black-box 'Failed to detect' error."""
573+
bad_v2 = mocker.Mock(status_code=404)
574+
bad_v1 = mocker.Mock(status_code=502)
575+
576+
mock_client = mocker.Mock()
577+
mock_client.get.side_effect = [bad_v2, bad_v1]
578+
mock_client.__enter__ = mocker.Mock(return_value=mock_client)
579+
mock_client.__exit__ = mocker.Mock(return_value=False)
580+
581+
mocker.patch("httpx.Client", return_value=mock_client)
582+
583+
with pytest.raises(RuntimeError) as exc_info:
584+
detect_version("http://localhost:8080")
585+
586+
msg = str(exc_info.value)
587+
assert "/api/v2: HTTP 404" in msg
588+
assert "/api/v1: HTTP 502" in msg
589+
540590

541591
class TestAdapterFactory:
542592
"""Tests for adapter factory."""

astro-airflow-mcp/tests/test_astro_cli.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,22 @@ def test_from_inspect_yaml_minimal(self):
7676
assert deployment.status == "UNKNOWN"
7777
assert deployment.airflow_api_url == ""
7878

79+
def test_from_inspect_yaml_strips_query_string(self):
80+
"""Some Astro deployments return webserver_url with ?orgId=… —
81+
if we store that, every API call concatenates /api/v1/... into
82+
the query string and breaks. Strip query/fragment at the boundary."""
83+
data = {
84+
"deployment": {
85+
"configuration": {"name": "t"},
86+
"metadata": {
87+
"deployment_id": "dep-123",
88+
"webserver_url": "https://xyz.astronomer.run/abc?orgId=org_abc",
89+
},
90+
}
91+
}
92+
deployment = AstroDeployment.from_inspect_yaml(data)
93+
assert deployment.airflow_api_url == "https://xyz.astronomer.run/abc"
94+
7995

8096
class TestAstroCliInstallation:
8197
"""Tests for CLI installation detection."""

astro-airflow-mcp/tests/test_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from astro_airflow_mcp.utils import (
1111
extract_failed_tasks,
1212
filter_connection_passwords,
13+
normalize_airflow_url,
1314
wrap_list_response,
1415
)
1516

@@ -196,3 +197,30 @@ def test_different_key_names(self):
196197

197198
assert "total_dag_runs" in result
198199
assert "dag_runs" in result
200+
201+
202+
class TestNormalizeAirflowUrl:
203+
"""Tests for normalize_airflow_url."""
204+
205+
def test_strips_query_string(self):
206+
# Astro stored some webserver_urls with ?orgId=… — concatenating
207+
# /api/v2/version onto a URL with a query string corrupts the path.
208+
url = "https://host.example.com/dep?orgId=org_abc"
209+
assert normalize_airflow_url(url) == "https://host.example.com/dep"
210+
211+
def test_strips_fragment(self):
212+
assert normalize_airflow_url("https://h/p#frag") == "https://h/p"
213+
214+
def test_strips_trailing_slash(self):
215+
assert normalize_airflow_url("https://h/p/") == "https://h/p"
216+
217+
def test_passthrough_clean_url(self):
218+
url = "https://h.example.com/dep"
219+
assert normalize_airflow_url(url) == url
220+
221+
def test_handles_empty(self):
222+
assert normalize_airflow_url("") == ""
223+
224+
def test_strips_query_and_fragment_and_slash(self):
225+
url = "https://h/p/?x=1#y"
226+
assert normalize_airflow_url(url) == "https://h/p"

0 commit comments

Comments
 (0)