Skip to content

Commit cc197ea

Browse files
authored
feat(api): add periodic cleanup of stale Attack Paths scans with dead-worker detection (#10387)
1 parent 2b5d015 commit cc197ea

File tree

10 files changed

+676
-74
lines changed

10 files changed

+676
-74
lines changed

.github/workflows/api-codeql.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ jobs:
5050
github.com:443
5151
release-assets.githubusercontent.com:443
5252
uploads.github.com:443
53+
release-assets.githubusercontent.com:443
54+
objects.githubusercontent.com:443
5355
5456
- name: Checkout repository
5557
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2

.github/workflows/api-container-checks.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ jobs:
8686
production.cloudflare.docker.com:443
8787
debian.map.fastlydns.net:80
8888
release-assets.githubusercontent.com:443
89+
objects.githubusercontent.com:443
8990
pypi.org:443
9091
files.pythonhosted.org:443
9192
www.powershellgallery.com:443

api/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ All notable changes to the **Prowler API** are documented in this file.
3030

3131
- Finding groups support `check_title` substring filtering [(#10377)](https://github.com/prowler-cloud/prowler/pull/10377)
3232

33+
### 🔄 Changed
34+
35+
- Attack Paths: Periodic cleanup of stale scans with dead-worker detection via Celery inspect, marking orphaned `EXECUTING` scans as `FAILED` and recovering `graph_data_ready` [(#10387)](https://github.com/prowler-cloud/prowler/pull/10387)
36+
3337
### 🐞 Fixed
3438

3539
- Finding groups latest endpoint now aggregates the latest snapshot per provider before check-level totals, keeping impacted resources aligned across providers [(#10419)](https://github.com/prowler-cloud/prowler/pull/10419)

api/docker-entrypoint.sh

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,28 @@ start_prod_server() {
3030
poetry run gunicorn -c config/guniconf.py config.wsgi:application
3131
}
3232

33+
resolve_worker_hostname() {
34+
TASK_ID=""
35+
36+
if [ -n "$ECS_CONTAINER_METADATA_URI_V4" ]; then
37+
TASK_ID=$(wget -qO- --timeout=2 "${ECS_CONTAINER_METADATA_URI_V4}/task" | \
38+
python3 -c "import sys,json; print(json.load(sys.stdin)['TaskARN'].split('/')[-1])" 2>/dev/null)
39+
fi
40+
41+
if [ -z "$TASK_ID" ]; then
42+
TASK_ID=$(python3 -c "import uuid; print(uuid.uuid4().hex)")
43+
fi
44+
45+
echo "${TASK_ID}@$(hostname)"
46+
}
47+
3348
start_worker() {
3449
echo "Starting the worker..."
35-
poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill,overview,integrations,compliance,attack-paths-scans -E --max-tasks-per-child 1
50+
poetry run python -m celery -A config.celery worker \
51+
-n "$(resolve_worker_hostname)" \
52+
-l "${DJANGO_LOGGING_LEVEL:-info}" \
53+
-Q celery,scans,scan-reports,deletion,backfill,overview,integrations,compliance,attack-paths-scans \
54+
-E --max-tasks-per-child 1
3655
}
3756

3857
start_worker_beat() {
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from django.db import migrations
2+
3+
4+
TASK_NAME = "attack-paths-cleanup-stale-scans"
5+
INTERVAL_HOURS = 1
6+
7+
8+
def create_periodic_task(apps, schema_editor):
9+
IntervalSchedule = apps.get_model("django_celery_beat", "IntervalSchedule")
10+
PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask")
11+
12+
schedule, _ = IntervalSchedule.objects.get_or_create(
13+
every=INTERVAL_HOURS,
14+
period="hours",
15+
)
16+
17+
PeriodicTask.objects.update_or_create(
18+
name=TASK_NAME,
19+
defaults={
20+
"task": TASK_NAME,
21+
"interval": schedule,
22+
"enabled": True,
23+
},
24+
)
25+
26+
27+
def delete_periodic_task(apps, schema_editor):
28+
IntervalSchedule = apps.get_model("django_celery_beat", "IntervalSchedule")
29+
PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask")
30+
31+
PeriodicTask.objects.filter(name=TASK_NAME).delete()
32+
33+
# Clean up the schedule if no other task references it
34+
IntervalSchedule.objects.filter(
35+
every=INTERVAL_HOURS,
36+
period="hours",
37+
periodictask__isnull=True,
38+
).delete()
39+
40+
41+
class Migration(migrations.Migration):
42+
dependencies = [
43+
("api", "0085_finding_group_daily_summary_trgm_indexes"),
44+
("django_celery_beat", "0019_alter_periodictasks_options"),
45+
]
46+
47+
operations = [
48+
migrations.RunPython(create_periodic_task, delete_periodic_task),
49+
]

api/src/backend/config/django/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,8 @@
299299
# SAML requirement
300300
CSRF_COOKIE_SECURE = True
301301
SESSION_COOKIE_SECURE = True
302+
303+
# Attack Paths
304+
ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES = env.int(
305+
"ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES", 2880
306+
) # 48h
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from datetime import datetime, timedelta, timezone
2+
3+
from celery import current_app, states
4+
from celery.utils.log import get_task_logger
5+
from config.django.base import ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES
6+
from tasks.jobs.attack_paths.db_utils import (
7+
_mark_scan_finished,
8+
recover_graph_data_ready,
9+
)
10+
11+
from api.attack_paths import database as graph_database
12+
from api.db_router import MainRouter
13+
from api.db_utils import rls_transaction
14+
from api.models import AttackPathsScan, StateChoices
15+
16+
logger = get_task_logger(__name__)
17+
18+
19+
def cleanup_stale_attack_paths_scans() -> dict:
20+
"""
21+
Find `EXECUTING` `AttackPathsScan` scans whose workers are dead or that have
22+
exceeded the stale threshold, and mark them as `FAILED`.
23+
24+
Two-pass detection:
25+
1. If `TaskResult.worker` exists, ping the worker.
26+
- Dead worker: cleanup immediately (any age).
27+
- Alive + past threshold: revoke the task, then cleanup.
28+
- Alive + within threshold: skip.
29+
2. If no worker field: fall back to time-based heuristic only.
30+
"""
31+
threshold = timedelta(minutes=ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES)
32+
now = datetime.now(tz=timezone.utc)
33+
cutoff = now - threshold
34+
35+
executing_scans = (
36+
AttackPathsScan.all_objects.using(MainRouter.admin_db)
37+
.filter(state=StateChoices.EXECUTING)
38+
.select_related("task__task_runner_task")
39+
)
40+
41+
# Cache worker liveness so each worker is pinged at most once
42+
executing_scans = list(executing_scans)
43+
workers = {
44+
tr.worker
45+
for scan in executing_scans
46+
if (tr := getattr(scan.task, "task_runner_task", None) if scan.task else None)
47+
and tr.worker
48+
}
49+
worker_alive = {w: _is_worker_alive(w) for w in workers}
50+
51+
cleaned_up = []
52+
53+
for scan in executing_scans:
54+
task_result = (
55+
getattr(scan.task, "task_runner_task", None) if scan.task else None
56+
)
57+
worker = task_result.worker if task_result else None
58+
59+
if worker:
60+
alive = worker_alive.get(worker, True)
61+
62+
if alive:
63+
if scan.started_at and scan.started_at >= cutoff:
64+
continue
65+
66+
# Alive but stale — revoke before cleanup
67+
_revoke_task(task_result)
68+
reason = (
69+
"Scan exceeded stale threshold — " "cleaned up by periodic task"
70+
)
71+
else:
72+
reason = "Worker dead — cleaned up by periodic task"
73+
else:
74+
# No worker recorded — time-based heuristic only
75+
if scan.started_at and scan.started_at >= cutoff:
76+
continue
77+
reason = (
78+
"No worker recorded, scan exceeded stale threshold — "
79+
"cleaned up by periodic task"
80+
)
81+
82+
if _cleanup_scan(scan, task_result, reason):
83+
cleaned_up.append(str(scan.id))
84+
85+
logger.info(
86+
f"Stale `AttackPathsScan` cleanup: {len(cleaned_up)} scan(s) cleaned up"
87+
)
88+
return {"cleaned_up_count": len(cleaned_up), "scan_ids": cleaned_up}
89+
90+
91+
def _is_worker_alive(worker: str) -> bool:
92+
"""Ping a specific Celery worker. Returns `True` if it responds or on error."""
93+
try:
94+
response = current_app.control.inspect(destination=[worker], timeout=1.0).ping()
95+
return response is not None and worker in response
96+
except Exception:
97+
logger.exception(f"Failed to ping worker {worker}, treating as alive")
98+
return True
99+
100+
101+
def _revoke_task(task_result) -> None:
102+
"""Send `SIGTERM` to a hung Celery task. Non-fatal on failure."""
103+
try:
104+
current_app.control.revoke(
105+
task_result.task_id, terminate=True, signal="SIGTERM"
106+
)
107+
logger.info(f"Revoked task {task_result.task_id}")
108+
except Exception:
109+
logger.exception(f"Failed to revoke task {task_result.task_id}")
110+
111+
112+
def _cleanup_scan(scan, task_result, reason: str) -> bool:
113+
"""
114+
Clean up a single stale `AttackPathsScan`:
115+
drop temp DB, mark `FAILED`, update `TaskResult`, recover `graph_data_ready`.
116+
117+
Returns `True` if the scan was actually cleaned up, `False` if skipped.
118+
"""
119+
scan_id_str = str(scan.id)
120+
121+
# 1. Drop temp Neo4j database
122+
tmp_db_name = graph_database.get_database_name(scan.id, temporary=True)
123+
try:
124+
graph_database.drop_database(tmp_db_name)
125+
except Exception:
126+
logger.exception(f"Failed to drop temp database {tmp_db_name}")
127+
128+
# 2. Lock row, verify still EXECUTING, mark FAILED — all atomic
129+
with rls_transaction(str(scan.tenant_id)):
130+
try:
131+
fresh_scan = AttackPathsScan.objects.select_for_update().get(id=scan.id)
132+
except AttackPathsScan.DoesNotExist:
133+
logger.warning(f"Scan {scan_id_str} no longer exists, skipping")
134+
return False
135+
136+
if fresh_scan.state != StateChoices.EXECUTING:
137+
logger.info(f"Scan {scan_id_str} is now {fresh_scan.state}, skipping")
138+
return False
139+
140+
_mark_scan_finished(fresh_scan, StateChoices.FAILED, {"global_error": reason})
141+
142+
# 3. Mark `TaskResult` as `FAILURE` (not RLS-protected, outside lock)
143+
if task_result:
144+
task_result.status = states.FAILURE
145+
task_result.date_done = datetime.now(tz=timezone.utc)
146+
task_result.save(update_fields=["status", "date_done"])
147+
148+
# 4. Recover graph_data_ready if provider data still exists
149+
recover_graph_data_ready(fresh_scan)
150+
151+
logger.info(f"Cleaned up stale scan {scan_id_str}: {reason}")
152+
return True

api/src/backend/tasks/jobs/attack_paths/db_utils.py

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -88,34 +88,41 @@ def starting_attack_paths_scan(
8888
)
8989

9090

91-
def finish_attack_paths_scan(
91+
def _mark_scan_finished(
9292
attack_paths_scan: ProwlerAPIAttackPathsScan,
9393
state: StateChoices,
9494
ingestion_exceptions: dict[str, Any],
9595
) -> None:
96-
with rls_transaction(attack_paths_scan.tenant_id):
97-
now = datetime.now(tz=timezone.utc)
98-
duration = (
99-
int((now - attack_paths_scan.started_at).total_seconds())
100-
if attack_paths_scan.started_at
101-
else 0
102-
)
96+
"""Set terminal fields on a scan. Caller must be inside a transaction."""
97+
now = datetime.now(tz=timezone.utc)
98+
duration = (
99+
int((now - attack_paths_scan.started_at).total_seconds())
100+
if attack_paths_scan.started_at
101+
else 0
102+
)
103+
attack_paths_scan.state = state
104+
attack_paths_scan.progress = 100
105+
attack_paths_scan.completed_at = now
106+
attack_paths_scan.duration = duration
107+
attack_paths_scan.ingestion_exceptions = ingestion_exceptions
108+
attack_paths_scan.save(
109+
update_fields=[
110+
"state",
111+
"progress",
112+
"completed_at",
113+
"duration",
114+
"ingestion_exceptions",
115+
]
116+
)
103117

104-
attack_paths_scan.state = state
105-
attack_paths_scan.progress = 100
106-
attack_paths_scan.completed_at = now
107-
attack_paths_scan.duration = duration
108-
attack_paths_scan.ingestion_exceptions = ingestion_exceptions
109118

110-
attack_paths_scan.save(
111-
update_fields=[
112-
"state",
113-
"progress",
114-
"completed_at",
115-
"duration",
116-
"ingestion_exceptions",
117-
]
118-
)
119+
def finish_attack_paths_scan(
120+
attack_paths_scan: ProwlerAPIAttackPathsScan,
121+
state: StateChoices,
122+
ingestion_exceptions: dict[str, Any],
123+
) -> None:
124+
with rls_transaction(attack_paths_scan.tenant_id):
125+
_mark_scan_finished(attack_paths_scan, state, ingestion_exceptions)
119126

120127

121128
def update_attack_paths_scan_progress(
@@ -194,25 +201,26 @@ def fail_attack_paths_scan(
194201
Used as a safety net when the Celery task fails outside the job's own error handling.
195202
"""
196203
attack_paths_scan = retrieve_attack_paths_scan(tenant_id, scan_id)
197-
if attack_paths_scan and attack_paths_scan.state not in (
198-
StateChoices.COMPLETED,
199-
StateChoices.FAILED,
200-
):
201-
tmp_db_name = graph_database.get_database_name(
202-
attack_paths_scan.id, temporary=True
204+
if not attack_paths_scan:
205+
return
206+
207+
tmp_db_name = graph_database.get_database_name(attack_paths_scan.id, temporary=True)
208+
try:
209+
graph_database.drop_database(tmp_db_name)
210+
except Exception:
211+
logger.exception(
212+
f"Failed to drop temp database {tmp_db_name} during failure handling"
203213
)
204-
try:
205-
graph_database.drop_database(tmp_db_name)
206214

207-
except Exception:
208-
logger.exception(
209-
f"Failed to drop temp database {tmp_db_name} during failure handling"
215+
with rls_transaction(tenant_id):
216+
try:
217+
fresh = ProwlerAPIAttackPathsScan.objects.select_for_update().get(
218+
id=attack_paths_scan.id
210219
)
220+
except ProwlerAPIAttackPathsScan.DoesNotExist:
221+
return
222+
if fresh.state in (StateChoices.COMPLETED, StateChoices.FAILED):
223+
return
224+
_mark_scan_finished(fresh, StateChoices.FAILED, {"global_error": error})
211225

212-
finish_attack_paths_scan(
213-
attack_paths_scan,
214-
StateChoices.FAILED,
215-
{"global_error": error},
216-
)
217-
218-
recover_graph_data_ready(attack_paths_scan)
226+
recover_graph_data_ready(fresh)

api/src/backend/tasks/tasks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
can_provider_run_attack_paths_scan,
1414
)
1515
from tasks.jobs.attack_paths import db_utils as attack_paths_db_utils
16+
from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans
1617
from tasks.jobs.backfill import (
1718
backfill_compliance_summaries,
1819
backfill_daily_severity_summaries,
@@ -406,6 +407,11 @@ def perform_attack_paths_scan_task(self, tenant_id: str, scan_id: str):
406407
)
407408

408409

410+
@shared_task(name="attack-paths-cleanup-stale-scans", queue="attack-paths-scans")
411+
def cleanup_stale_attack_paths_scans_task():
412+
return cleanup_stale_attack_paths_scans()
413+
414+
409415
@shared_task(name="tenant-deletion", queue="deletion", autoretry_for=(Exception,))
410416
def delete_tenant_task(tenant_id: str):
411417
return delete_tenant(pk=tenant_id)

0 commit comments

Comments
 (0)