|
| 1 | +from datetime import datetime, timedelta, timezone |
| 2 | + |
| 3 | +from celery import current_app, states |
| 4 | +from celery.utils.log import get_task_logger |
| 5 | +from config.django.base import ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES |
| 6 | +from tasks.jobs.attack_paths.db_utils import ( |
| 7 | + _mark_scan_finished, |
| 8 | + recover_graph_data_ready, |
| 9 | +) |
| 10 | + |
| 11 | +from api.attack_paths import database as graph_database |
| 12 | +from api.db_router import MainRouter |
| 13 | +from api.db_utils import rls_transaction |
| 14 | +from api.models import AttackPathsScan, StateChoices |
| 15 | + |
| 16 | +logger = get_task_logger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +def cleanup_stale_attack_paths_scans() -> dict: |
| 20 | + """ |
| 21 | + Find `EXECUTING` `AttackPathsScan` scans whose workers are dead or that have |
| 22 | + exceeded the stale threshold, and mark them as `FAILED`. |
| 23 | +
|
| 24 | + Two-pass detection: |
| 25 | + 1. If `TaskResult.worker` exists, ping the worker. |
| 26 | + - Dead worker: cleanup immediately (any age). |
| 27 | + - Alive + past threshold: revoke the task, then cleanup. |
| 28 | + - Alive + within threshold: skip. |
| 29 | + 2. If no worker field: fall back to time-based heuristic only. |
| 30 | + """ |
| 31 | + threshold = timedelta(minutes=ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES) |
| 32 | + now = datetime.now(tz=timezone.utc) |
| 33 | + cutoff = now - threshold |
| 34 | + |
| 35 | + executing_scans = ( |
| 36 | + AttackPathsScan.all_objects.using(MainRouter.admin_db) |
| 37 | + .filter(state=StateChoices.EXECUTING) |
| 38 | + .select_related("task__task_runner_task") |
| 39 | + ) |
| 40 | + |
| 41 | + # Cache worker liveness so each worker is pinged at most once |
| 42 | + executing_scans = list(executing_scans) |
| 43 | + workers = { |
| 44 | + tr.worker |
| 45 | + for scan in executing_scans |
| 46 | + if (tr := getattr(scan.task, "task_runner_task", None) if scan.task else None) |
| 47 | + and tr.worker |
| 48 | + } |
| 49 | + worker_alive = {w: _is_worker_alive(w) for w in workers} |
| 50 | + |
| 51 | + cleaned_up = [] |
| 52 | + |
| 53 | + for scan in executing_scans: |
| 54 | + task_result = ( |
| 55 | + getattr(scan.task, "task_runner_task", None) if scan.task else None |
| 56 | + ) |
| 57 | + worker = task_result.worker if task_result else None |
| 58 | + |
| 59 | + if worker: |
| 60 | + alive = worker_alive.get(worker, True) |
| 61 | + |
| 62 | + if alive: |
| 63 | + if scan.started_at and scan.started_at >= cutoff: |
| 64 | + continue |
| 65 | + |
| 66 | + # Alive but stale — revoke before cleanup |
| 67 | + _revoke_task(task_result) |
| 68 | + reason = ( |
| 69 | + "Scan exceeded stale threshold — " "cleaned up by periodic task" |
| 70 | + ) |
| 71 | + else: |
| 72 | + reason = "Worker dead — cleaned up by periodic task" |
| 73 | + else: |
| 74 | + # No worker recorded — time-based heuristic only |
| 75 | + if scan.started_at and scan.started_at >= cutoff: |
| 76 | + continue |
| 77 | + reason = ( |
| 78 | + "No worker recorded, scan exceeded stale threshold — " |
| 79 | + "cleaned up by periodic task" |
| 80 | + ) |
| 81 | + |
| 82 | + if _cleanup_scan(scan, task_result, reason): |
| 83 | + cleaned_up.append(str(scan.id)) |
| 84 | + |
| 85 | + logger.info( |
| 86 | + f"Stale `AttackPathsScan` cleanup: {len(cleaned_up)} scan(s) cleaned up" |
| 87 | + ) |
| 88 | + return {"cleaned_up_count": len(cleaned_up), "scan_ids": cleaned_up} |
| 89 | + |
| 90 | + |
| 91 | +def _is_worker_alive(worker: str) -> bool: |
| 92 | + """Ping a specific Celery worker. Returns `True` if it responds or on error.""" |
| 93 | + try: |
| 94 | + response = current_app.control.inspect(destination=[worker], timeout=1.0).ping() |
| 95 | + return response is not None and worker in response |
| 96 | + except Exception: |
| 97 | + logger.exception(f"Failed to ping worker {worker}, treating as alive") |
| 98 | + return True |
| 99 | + |
| 100 | + |
| 101 | +def _revoke_task(task_result) -> None: |
| 102 | + """Send `SIGTERM` to a hung Celery task. Non-fatal on failure.""" |
| 103 | + try: |
| 104 | + current_app.control.revoke( |
| 105 | + task_result.task_id, terminate=True, signal="SIGTERM" |
| 106 | + ) |
| 107 | + logger.info(f"Revoked task {task_result.task_id}") |
| 108 | + except Exception: |
| 109 | + logger.exception(f"Failed to revoke task {task_result.task_id}") |
| 110 | + |
| 111 | + |
| 112 | +def _cleanup_scan(scan, task_result, reason: str) -> bool: |
| 113 | + """ |
| 114 | + Clean up a single stale `AttackPathsScan`: |
| 115 | + drop temp DB, mark `FAILED`, update `TaskResult`, recover `graph_data_ready`. |
| 116 | +
|
| 117 | + Returns `True` if the scan was actually cleaned up, `False` if skipped. |
| 118 | + """ |
| 119 | + scan_id_str = str(scan.id) |
| 120 | + |
| 121 | + # 1. Drop temp Neo4j database |
| 122 | + tmp_db_name = graph_database.get_database_name(scan.id, temporary=True) |
| 123 | + try: |
| 124 | + graph_database.drop_database(tmp_db_name) |
| 125 | + except Exception: |
| 126 | + logger.exception(f"Failed to drop temp database {tmp_db_name}") |
| 127 | + |
| 128 | + # 2. Lock row, verify still EXECUTING, mark FAILED — all atomic |
| 129 | + with rls_transaction(str(scan.tenant_id)): |
| 130 | + try: |
| 131 | + fresh_scan = AttackPathsScan.objects.select_for_update().get(id=scan.id) |
| 132 | + except AttackPathsScan.DoesNotExist: |
| 133 | + logger.warning(f"Scan {scan_id_str} no longer exists, skipping") |
| 134 | + return False |
| 135 | + |
| 136 | + if fresh_scan.state != StateChoices.EXECUTING: |
| 137 | + logger.info(f"Scan {scan_id_str} is now {fresh_scan.state}, skipping") |
| 138 | + return False |
| 139 | + |
| 140 | + _mark_scan_finished(fresh_scan, StateChoices.FAILED, {"global_error": reason}) |
| 141 | + |
| 142 | + # 3. Mark `TaskResult` as `FAILURE` (not RLS-protected, outside lock) |
| 143 | + if task_result: |
| 144 | + task_result.status = states.FAILURE |
| 145 | + task_result.date_done = datetime.now(tz=timezone.utc) |
| 146 | + task_result.save(update_fields=["status", "date_done"]) |
| 147 | + |
| 148 | + # 4. Recover graph_data_ready if provider data still exists |
| 149 | + recover_graph_data_ready(fresh_scan) |
| 150 | + |
| 151 | + logger.info(f"Cleaned up stale scan {scan_id_str}: {reason}") |
| 152 | + return True |
0 commit comments