Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 46 additions & 39 deletions astro-airflow-mcp/src/astro_airflow_mcp/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from astro_airflow_mcp.adapters.airflow_v3 import AirflowV3Adapter
from astro_airflow_mcp.adapters.base import AirflowAdapter, NotFoundError
from astro_airflow_mcp.astro_pat import AstroPATError
from astro_airflow_mcp.utils import normalize_airflow_url


def detect_version(
Expand All @@ -33,6 +34,8 @@ def detect_version(
Raises:
RuntimeError: If version detection fails
"""
airflow_url = normalize_airflow_url(airflow_url)

headers: dict[str, str] = {}
auth: tuple[str, str] | httpx.Auth | None = None

Expand All @@ -46,48 +49,52 @@ def detect_version(
if basic_auth_getter:
auth = basic_auth_getter()

# Try Airflow 3 API first (/api/v2/version)
try:
with httpx.Client(timeout=10.0, verify=verify) as client:
response = client.get(
f"{airflow_url}/api/v2/version",
headers=headers,
auth=auth,
)
if response.status_code == 200:
data = response.json()
version = data.get("version", "3.0.0")
major = int(version.split(".")[0])
return (major, version)
except AstroPATError:
# PAT misconfiguration (no astro login, refresh failed, etc) —
# surface to the caller rather than masking as a version detection
# failure.
raise
except Exception: # nosec B110 - try v1 API next
pass

# Try Airflow 2 API (/api/v1/version)
try:
with httpx.Client(timeout=10.0, verify=verify) as client:
response = client.get(
f"{airflow_url}/api/v1/version",
headers=headers,
auth=auth,
)
if response.status_code == 200:
probe_failures: list[str] = []

def _probe(api_path: str, default_version: str) -> tuple[int, str] | None:
try:
with httpx.Client(timeout=10.0, verify=verify) as client:
response = client.get(
f"{airflow_url}{api_path}/version",
headers=headers,
auth=auth,
)
except AstroPATError:
# PAT misconfiguration (no astro login, refresh failed, etc) —
# surface to the caller rather than masking as a version
# detection failure.
raise
except Exception as e:
probe_failures.append(f"{api_path}: {type(e).__name__}: {e}")
return None

if response.status_code == 200:
try:
data = response.json()
version = data.get("version", "2.0.0")
major = int(version.split(".")[0])
return (major, version)
except AstroPATError:
raise
except Exception: # nosec B110 - raise RuntimeError below
pass

except ValueError as e:
probe_failures.append(
f"{api_path}: 200 but non-JSON body ({type(e).__name__}); "
f"got Content-Type={response.headers.get('content-type', '?')}"
)
return None
version = data.get("version", default_version)
major = int(version.split(".")[0])
return (major, version)

probe_failures.append(f"{api_path}: HTTP {response.status_code}")
return None

result = _probe("/api/v2", "3.0.0")
if result is not None:
return result
result = _probe("/api/v1", "2.0.0")
if result is not None:
return result

detail = "; ".join(probe_failures) if probe_failures else "no response"
raise RuntimeError(
f"Failed to detect Airflow version at {airflow_url}. "
"Ensure Airflow is running and accessible."
f"Probes: {detail}. Ensure Airflow is running and accessible."
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import httpx

from astro_airflow_mcp.adapters.base import AirflowAdapter, NotFoundError
from astro_airflow_mcp.utils import normalize_airflow_url


class AirflowV3Adapter(AirflowAdapter):
Expand Down Expand Up @@ -71,7 +72,7 @@ def _exchange_for_token(
try:
with httpx.Client(timeout=10.0, verify=verify) as client:
response = client.post(
f"{airflow_url}/auth/token",
f"{normalize_airflow_url(airflow_url)}/auth/token",
data={"username": username, "password": password},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
Expand Down
3 changes: 2 additions & 1 deletion astro-airflow-mcp/src/astro_airflow_mcp/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import httpx

from astro_airflow_mcp.constants import READ_ONLY_ENV_VAR
from astro_airflow_mcp.utils import normalize_airflow_url


class ReadOnlyError(Exception):
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
verify: SSL verification setting. True (default) enables verification,
False disables it, or a string path to a CA bundle file.
"""
self.airflow_url = airflow_url
self.airflow_url = normalize_airflow_url(airflow_url)
self.version = version
self._token_getter = token_getter
self._basic_auth_getter = basic_auth_getter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import yaml

from astro_airflow_mcp.utils import normalize_airflow_url


class AstroCliError(Exception):
"""Base exception for Astro CLI errors."""
Expand Down Expand Up @@ -46,6 +48,10 @@ def from_inspect_yaml(cls, data: dict) -> AstroDeployment:
webserver_url = metadata.get("webserver_url", "")
if webserver_url and not webserver_url.startswith("http"):
webserver_url = f"https://{webserver_url}"
# Strip any query string / fragment that the control plane may
# include (eg ``?orgId=…``). API URLs are built by string
# concatenation downstream, so a stray ``?`` corrupts the path.
webserver_url = normalize_airflow_url(webserver_url)

return cls(
id=metadata.get("deployment_id", ""),
Expand Down
17 changes: 17 additions & 0 deletions astro-airflow-mcp/src/astro_airflow_mcp/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
"""Shared utility functions for CLI and MCP server."""

from typing import Any
from urllib.parse import urlsplit, urlunsplit

from astro_airflow_mcp.constants import FAILED_TASK_STATES


def normalize_airflow_url(url: str) -> str:
"""Strip query string, fragment, and trailing slash from an Airflow base URL.

API URLs are built by concatenation (eg ``f"{airflow_url}/api/v2/version"``),
so a stored URL like ``https://host/dep?orgId=foo`` produces the malformed
``https://host/dep?orgId=foo/api/v2/version`` — the path stays at ``/dep``
and ``/api/...`` ends up inside the query string. Normalizing once at the
boundary keeps every downstream call safe.
"""
if not url:
return url
parts = urlsplit(url)
path = parts.path.rstrip("/")
return urlunsplit((parts.scheme, parts.netloc, path, "", ""))


def filter_connection_passwords(connections: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Filter sensitive fields from connections.

Expand Down
50 changes: 50 additions & 0 deletions astro-airflow-mcp/tests/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def test_api_base_path(self):
)
assert adapter.api_base_path == "/api/v1"

def test_constructor_normalizes_airflow_url(self):
"""A query string on the stored URL must not corrupt API URLs.
Existing configs with ?orgId=… should keep working without re-discovery."""
adapter = AirflowV2Adapter(
"https://host.example.com/dep?orgId=org_abc",
"2.9.0",
)
assert adapter.airflow_url == "https://host.example.com/dep"

def test_setup_auth_with_token_getter(self):
"""Test auth setup with token getter."""
adapter = AirflowV2Adapter(
Expand Down Expand Up @@ -537,6 +546,47 @@ def test_detect_version_with_token_getter(self, mocker):
call_kwargs = mock_client.get.call_args[1]
assert call_kwargs["headers"]["Authorization"] == "Bearer test_token"

def test_detect_version_strips_query_string_from_base_url(self, mocker):
"""A stored URL with ?orgId=… (eg from Astro discovery) must not
corrupt the probe path. See bug recreated against an Astro deployment
where the saved URL had ?orgId=org_… appended."""
mock_response = mocker.Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"version": "3.2.1"}

mock_client = mocker.Mock()
mock_client.get.return_value = mock_response
mock_client.__enter__ = mocker.Mock(return_value=mock_client)
mock_client.__exit__ = mocker.Mock(return_value=False)

mocker.patch("httpx.Client", return_value=mock_client)

major, full = detect_version("https://host.example.com/dep?orgId=org_abc")

assert (major, full) == (3, "3.2.1")
called_url = mock_client.get.call_args[0][0]
assert called_url == "https://host.example.com/dep/api/v2/version"

def test_detect_version_failure_includes_probe_detail(self, mocker):
"""RuntimeError must surface the actual probe failure (status code or
exception) so users don't get a black-box 'Failed to detect' error."""
bad_v2 = mocker.Mock(status_code=404)
bad_v1 = mocker.Mock(status_code=502)

mock_client = mocker.Mock()
mock_client.get.side_effect = [bad_v2, bad_v1]
mock_client.__enter__ = mocker.Mock(return_value=mock_client)
mock_client.__exit__ = mocker.Mock(return_value=False)

mocker.patch("httpx.Client", return_value=mock_client)

with pytest.raises(RuntimeError) as exc_info:
detect_version("http://localhost:8080")

msg = str(exc_info.value)
assert "/api/v2: HTTP 404" in msg
assert "/api/v1: HTTP 502" in msg


class TestAdapterFactory:
"""Tests for adapter factory."""
Expand Down
16 changes: 16 additions & 0 deletions astro-airflow-mcp/tests/test_astro_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def test_from_inspect_yaml_minimal(self):
assert deployment.status == "UNKNOWN"
assert deployment.airflow_api_url == ""

def test_from_inspect_yaml_strips_query_string(self):
"""Some Astro deployments return webserver_url with ?orgId=… —
if we store that, every API call concatenates /api/v1/... into
the query string and breaks. Strip query/fragment at the boundary."""
data = {
"deployment": {
"configuration": {"name": "t"},
"metadata": {
"deployment_id": "dep-123",
"webserver_url": "https://xyz.astronomer.run/abc?orgId=org_abc",
},
}
}
deployment = AstroDeployment.from_inspect_yaml(data)
assert deployment.airflow_api_url == "https://xyz.astronomer.run/abc"


class TestAstroCliInstallation:
"""Tests for CLI installation detection."""
Expand Down
28 changes: 28 additions & 0 deletions astro-airflow-mcp/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from astro_airflow_mcp.utils import (
extract_failed_tasks,
filter_connection_passwords,
normalize_airflow_url,
wrap_list_response,
)

Expand Down Expand Up @@ -196,3 +197,30 @@ def test_different_key_names(self):

assert "total_dag_runs" in result
assert "dag_runs" in result


class TestNormalizeAirflowUrl:
"""Tests for normalize_airflow_url."""

def test_strips_query_string(self):
# Astro stored some webserver_urls with ?orgId=… — concatenating
# /api/v2/version onto a URL with a query string corrupts the path.
url = "https://host.example.com/dep?orgId=org_abc"
assert normalize_airflow_url(url) == "https://host.example.com/dep"

def test_strips_fragment(self):
assert normalize_airflow_url("https://h/p#frag") == "https://h/p"

def test_strips_trailing_slash(self):
assert normalize_airflow_url("https://h/p/") == "https://h/p"

def test_passthrough_clean_url(self):
url = "https://h.example.com/dep"
assert normalize_airflow_url(url) == url

def test_handles_empty(self):
assert normalize_airflow_url("") == ""

def test_strips_query_and_fragment_and_slash(self):
url = "https://h/p/?x=1#y"
assert normalize_airflow_url(url) == "https://h/p"