|
| 1 | +""" |
| 2 | +Cypher sanitizer for custom (user-supplied) Attack Paths queries. |
| 3 | +
|
| 4 | +Two responsibilities: |
| 5 | +
|
| 6 | +1. **Validation** - reject queries containing SSRF or dangerous procedure |
| 7 | + patterns (defense-in-depth; the primary control is ``neo4j.READ_ACCESS``). |
| 8 | +
|
| 9 | +2. **Provider-scoped label injection** - inject a dynamic |
| 10 | + ``_Provider_{uuid}`` label into every node pattern so the database can |
| 11 | + use its native label index for provider isolation. |
| 12 | +
|
| 13 | +Label-injection pipeline: |
| 14 | +
|
| 15 | +1. **Protect** string literals and line comments (placeholder replacement). |
| 16 | +2. **Split** by top-level clause keywords to track clause context. |
| 17 | +3. **Pass A** - inject into *labeled* node patterns in ALL segments. |
| 18 | +4. **Pass B** - inject into *bare* node patterns in MATCH segments only. |
| 19 | +5. **Restore** protected regions. |
| 20 | +""" |
| 21 | + |
| 22 | +import re |
| 23 | + |
| 24 | +from rest_framework.exceptions import ValidationError |
| 25 | + |
| 26 | +from tasks.jobs.attack_paths.config import get_provider_label |
| 27 | + |
| 28 | + |
| 29 | +# Step 1 - String / comment protection |
| 30 | +# Single combined regex: strings first, then line comments. |
| 31 | +# The regex engine finds the leftmost match, so a string like 'https://prowler.com' |
| 32 | +# is consumed as a string before the // inside it can match as a comment. |
| 33 | +_PROTECTED_RE = re.compile(r"'(?:[^'\\]|\\.)*'|\"(?:[^\"\\]|\\.)*\"|//[^\n]*") |
| 34 | + |
| 35 | +# Step 2 - Clause splitting |
| 36 | +# OPTIONAL MATCH must come before MATCH to avoid partial matching. |
| 37 | +_CLAUSE_RE = re.compile( |
| 38 | + r"\b(OPTIONAL\s+MATCH|MATCH|WHERE|RETURN|WITH|ORDER\s+BY" |
| 39 | + r"|SKIP|LIMIT|UNION|UNWIND|CALL)\b", |
| 40 | + re.IGNORECASE, |
| 41 | +) |
| 42 | + |
| 43 | +# Pass A - Labeled node patterns (all segments) |
| 44 | +# Matches node patterns that have at least one :Label. |
| 45 | +# (?<!\w)\( - open paren NOT preceded by a word char (excludes function calls). |
| 46 | +# Group 1: optional variable + one or more :Label |
| 47 | +# Group 2: optional {properties} + closing paren |
| 48 | +_LABELED_NODE_RE = re.compile( |
| 49 | + r"(?<!\w)\(" |
| 50 | + r"(" |
| 51 | + r"\s*(?:[a-zA-Z_]\w*)?" |
| 52 | + r"(?:\s*:\s*(?:`[^`]*`|[a-zA-Z_]\w*))+" |
| 53 | + r")" |
| 54 | + r"(" |
| 55 | + r"\s*(?:\{[^}]*\})?" |
| 56 | + r"\s*\)" |
| 57 | + r")" |
| 58 | +) |
| 59 | + |
| 60 | +# Pass B - Bare node patterns (MATCH segments only) |
| 61 | +# Matches (identifier) or (identifier {properties}) without any :Label. |
| 62 | +# Only applied in MATCH/OPTIONAL MATCH segments. |
| 63 | +_BARE_NODE_RE = re.compile( |
| 64 | + r"(?<!\w)\(" r"(\s*[a-zA-Z_]\w*)" r"(\s*(?:\{[^}]*\})?)" r"\s*\)" |
| 65 | +) |
| 66 | + |
| 67 | +_MATCH_CLAUSES = frozenset({"MATCH", "OPTIONAL MATCH"}) |
| 68 | + |
| 69 | + |
| 70 | +def _inject_labeled(segment: str, label: str) -> str: |
| 71 | + """Inject provider label into all node patterns that have existing labels.""" |
| 72 | + return _LABELED_NODE_RE.sub(rf"(\1:{label}\2", segment) |
| 73 | + |
| 74 | + |
| 75 | +def _inject_bare(segment: str, label: str) -> str: |
| 76 | + """Inject provider label into bare `(identifier)` node patterns.""" |
| 77 | + |
| 78 | + def _replace(match): |
| 79 | + var = match.group(1) |
| 80 | + props = match.group(2).strip() |
| 81 | + if props: |
| 82 | + return f"({var}:{label} {props})" |
| 83 | + return f"({var}:{label})" |
| 84 | + |
| 85 | + return _BARE_NODE_RE.sub(_replace, segment) |
| 86 | + |
| 87 | + |
| 88 | +def inject_provider_label(cypher: str, provider_id: str) -> str: |
| 89 | + """Rewrite a Cypher query to scope every node pattern to a provider. |
| 90 | +
|
| 91 | + Args: |
| 92 | + cypher: The original Cypher query string. |
| 93 | + provider_id: The provider UUID (will be converted to a label via |
| 94 | + `get_provider_label`). |
| 95 | +
|
| 96 | + Returns: |
| 97 | + The rewritten Cypher with `:_Provider_{uuid}` appended to every |
| 98 | + node pattern. |
| 99 | + """ |
| 100 | + label = get_provider_label(provider_id) |
| 101 | + |
| 102 | + # Step 1: Protect strings and comments (single pass, leftmost-first) |
| 103 | + protected: list[str] = [] |
| 104 | + |
| 105 | + def _save(match): |
| 106 | + protected.append(match.group(0)) |
| 107 | + return f"\x00P{len(protected) - 1}\x00" |
| 108 | + |
| 109 | + work = _PROTECTED_RE.sub(_save, cypher) |
| 110 | + |
| 111 | + # Step 2: Split by clause keywords |
| 112 | + parts = _CLAUSE_RE.split(work) |
| 113 | + |
| 114 | + # Steps 3-4: Apply injection passes per segment |
| 115 | + result: list[str] = [] |
| 116 | + current_clause: str | None = None |
| 117 | + |
| 118 | + for i, part in enumerate(parts): |
| 119 | + if i % 2 == 1: |
| 120 | + # Keyword token - normalize for clause tracking |
| 121 | + current_clause = re.sub(r"\s+", " ", part.strip()).upper() |
| 122 | + result.append(part) |
| 123 | + else: |
| 124 | + # Content segment - apply injection based on clause context |
| 125 | + part = _inject_labeled(part, label) |
| 126 | + if current_clause in _MATCH_CLAUSES: |
| 127 | + part = _inject_bare(part, label) |
| 128 | + result.append(part) |
| 129 | + |
| 130 | + work = "".join(result) |
| 131 | + |
| 132 | + # Step 5: Restore protected regions |
| 133 | + for i, original in enumerate(protected): |
| 134 | + work = work.replace(f"\x00P{i}\x00", original) |
| 135 | + |
| 136 | + return work |
| 137 | + |
| 138 | + |
| 139 | +# --------------------------------------------------------------------------- |
| 140 | +# Validation |
| 141 | +# --------------------------------------------------------------------------- |
| 142 | + |
| 143 | +# Patterns that indicate SSRF or dangerous procedure calls |
| 144 | +# Defense-in-depth layer - the primary control is `neo4j.READ_ACCESS` |
| 145 | +_BLOCKED_PATTERNS = [ |
| 146 | + re.compile(r"\bLOAD\s+CSV\b", re.IGNORECASE), |
| 147 | + re.compile(r"\bapoc\.load\b", re.IGNORECASE), |
| 148 | + re.compile(r"\bapoc\.import\b", re.IGNORECASE), |
| 149 | + re.compile(r"\bapoc\.export\b", re.IGNORECASE), |
| 150 | + re.compile(r"\bapoc\.cypher\b", re.IGNORECASE), |
| 151 | + re.compile(r"\bapoc\.systemdb\b", re.IGNORECASE), |
| 152 | + re.compile(r"\bapoc\.config\b", re.IGNORECASE), |
| 153 | + re.compile(r"\bapoc\.periodic\b", re.IGNORECASE), |
| 154 | + re.compile(r"\bapoc\.do\b", re.IGNORECASE), |
| 155 | + re.compile(r"\bapoc\.trigger\b", re.IGNORECASE), |
| 156 | + re.compile(r"\bapoc\.custom\b", re.IGNORECASE), |
| 157 | +] |
| 158 | + |
| 159 | + |
| 160 | +def validate_custom_query(cypher: str) -> None: |
| 161 | + """Reject queries containing known SSRF or dangerous procedure patterns. |
| 162 | +
|
| 163 | + Raises ValidationError if a blocked pattern is found. |
| 164 | + String literals and comments are stripped before matching to avoid |
| 165 | + false positives. |
| 166 | + """ |
| 167 | + stripped = _PROTECTED_RE.sub("", cypher) |
| 168 | + for pattern in _BLOCKED_PATTERNS: |
| 169 | + if pattern.search(stripped): |
| 170 | + raise ValidationError({"query": "Query contains a blocked operation"}) |
0 commit comments