diff --git a/gnat/connectors/cortex_xdr/client.py b/gnat/connectors/cortex_xdr/client.py index 6b585f55..aae7d8f4 100644 --- a/gnat/connectors/cortex_xdr/client.py +++ b/gnat/connectors/cortex_xdr/client.py @@ -294,6 +294,73 @@ def isolate_endpoint(self, endpoint_id: str) -> dict[str, Any]: ) return resp if isinstance(resp, dict) else {} + # ── Investigation sub-API ────────────────────────────────────────────── + + def get_incident_alerts(self, incident_id: str) -> list[dict[str, Any]]: + """ + Return the alerts that belong to a Cortex XDR incident. + + Calls ``get_incident_extra_data`` and extracts the ``alerts`` list + from the enriched response. + + Parameters + ---------- + incident_id : str + Cortex XDR incident ID. + """ + extra = self.get_incident_extra_data(incident_id) + alerts = extra.get("alerts", {}) + if isinstance(alerts, dict): + return alerts.get("data", []) + if isinstance(alerts, list): + return alerts + return [] + + def get_incident_artifacts(self, incident_id: str) -> list[dict[str, Any]]: + """ + Return network / file artifacts observed in a Cortex XDR incident. + + Extracts ``network_artifacts`` and ``file_artifacts`` from the + enriched incident response and merges them into a single list. + + Parameters + ---------- + incident_id : str + Cortex XDR incident ID. + """ + extra = self.get_incident_extra_data(incident_id) + artifacts: list[dict[str, Any]] = [] + for key in ("network_artifacts", "file_artifacts"): + bucket = extra.get(key, {}) + if isinstance(bucket, dict): + artifacts.extend(bucket.get("data", [])) + elif isinstance(bucket, list): + artifacts.extend(bucket) + return artifacts + + def search_indicators_by_value(self, value: str) -> list[dict[str, Any]]: + """ + Search XDR/XSIAM threat indicators by exact value. + + Parameters + ---------- + value : str + Indicator value (IP, domain, hash, etc.) to search for. + """ + resp = self.post( + "/public_api/v1/indicators/", + json={ + "request_data": { + "filters": [ + {"field": "indicator_value", "operator": "eq", "value": value} + ], + "page_size": 100, + "page_number": 0, + } + }, + ) + return resp.get("reply", {}).get("indicators", []) if isinstance(resp, dict) else [] + def get_indicators( self, ioc_type: str | None = None, diff --git a/gnat/connectors/greymatter/client.py b/gnat/connectors/greymatter/client.py index bb751a45..d4f24c24 100644 --- a/gnat/connectors/greymatter/client.py +++ b/gnat/connectors/greymatter/client.py @@ -36,10 +36,21 @@ * ``/v1/observables`` — observable values (IPs, domains, hashes, URLs) * ``/v1/indicators`` — compound indicators with patterns -* ``/v1/incidents`` — security incidents +* ``/v1/incidents`` — security investigations / cases (``observed-data``) * ``/v1/threat-actors`` — threat actor entities * ``/v1/malware`` — malware families / samples * ``/v1/vulnerabilities`` — CVE / vulnerability records + +Investigation CRUD +------------------ +Pass ``stix_type="observed-data"`` to the standard CRUD methods to interact +with GreyMatter investigations (cases):: + + client.list_objects("observed-data") + client.get_object("observed-data", case_uuid) + client.upsert_object("observed-data", {"title": "APT28 Campaign"}) + +Use :meth:`link_investigation` to link a STIX observable to an existing case. """ from __future__ import annotations @@ -67,17 +78,18 @@ class GreyMatterClient(BaseClient, ConnectorMixin): """ stix_type_map: dict[str, str] = { - "indicator": "observables", - "threat-actor": "threat-actors", - "malware": "malware", - "vulnerability": "vulnerabilities", + "indicator": "observables", + "threat-actor": "threat-actors", + "malware": "malware", + "vulnerability": "vulnerabilities", "attack-pattern": "attack-patterns", + "observed-data": "incidents", } # GreyMatter observable type → STIX pattern template _OBS_PATTERN: dict[str, str] = { - "ipv4": "[ipv4-addr:value = '{v}']", - "ipv6": "[ipv6-addr:value = '{v}']", + "ipv4": "[ipv4-addr:value = '{v}']", + "ipv6": "[ipv6-addr:value = '{v}']", "domain": "[domain-name:value = '{v}']", "url": "[url:value = '{v}']", "md5": "[file:hashes.MD5 = '{v}']", @@ -160,7 +172,7 @@ def list_objects( """ resource = self._resolve(stix_type) params: dict[str, Any] = { - "limit": page_size, + "limit": page_size, "offset": (page - 1) * page_size, } if filters: @@ -168,13 +180,9 @@ def list_objects( resp = self.get(f"/v1/{resource}", params=params) return resp.get("data", []) if isinstance(resp, dict) else [] - def upsert_object( - self, - stix_type: str, - payload: dict[str, Any], - linked_cases: list[str] | None = None, - **kwargs: Any, - ) -> dict[str, Any]: + def upsert_object(self, stix_type: str, payload: dict[str, Any], + linked_cases: list[str] | None = None, + **kwargs: Any) -> dict[str, Any]: """ Create or update a GreyMatter object. @@ -205,16 +213,26 @@ def delete_object(self, stix_type: str, object_id: str) -> None: def to_stix(self, native: dict[str, Any]) -> dict[str, Any]: """ - Translate a GreyMatter observable/entity dict to STIX 2.1. + Translate a GreyMatter observable/entity or incident dict to STIX 2.1. + + Dispatches to :meth:`_incident_to_stix` when the native record looks + like an investigation/case (detected by ``case_number`` or + ``assigned_to`` fields), otherwise maps to a STIX Indicator. - Handles both observable-value records and full entity records. + Parameters + ---------- + native : dict + Raw GreyMatter API response. """ data = native.get("data", native) - gm_type = data.get("type", "") - value = data.get("value", data.get("name", "")) - pattern = self._OBS_PATTERN.get(gm_type, "[unknown:value = '{v}']").format( - v=value.replace("'", "\\'") - ) + # Investigations/cases have case_number or assigned_to; observables have type+value + if "case_number" in data or "assigned_to" in data: + return self._incident_to_stix(data) + gm_type = data.get("type", "") + value = data.get("value", data.get("name", "")) + pattern = self._OBS_PATTERN.get( + gm_type, "[unknown:value = '{v}']" + ).format(v=value.replace("'", "\\'")) return { "type": "indicator", @@ -233,6 +251,38 @@ def to_stix(self, native: dict[str, Any]) -> dict[str, Any]: "x_tlp": data.get("tlp", "white"), } + @staticmethod + def _incident_to_stix(data: dict[str, Any]) -> dict[str, Any]: + """ + Map a GreyMatter investigation/case record to STIX ``observed-data``. + + Parameters + ---------- + data : dict + GreyMatter incident/case record (with ``case_number``, + ``assigned_to``, ``status``, ``severity`` etc.). + """ + created = data.get("created_at", "") + modified = data.get("updated_at", created) + return { + "type": "observed-data", + "id": f"observed-data--{data.get('id', '')}", + "created": created, + "modified": modified, + "first_observed": created, + "last_observed": modified, + "number_observed": 1, + "object_refs": [], + "name": data.get("title", data.get("name", "")), + "description": data.get("description", ""), + "x_gm_case_number": data.get("case_number", ""), + "x_gm_status": data.get("status", ""), + "x_gm_severity": data.get("severity", ""), + "x_gm_assigned_to": data.get("assigned_to", ""), + "x_gm_tags": data.get("tags", []), + "x_tlp": data.get("tlp", "white"), + } + def from_stix(self, stix_dict: dict[str, Any]) -> dict[str, Any]: """ Translate a STIX Indicator dict to a GreyMatter observable payload. @@ -293,6 +343,65 @@ def link_investigation( } return self.post(f"/v1/incidents/{case_id}/linked_observables", json=payload) + # ── Evidence expansion ──────────────────────────────────────────────── + + def get_investigation_observables(self, case_id: str) -> list[dict[str, Any]]: + """ + Return all observables linked to a GreyMatter investigation/case. + + Calls ``GET /v1/incidents/{case_id}/linked_observables``. + + Parameters + ---------- + case_id : str + GreyMatter investigation / case UUID. + + Returns + ------- + list of dict + Raw GreyMatter observable records. + """ + resp = self.get(f"/v1/incidents/{self._to_gm_id(case_id)}/linked_observables") + return resp.get("data", []) if isinstance(resp, dict) else [] + + def get_investigation_tasks(self, case_id: str) -> list[dict[str, Any]]: + """ + Return tasks associated with a GreyMatter investigation/case. + + Calls ``GET /v1/incidents/{case_id}/tasks``. + + Parameters + ---------- + case_id : str + GreyMatter investigation / case UUID. + + Returns + ------- + list of dict + Raw GreyMatter task records. + """ + resp = self.get(f"/v1/incidents/{self._to_gm_id(case_id)}/tasks") + return resp.get("data", []) if isinstance(resp, dict) else [] + + def search_observables_by_value(self, value: str) -> list[dict[str, Any]]: + """ + Search GreyMatter observables by value (IP, domain, hash, email, …). + + Calls ``GET /v1/observables?value={value}``. + + Parameters + ---------- + value : str + Observable value to search for. + + Returns + ------- + list of dict + Raw GreyMatter observable records. + """ + resp = self.get("/v1/observables", params={"value": value, "limit": 50}) + return resp.get("data", []) if isinstance(resp, dict) else [] + # ── Helpers ──────────────────────────────────────────────────────────── def _resolve(self, stix_type: str) -> str: @@ -312,21 +421,21 @@ def _to_gm_id(stix_or_plain_id: str) -> str: @staticmethod def _infer_gm_type(pattern: str) -> str: pattern = pattern.lower() - if "ipv4-addr" in pattern: + if "ipv4-addr" in pattern: return "ipv4" - if "ipv6-addr" in pattern: + if "ipv6-addr" in pattern: return "ipv6" if "domain-name" in pattern: return "domain" - if "url:" in pattern: + if "url:" in pattern: return "url" - if "sha-256" in pattern: + if "sha-256" in pattern: return "sha256" - if "sha-1" in pattern: + if "sha-1" in pattern: return "sha1" - if "md5" in pattern: + if "md5" in pattern: return "md5" - if "email-addr" in pattern: + if "email-addr" in pattern: return "email" return "unknown" diff --git a/gnat/connectors/servicenow_secops/client.py b/gnat/connectors/servicenow_secops/client.py index 8a905436..cd66b34a 100644 --- a/gnat/connectors/servicenow_secops/client.py +++ b/gnat/connectors/servicenow_secops/client.py @@ -280,6 +280,67 @@ def annotate_incident( ) return resp.get("result", {}) if isinstance(resp, dict) else {} + # ── Investigation sub-API ──────────────────────────────────────────── + + def get_incident_tasks(self, incident_sys_id: str) -> list[dict[str, Any]]: + """ + Return security tasks linked to a SIR incident. + + Queries the ``sn_si_task`` table filtered by parent ``sys_id``. + + Parameters + ---------- + incident_sys_id : str + ServiceNow ``sys_id`` of the parent SIR incident. + """ + resp = self.get( + f"{_TABLE_BASE}/sn_si_task", + params={ + "sysparm_query": f"parent={incident_sys_id}", + "sysparm_limit": 200, + }, + ) + return resp.get("result", []) if isinstance(resp, dict) else [] + + def get_incident_observables(self, incident_sys_id: str) -> list[dict[str, Any]]: + """ + Return threat intelligence observables linked to a SIR incident. + + Queries the ``sn_ti_observable`` table filtered by the incident + reference field. + + Parameters + ---------- + incident_sys_id : str + ServiceNow ``sys_id`` of the parent SIR incident. + """ + resp = self.get( + f"{_TABLE_BASE}/sn_ti_observable", + params={ + "sysparm_query": f"incident={incident_sys_id}", + "sysparm_limit": 200, + }, + ) + return resp.get("result", []) if isinstance(resp, dict) else [] + + def search_indicators_by_value(self, value: str) -> list[dict[str, Any]]: + """ + Search TIARA observables by value (IP, domain, hash, URL, etc.). + + Parameters + ---------- + value : str + Observable value to search for. + """ + resp = self.get( + f"{_TABLE_BASE}/sn_ti_observable", + params={ + "sysparm_query": f"valueLIKE{value}", + "sysparm_limit": 100, + }, + ) + return resp.get("result", []) if isinstance(resp, dict) else [] + # ── Vulnerability Response (VR) helpers ────────────────────────────── def list_vulnerable_items( diff --git a/gnat/connectors/thehive/client.py b/gnat/connectors/thehive/client.py index 353beaa3..57f80613 100644 --- a/gnat/connectors/thehive/client.py +++ b/gnat/connectors/thehive/client.py @@ -153,6 +153,72 @@ def delete_object(self, stix_type: str, object_id: str) -> None: # Extra helpers # ------------------------------------------------------------------ + # ------------------------------------------------------------------ + # Investigation sub-API + # ------------------------------------------------------------------ + + def get_case_observables(self, case_id: str) -> list[dict[str, Any]]: + """ + Return all observables (IOCs) attached to a TheHive case. + + Calls ``GET /api/v1/case/{case_id}/observable``. + + Parameters + ---------- + case_id : str + TheHive internal case ID (``_id`` field). + """ + resp = self.get(f"{_API}/case/{case_id}/observable") + if isinstance(resp, list): + return resp + return resp.get("items", []) if isinstance(resp, dict) else [] + + def get_case_tasks(self, case_id: str) -> list[dict[str, Any]]: + """ + Return all tasks linked to a TheHive case. + + Uses ``POST /api/v1/query`` with a ``listTask`` filter. + + Parameters + ---------- + case_id : str + TheHive internal case ID. + """ + query: dict[str, Any] = { + "query": [ + {"_name": "getCase", "idOrName": case_id}, + {"_name": "tasks"}, + ] + } + resp = self.post(f"{_API}/query", json=query) + if isinstance(resp, list): + return resp + return resp.get("items", []) if isinstance(resp, dict) else [] + + def search_observables_by_value(self, value: str) -> list[dict[str, Any]]: + """ + Search observables across all cases by data value. + + Uses ``POST /api/v1/query`` with a ``like`` filter on the + ``data`` field. + + Parameters + ---------- + value : str + Observable value to search for (IP, domain, hash, etc.). + """ + query: dict[str, Any] = { + "query": [ + {"_name": "listObservable"}, + {"_name": "filter", "_like": {"_field": "data", "_value": value}}, + {"_name": "page", "from": 0, "to": 100}, + ] + } + resp = self.post(f"{_API}/query", json=query) + if isinstance(resp, list): + return resp + return resp.get("items", []) if isinstance(resp, dict) else [] + def add_observable(self, case_id: str, stix_obj: dict[str, Any]) -> dict[str, Any]: """ Add an observable (IOC) to an existing TheHive case. diff --git a/gnat/connectors/threatq/client.py b/gnat/connectors/threatq/client.py index d7c139e3..f6187070 100644 --- a/gnat/connectors/threatq/client.py +++ b/gnat/connectors/threatq/client.py @@ -31,6 +31,23 @@ +--------------------+---------------------------+ | attack-pattern | attack-pattern | +--------------------+---------------------------+ +| observed-data | event | ++--------------------+---------------------------+ + +Investigation Linking +--------------------- +ThreatQ *Events* are the investigation container. Use :meth:`link_event` +to associate a STIX indicator with an existing event, or pass ``event_id`` +to :meth:`upsert_object` to link automatically on write:: + + client.link_event("42", stix_indicator) + client.upsert_object("indicator", payload, event_id="42") + +Pass ``stix_type="observed-data"`` to standard CRUD methods to interact +with ThreatQ Events directly:: + + client.list_objects("observed-data") + client.get_object("observed-data", "42") Sector / Industry attributes ----------------------------- @@ -94,11 +111,12 @@ class ThreatQClient(BaseClient, ConnectorMixin): """ stix_type_map: dict[str, str] = { - "indicator": "indicator", - "threat-actor": "adversary", - "malware": "malware", + "indicator": "indicator", + "threat-actor": "adversary", + "malware": "malware", "vulnerability": "vulnerability", "attack-pattern": "attack-pattern", + "observed-data": "event", } def __init__( @@ -177,25 +195,54 @@ def list_objects( page: int = 1, page_size: int = 100, ) -> list[dict[str, Any]]: - """Return a paginated list of ThreatQ objects (includes attributes).""" + """ + Return a paginated list of ThreatQ objects. + + For indicator-like types, ``?with=attributes`` is automatically + appended so sector/industry data is included. Events (``observed-data``) + do not use this parameter. + """ resource = self._resolve_resource(stix_type) params: dict[str, Any] = { - "limit": page_size, + "limit": page_size, "offset": (page - 1) * page_size, - "with": "attributes", } + if stix_type != "observed-data": + params["with"] = "attributes" if filters: params.update(filters) resp = self.get(f"/api/{resource}", params=params) return resp.get("data", []) if isinstance(resp, dict) else [] - def upsert_object(self, stix_type: str, payload: dict[str, Any]) -> dict[str, Any]: - """Create or update a ThreatQ object.""" + def upsert_object( + self, + stix_type: str, + payload: dict[str, Any], + event_id: Optional[str] = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Create or update a ThreatQ object. + + Parameters + ---------- + stix_type : str + STIX type (used to resolve the API resource path). + payload : dict + Object fields. An ``"id"`` key triggers an update (PUT). + event_id : str, optional + For indicator writes only: if provided, the indicator is linked + to this ThreatQ Event after upsert via :meth:`link_event`. + """ resource = self._resolve_resource(stix_type) tq_id = payload.pop("id", None) if tq_id: - return self.put(f"/api/{resource}/{tq_id}", json=payload) - return self.post(f"/api/{resource}", json=payload) + result = self.put(f"/api/{resource}/{tq_id}", json=payload) + else: + result = self.post(f"/api/{resource}", json=payload) + if event_id and stix_type != "observed-data": + self.link_event(event_id, payload) + return result def delete_object(self, stix_type: str, object_id: str) -> None: """Delete a ThreatQ object.""" @@ -230,32 +277,35 @@ def get_attribute_types(self) -> list[str]: def to_stix(self, native: dict[str, Any]) -> dict[str, Any]: """ - Translate a ThreatQ indicator dict to STIX 2.1 format. + Translate a ThreatQ native object to STIX 2.1. - Sector and industry context is extracted from the ``attributes`` array - when present (requires ``?with=attributes`` on the originating request, - which :meth:`get_object` and :meth:`list_objects` both include). - Matched attribute values are written to ``x_target_sectors``. + Dispatches to :meth:`_event_to_stix` for ThreatQ Events (detected by + ``happened_at`` or ``event_type`` fields) and to the indicator path + for all other objects. Sector/industry attributes are extracted from + the ``attributes`` array when present. Parameters ---------- native : dict - Raw ThreatQ API indicator object. + Raw ThreatQ API response (indicator or event record). Returns ------- dict - Partial STIX Indicator dict. + STIX 2.1 SDO (``indicator`` or ``observed-data``). """ data = native.get("data", native) + # Events have happened_at or event_type; indicators have value+type + if "happened_at" in data or "event_type" in data: + return self._event_to_stix(data) stix: dict[str, Any] = { - "type": "indicator", - "id": f"indicator--{data.get('id', '')}", - "name": data.get("value", ""), - "pattern": f"[{data.get('type', 'unknown')}:value = '{data.get('value', '')}']", - "pattern_type": "stix", - "created": data.get("created_at", ""), - "modified": data.get("updated_at", ""), + "type": "indicator", + "id": f"indicator--{data.get('id', '')}", + "name": data.get("value", ""), + "pattern": f"[{data.get('type', 'unknown')}:value = '{data.get('value', '')}']", + "pattern_type": "stix", + "created": data.get("created_at", ""), + "modified": data.get("updated_at", ""), "indicator_types": [data.get("class", "unknown")], } sectors = self._extract_sectors(data.get("attributes", [])) @@ -263,6 +313,35 @@ def to_stix(self, native: dict[str, Any]) -> dict[str, Any]: stix["x_target_sectors"] = sectors return stix + @staticmethod + def _event_to_stix(data: dict[str, Any]) -> dict[str, Any]: + """ + Map a ThreatQ Event record to a STIX ``observed-data`` SDO. + + Parameters + ---------- + data : dict + ThreatQ event record (with ``title``, ``happened_at``, + ``event_type``, etc.). + """ + created = data.get("created_at", "") + modified = data.get("updated_at", created) + happened = data.get("happened_at", created) + return { + "type": "observed-data", + "id": f"observed-data--{data.get('id', '')}", + "created": created, + "modified": modified, + "first_observed": happened, + "last_observed": happened, + "number_observed": 1, + "object_refs": [], + "name": data.get("title", ""), + "description": data.get("description", ""), + "x_tq_event_type": data.get("event_type", ""), + "x_tq_event_id": str(data.get("id", "")), + } + def from_stix(self, stix_dict: dict[str, Any]) -> dict[str, Any]: """ Translate a STIX Indicator dict to a ThreatQ API payload. @@ -283,6 +362,110 @@ def from_stix(self, stix_dict: dict[str, Any]) -> dict[str, Any]: "status": {"name": "Active"}, } + # ------------------------------------------------------------------ + # Investigation linking + # ------------------------------------------------------------------ + + def link_event( + self, + event_id: str, + stix_obj: dict[str, Any], + ) -> dict[str, Any]: + """ + Link a STIX indicator to an existing ThreatQ Event. + + Calls ``POST /api/events/{event_id}/indicators`` with an indicator + payload derived from *stix_obj*, associating the threat intelligence + with the event (investigation) record. + + Parameters + ---------- + event_id : str + ThreatQ Event numeric ID (or STIX id — the numeric portion is + extracted automatically). + stix_obj : dict + STIX 2.1 indicator SDO (or any dict with ``name``/``pattern``). + + Returns + ------- + dict + Raw ThreatQ API response. + """ + tq_id = self._extract_numeric_id(event_id) + payload = { + "value": stix_obj.get("name", ""), + "type": self._infer_tq_type(stix_obj.get("pattern", "")), + "status": {"name": "Active"}, + } + return self.post(f"/api/events/{tq_id}/indicators", json=payload) + + # ------------------------------------------------------------------ + # Evidence expansion + # ------------------------------------------------------------------ + + def get_event_indicators(self, event_id: str) -> list[dict[str, Any]]: + """ + Return all indicators linked to a ThreatQ Event. + + Calls ``GET /api/events/{event_id}/indicators``. + + Parameters + ---------- + event_id : str + ThreatQ Event numeric ID (or STIX id — the numeric portion is + extracted automatically). + + Returns + ------- + list of dict + Raw ThreatQ indicator records. + """ + tq_id = self._extract_numeric_id(event_id) + resp = self.get(f"/api/events/{tq_id}/indicators", params={"with": "attributes"}) + return resp.get("data", []) if isinstance(resp, dict) else [] + + def get_event_adversaries(self, event_id: str) -> list[dict[str, Any]]: + """ + Return all adversaries (threat actors) linked to a ThreatQ Event. + + Calls ``GET /api/events/{event_id}/adversaries``. + + Parameters + ---------- + event_id : str + ThreatQ Event numeric ID (or STIX id). + + Returns + ------- + list of dict + Raw ThreatQ adversary records. + """ + tq_id = self._extract_numeric_id(event_id) + resp = self.get(f"/api/events/{tq_id}/adversaries") + return resp.get("data", []) if isinstance(resp, dict) else [] + + def search_indicators_by_value(self, value: str) -> list[dict[str, Any]]: + """ + Search ThreatQ indicators by value. + + Calls ``GET /api/indicators?search={value}&with=attributes``. + + Parameters + ---------- + value : str + Indicator value to search for (IP, domain, hash, URL, …). + + Returns + ------- + list of dict + Raw ThreatQ indicator records (includes attributes array). + """ + resp = self.get( + "/api/indicators", + params={"search": value, "with": "attributes", "limit": 50}, + ) + return resp.get("data", []) if isinstance(resp, dict) else [] + # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ @@ -291,6 +474,8 @@ def _resolve_resource(self, stix_type: str) -> str: resource = self.stix_type_map.get(stix_type) if not resource: raise GNATClientError(f"ThreatQ: unsupported STIX type '{stix_type}'") + if stix_type == "observed-data": + return "events" # already the plural form used by the ThreatQ API return resource + "s" # ThreatQ uses plural endpoints @staticmethod diff --git a/gnat/connectors/xsoar/client.py b/gnat/connectors/xsoar/client.py index 2025bf2d..37e51c8b 100644 --- a/gnat/connectors/xsoar/client.py +++ b/gnat/connectors/xsoar/client.py @@ -11,75 +11,282 @@ api_key = auth_id = ; optional for multi-tenant auth_type = api_key + +STIX Type Mapping +----------------- ++--------------------+-----------------------------------+ +| STIX Type | XSOAR Resource | ++====================+===================================+ +| indicator | indicator (IOC) | ++--------------------+-----------------------------------+ +| malware | indicator | ++--------------------+-----------------------------------+ +| threat-actor | indicator | ++--------------------+-----------------------------------+ +| vulnerability | indicator | ++--------------------+-----------------------------------+ +| observed-data | incident | ++--------------------+-----------------------------------+ + +Investigation Linking +--------------------- +Use :meth:`link_incident` to associate a STIX indicator with an existing +XSOAR incident, or pass ``incident_id`` to :meth:`upsert_object` to link +automatically on write:: + + client.link_incident("1234", stix_indicator) + client.upsert_object("indicator", payload, incident_id="1234") """ -from typing import Any, Optional +from __future__ import annotations + +from typing import Any from gnat.clients.base import BaseClient from gnat.connectors.base_connector import ConnectorMixin +# Map XSOAR native indicator_type (lowercase) → STIX object-path prefix +_XSOAR_TYPE_TO_STIX: dict[str, str] = { + "ip": "ipv4-addr:value", + "ipv4": "ipv4-addr:value", + "ip address": "ipv4-addr:value", + "ipv6": "ipv6-addr:value", + "ipv6 address": "ipv6-addr:value", + "domain": "domain-name:value", + "hostname": "domain-name:value", + "fqdn": "domain-name:value", + "url": "url:value", + "file": "file:hashes.MD5", + "file sha-256": "file:hashes.SHA-256", + "file sha-1": "file:hashes.SHA-1", + "file sha1": "file:hashes.SHA-1", + "file md5": "file:hashes.MD5", + "md5": "file:hashes.MD5", + "sha256": "file:hashes.SHA-256", + "sha-256": "file:hashes.SHA-256", + "sha1": "file:hashes.SHA-1", + "sha-1": "file:hashes.SHA-1", + "email": "email-addr:value", + "email address": "email-addr:value", +} + class XSOARClient(BaseClient, ConnectorMixin): - """HTTP client for the XSOAR 6 REST API.""" + """ + HTTP client for the XSOAR 6 REST API. + + Supports both indicator (threat intel) and incident (investigation) + resources. Pass ``stix_type="observed-data"`` to CRUD methods to + interact with the incident sub-API. + + Parameters + ---------- + host : str + XSOAR base URL, e.g. ``"https://xsoar.example.com"``. + api_key : str + XSOAR API key. + auth_id : str + Multi-tenant auth ID header (optional). + verify_ssl : bool + TLS certificate verification. Default ``True``. + """ stix_type_map: dict[str, str] = { - "indicator": "indicator", - "malware": "indicator", - "threat-actor": "indicator", + "indicator": "indicator", + "malware": "indicator", + "threat-actor": "indicator", "vulnerability": "indicator", + "observed-data": "incident", } - def __init__(self, host: str, api_key: str = "", auth_id: str = "", **kwargs: Any): + def __init__( + self, + host: str, + api_key: str = "", + auth_id: str = "", + **kwargs: Any, + ) -> None: super().__init__(host=host, **kwargs) self._api_key = api_key self._auth_id = auth_id + # ── Authentication ───────────────────────────────────────────────────── + def authenticate(self) -> None: - """Inject the XSOAR API key header.""" + """Inject the XSOAR API key (and optional multi-tenant auth-id) headers.""" self._auth_headers["Authorization"] = self._api_key if self._auth_id: self._auth_headers["x-xdr-auth-id"] = self._auth_id + # ── ConnectorMixin — CRUD ───────────────────────────────────────────── + def health_check(self) -> bool: + """Return True if the XSOAR instance is reachable.""" self.get("/health") return True def get_object(self, stix_type: str, object_id: str) -> dict[str, Any]: - resp = self.post("/indicators/search", json={"query": f"id:{object_id}", "size": 1}) + """ + Fetch a single XSOAR object by type and id. + + Parameters + ---------- + stix_type : str + ``"observed-data"`` to fetch an incident, any other supported + type to search indicators. + object_id : str + XSOAR incident id (numeric string) or indicator id. + """ + if stix_type == "observed-data": + resp = self.get(f"/incident/{object_id}") + return resp if isinstance(resp, dict) else {} + # Indicator path + resp = self.post("/indicators/search", json={ + "query": f"id:{object_id}", "size": 1 + }) items = resp.get("iocObjects", []) if isinstance(resp, dict) else [] return items[0] if items else {} def list_objects( self, stix_type: str, - filters: Optional[dict[str, Any]] = None, + filters: dict[str, Any] | None = None, page: int = 1, page_size: int = 100, ) -> list[dict[str, Any]]: - query = filters.get("query", "") if filters else "" - resp = self.post( - "/indicators/search", json={"query": query, "size": page_size, "page": page - 1} - ) + """ + List XSOAR objects of a given STIX type. + + Parameters + ---------- + stix_type : str + ``"observed-data"`` to list incidents, other types to list + indicators. + filters : dict, optional + For indicators: ``{"query": "type:IP"}`` free-text query. + For incidents: ``{"query": "status:0"}`` XSOAR query string. + page : int + 1-based page number. + page_size : int + Records per page. + """ + query = (filters or {}).get("query", "") + if stix_type == "observed-data": + resp = self.post("/incidents/search", json={ + "query": query, + "size": page_size, + "page": page - 1, + }) + return resp.get("data", []) if isinstance(resp, dict) else [] + # Indicator path + resp = self.post("/indicators/search", json={ + "query": query, "size": page_size, "page": page - 1 + }) return resp.get("iocObjects", []) if isinstance(resp, dict) else [] def upsert_object( self, stix_type: str, payload: dict[str, Any], - incident_id: Optional[str] = None, + incident_id: str | None = None, **kwargs: Any, ) -> dict[str, Any]: - """Create or update an indicator. If *incident_id* is given, the - indicator is linked to that incident after upsert.""" + """ + Create or update an XSOAR object. + + Parameters + ---------- + stix_type : str + ``"observed-data"`` to create/update an incident, other types + for indicators. + payload : dict + Object fields. An ``"id"`` key in the payload triggers an update + (PUT) for incidents. + incident_id : str, optional + For indicator writes only: if provided, the indicator is linked + to this incident after upsert via :meth:`link_incident`. + """ + if stix_type == "observed-data": + inc_id = payload.pop("id", None) + if inc_id: + return self.put(f"/incident/{inc_id}", json=payload) + return self.post("/incident", json=payload) + # Indicator path result = self.post("/indicators/edit", json=payload) if incident_id: self.link_incident(incident_id, payload) return result def delete_object(self, stix_type: str, object_id: str) -> None: - self.post("/indicators/delete", json={"id": object_id, "doNotWhitelist": False}) + """ + Delete an XSOAR object. + + Parameters + ---------- + stix_type : str + ``"observed-data"`` to delete an incident, other types for + indicators. + object_id : str + XSOAR id of the object to remove. + """ + if stix_type == "observed-data": + self.delete(f"/incident/{object_id}") + else: + self.post("/indicators/delete", json={ + "id": object_id, "doNotWhitelist": False + }) + + # ── ConnectorMixin — STIX translation ───────────────────────────────── + + def to_stix(self, native: dict[str, Any]) -> dict[str, Any]: + """ + Translate an XSOAR native object to STIX 2.1. + + Dispatches to :meth:`_indicator_to_stix` for indicator records and + :meth:`_incident_to_stix` for incident records (detected by the + presence of ``"CustomFields"`` or ``"type"`` == ``"incident"`` in + the payload). + + Parameters + ---------- + native : dict + Raw XSOAR API response (indicator or incident record). + """ + if native.get("type") == "incident" or "CustomFields" in native: + return self._incident_to_stix(native) + return self._indicator_to_stix(native) + + def from_stix(self, stix_dict: dict[str, Any]) -> dict[str, Any]: + """ + Translate a STIX object to an XSOAR API payload. - def link_incident(self, incident_id: str, stix_obj: dict[str, Any]) -> dict[str, Any]: + Dispatches on STIX type: ``observed-data`` → incident payload, + everything else → indicator payload. + + Parameters + ---------- + stix_dict : dict + STIX 2.1 SDO. + """ + if stix_dict.get("type") == "observed-data": + return { + "name": stix_dict.get("name", ""), + "description": stix_dict.get("description", ""), + "type": "incident", + } + return { + "value": stix_dict.get("name", ""), + "indicator_type": self._infer_xsoar_type(stix_dict.get("pattern", "")), + "score": self._confidence_to_score(stix_dict.get("confidence", 50)), + } + + # ── Investigation linking ───────────────────────────────────────────── + + def link_incident( + self, + incident_id: str, + stix_obj: dict[str, Any], + ) -> dict[str, Any]: """ Link a STIX indicator to an existing XSOAR incident. @@ -91,7 +298,7 @@ def link_incident(self, incident_id: str, stix_obj: dict[str, Any]) -> dict[str, incident_id : str XSOAR incident ID (numeric string). stix_obj : dict - STIX indicator dict (or any dict with a ``name`` / ``pattern`` field). + STIX indicator dict (or any dict with a ``name`` / ``value`` field). Returns ------- @@ -105,21 +312,183 @@ def link_incident(self, incident_id: str, stix_obj: dict[str, Any]) -> dict[str, } return self.post(f"/incident/{incident_id}/linkedIncidents", json=payload) - def to_stix(self, native: dict[str, Any]) -> dict[str, Any]: + # ── Evidence expansion ──────────────────────────────────────────────── + + def get_incident_alerts(self, incident_id: str) -> list[dict[str, Any]]: + """ + Return alerts linked to an XSOAR incident. + + Calls ``POST /alerts/search`` with ``incidentId`` filter. + + Parameters + ---------- + incident_id : str + XSOAR incident ID (numeric string). + + Returns + ------- + list of dict + Raw XSOAR alert records. + """ + resp = self.post("/alerts/search", json={ + "filter": {"incidentId": incident_id}, + "size": 100, + }) + return resp.get("data", []) if isinstance(resp, dict) else [] + + def get_incident_tasks(self, incident_id: str) -> list[dict[str, Any]]: + """ + Return tasks associated with an XSOAR incident. + + Calls ``GET /tasks`` with ``incidentId`` query parameter. + + Parameters + ---------- + incident_id : str + XSOAR incident ID. + + Returns + ------- + list of dict + Raw XSOAR task records. + """ + resp = self.get("/tasks", params={"incidentId": incident_id, "size": 100}) + # Response shape varies; handle both list and dict-with-data + if isinstance(resp, list): + return resp + if isinstance(resp, dict): + return resp.get("data", resp.get("tasks", [])) + return [] + + def get_incident_timeline(self, incident_id: str) -> list[dict[str, Any]]: + """ + Return timeline entries (war-room entries) for an XSOAR incident. + + Calls ``POST /entry/search`` filtered by incident ID. + + Parameters + ---------- + incident_id : str + XSOAR incident ID. + + Returns + ------- + list of dict + Raw XSOAR war-room entry records. + """ + resp = self.post("/entry/search", json={ + "filter": {"id": incident_id}, + "size": 200, + }) + if isinstance(resp, list): + return resp + if isinstance(resp, dict): + return resp.get("data", resp.get("entries", [])) + return [] + + def search_indicators_by_value(self, value: str) -> list[dict[str, Any]]: + """ + Search XSOAR indicators by exact or partial value. + + Parameters + ---------- + value : str + Indicator value to search for (IP, domain, hash, …). + + Returns + ------- + list of dict + Raw XSOAR indicator (iocObject) records. + """ + resp = self.post("/indicators/search", json={"query": f'value:"{value}"', "size": 50}) + return resp.get("iocObjects", []) if isinstance(resp, dict) else [] + + # ── Private helpers ──────────────────────────────────────────────────── + + def _indicator_to_stix(self, native: dict[str, Any]) -> dict[str, Any]: + """Map an XSOAR indicator dict to a STIX Indicator SDO.""" + xsoar_type = str(native.get("indicator_type", "")).lower() + stix_path = _XSOAR_TYPE_TO_STIX.get(xsoar_type, "unknown:value") + value = native.get("value", "") + pattern = f"[{stix_path} = '{value}']" if value else "" return { - "type": "indicator", - "id": f"indicator--{native.get('id', '')}", - "name": native.get("value", ""), - "pattern": f"[ipv4-addr:value = '{native.get('value', '')}']", - "pattern_type": "stix", - "created": native.get("timestamp", ""), - "modified": native.get("modified", ""), + "type": "indicator", + "id": f"indicator--{native.get('id', '')}", + "name": value, + "pattern": pattern, + "pattern_type": "stix", + "created": native.get("timestamp", ""), + "modified": native.get("modified", ""), "indicator_types": [native.get("indicator_type", "unknown")], + "confidence": self._score_to_confidence(native.get("score", 0)), } - def from_stix(self, stix_dict: dict[str, Any]) -> dict[str, Any]: + @staticmethod + def _incident_to_stix(native: dict[str, Any]) -> dict[str, Any]: + """Map an XSOAR incident dict to a STIX ``observed-data`` SDO.""" + inc_id = str(native.get("id", "")) + opened_at = native.get("occurred", native.get("created", "")) + modified = native.get("modified", opened_at) + custom = native.get("CustomFields", {}) + if not isinstance(custom, dict): + custom = {} return { - "value": stix_dict.get("name", ""), - "indicator_type": "IP", - "score": 2, + "type": "observed-data", + "id": f"observed-data--{inc_id}", + "created": opened_at, + "modified": modified, + "first_observed": opened_at, + "last_observed": modified, + "number_observed": 1, + "object_refs": [], + "name": native.get("name", ""), + "description": native.get("details", ""), + "x_xsoar_incident_id": inc_id, + "x_xsoar_severity": native.get("severity", 0), + "x_xsoar_status": native.get("status", 0), + "x_xsoar_owner": native.get("owner", ""), + "x_xsoar_type": native.get("type", ""), + "x_xsoar_labels": [ + lbl.get("value", "") for lbl in native.get("labels", []) + if isinstance(lbl, dict) + ], + "x_xsoar_custom": custom, } + + @staticmethod + def _infer_xsoar_type(pattern: str) -> str: + """Infer the XSOAR indicator type from a STIX pattern string.""" + p = pattern.lower() + if "ipv4-addr" in p: + return "IP" + if "ipv6-addr" in p: + return "IPv6" + if "domain-name" in p: + return "Domain" + if "url:" in p: + return "URL" + if "sha-256" in p: + return "File SHA-256" + if "sha-1" in p: + return "File SHA-1" + if "md5" in p: + return "File MD5" + if "email-addr" in p: + return "Email" + return "Unclassified" + + @staticmethod + def _score_to_confidence(score: int) -> int: + """Convert XSOAR severity score (0-3) to STIX confidence (0-100).""" + return min(100, max(0, score * 33)) + + @staticmethod + def _confidence_to_score(confidence: int) -> int: + """Convert STIX confidence (0-100) to XSOAR score (0-3).""" + if confidence >= 67: + return 3 + if confidence >= 34: + return 2 + if confidence >= 1: + return 1 + return 0 diff --git a/gnat/investigations/__init__.py b/gnat/investigations/__init__.py new file mode 100644 index 00000000..3e2d27a9 --- /dev/null +++ b/gnat/investigations/__init__.py @@ -0,0 +1,71 @@ +""" +gnat.investigations +==================== + +Incident-centric evidence graph builder. + +Collects evidence from connected platforms (XSOAR, GreyMatter, ThreatQ, …), +normalises it into a common model, correlates cross-system matches, and +materialises the result into a GNAT workspace. + +Five-step pipeline +------------------ +1. **Seed expansion** — query each connected system for incidents, alerts, + observables, and indicators matching the seed values. +2. **Incident expansion** — for each discovered incident/case/event, fetch + its constituent evidence (alerts, tasks, linked observables, timeline). +3. **Normalisation** — translate every raw record into a common + :class:`~.model.EvidenceNode`. +4. **Correlation** — add ``same-ioc``, ``same-host``, ``same-user``, + ``same-campaign``, and ``same-ticket`` edges between nodes from + different platforms that share correlation attributes. +5. **Materialisation** — write nodes and edges to a GNAT workspace as + STIX objects and Relationship SROs. + +Quick start:: + + from gnat.investigations import InvestigationBuilder, Seed, SeedType, materialize + + builder = InvestigationBuilder({ + "xsoar": xsoar_client, + "greymatter": gm_client, + "threatq": tq_client, + }) + + graph = builder.build( + seeds=[ + Seed("185.220.101.5", SeedType.IP), + Seed("INC-4892", SeedType.CASE_ID, hint_platform="xsoar"), + ], + title="Ransomware triage – 2026-04-05", + ) + + print(graph.summary()) + ws = materialize(graph, workspace_manager) +""" + +from gnat.investigations.builder import InvestigationBuilder +from gnat.investigations.correlator import correlate +from gnat.investigations.model import ( + EvidenceEdge, + EvidenceGraph, + EvidenceNode, + NodeType, + Seed, + SeedType, +) +from gnat.investigations.normalizer import normalize +from gnat.investigations.workspace import materialize + +__all__ = [ + "InvestigationBuilder", + "EvidenceGraph", + "EvidenceNode", + "EvidenceEdge", + "NodeType", + "Seed", + "SeedType", + "normalize", + "correlate", + "materialize", +] diff --git a/gnat/investigations/builder.py b/gnat/investigations/builder.py new file mode 100644 index 00000000..3260edb9 --- /dev/null +++ b/gnat/investigations/builder.py @@ -0,0 +1,428 @@ +""" +gnat.investigations.builder +============================== + +:class:`InvestigationBuilder` orchestrates the five-step evidence graph +pipeline: + +1. **Seed expansion** — translate each seed into platform queries and collect + the initial node set (indicators, incidents, cases, events). +2. **Incident expansion** — for each collected incident/case/event, fetch its + constituent evidence (alerts, tasks, linked observables, timeline entries, + adversaries). +3. **Normalisation** — every raw platform record is translated into a common + :class:`~.model.EvidenceNode` via :mod:`.normalizer`. +4. **Correlation** — cross-system edges are added for any two nodes from + different platforms that share an IOC value, hostname, username, campaign + label, or ticket reference. +5. **Materialisation** — the completed graph can be persisted into a GNAT + workspace via :func:`.workspace.materialize`. + +Usage:: + + from gnat.investigations.builder import InvestigationBuilder + from gnat.investigations.model import Seed, SeedType + from gnat.investigations.workspace import materialize + + builder = InvestigationBuilder({ + "xsoar": xsoar_client, + "greymatter": gm_client, + "threatq": tq_client, + "thehive": hive_client, + "servicenow_secops": sn_client, + "cortex_xdr": xdr_client, + }) + + graph = builder.build( + seeds=[ + Seed("185.220.101.5", SeedType.IP), + Seed("INC-4892", SeedType.CASE_ID, hint_platform="xsoar"), + ], + title="Ransomware triage – 2026-04-05", + ) + + print(graph.summary()) + ws = materialize(graph, workspace_manager, "ransomware-apr-2026") +""" + +from __future__ import annotations + +import logging +from typing import Any + +from gnat.investigations.correlator import correlate +from gnat.investigations.model import ( + EvidenceEdge, + EvidenceGraph, + EvidenceNode, + NodeType, + Seed, + SeedType, +) +from gnat.investigations.normalizer import normalize + +logger = logging.getLogger(__name__) + +# Seed types that should trigger a value-search on indicators/observables +_IOC_SEED_TYPES = frozenset({ + SeedType.IOC_VALUE, + SeedType.IP, + SeedType.DOMAIN, + SeedType.HASH, + SeedType.EMAIL, + SeedType.URL, +}) + +# Seed types that should also search incidents by free-text query +_INCIDENT_SEARCH_TYPES = frozenset({ + SeedType.IOC_VALUE, + SeedType.IP, + SeedType.DOMAIN, + SeedType.HOSTNAME, + SeedType.HASH, + SeedType.USERNAME, +}) + + +class InvestigationBuilder: + """ + Build an :class:`~.model.EvidenceGraph` by querying multiple connectors. + + Parameters + ---------- + connectors : dict + Mapping of platform name → connector instance. Any connector that + implements the :class:`~gnat.connectors.base_connector.ConnectorMixin` + interface (or a subset of it) is accepted. Platform names should + match those expected by the normaliser: ``"xsoar"``, ``"greymatter"``, + ``"threatq"``, ``"thehive"``, ``"servicenow_secops"``, + ``"cortex_xdr"``. + """ + + def __init__(self, connectors: dict[str, Any]) -> None: + self._connectors = connectors + + # ── Public API ──────────────────────────────────────────────────────── + + def build( + self, + seeds: list[Seed], + title: str = "Investigation", + expand_depth: int = 1, + ) -> EvidenceGraph: + """ + Run the full five-step evidence graph pipeline. + + Parameters + ---------- + seeds : list of Seed + Starting points for evidence collection. + title : str + Human-readable investigation title stored in the graph. + expand_depth : int + Number of expansion rounds. ``1`` (default) collects direct + children of each seed-identified incident. Higher values would + continue expanding newly discovered incidents (reserved for future + use; currently only ``1`` is applied). + + Returns + ------- + EvidenceGraph + """ + graph = EvidenceGraph(title=title, seeds=seeds) + + # Step 1: Expand seeds → initial nodes + logger.debug("InvestigationBuilder: expanding %d seeds", len(seeds)) + for seed in seeds: + self._expand_seed(graph, seed) + + # Step 2: Expand each incident → constituent evidence + incident_nodes = [ + n for n in list(graph.nodes.values()) + if n.node_type == NodeType.INCIDENT + ] + logger.debug( + "InvestigationBuilder: expanding %d incident nodes", len(incident_nodes) + ) + for node in incident_nodes: + self._expand_incident(graph, node) + + # Step 3-4: Correlate (builds indexes + cross-platform edges) + correlate(graph) + + logger.info( + "InvestigationBuilder: finished — %s", + graph.summary(), + ) + return graph + + # ── Step 1: seed expansion ──────────────────────────────────────────── + + def _expand_seed(self, graph: EvidenceGraph, seed: Seed) -> None: + for platform, connector in self._connectors.items(): + if seed.hint_platform and seed.hint_platform != platform: + continue + self._query_seed_on_platform(graph, seed, platform, connector) + + def _query_seed_on_platform( + self, + graph: EvidenceGraph, + seed: Seed, + platform: str, + connector: Any, + ) -> None: + # ── IOC / indicator value search ────────────────────────────────── + if seed.seed_type in _IOC_SEED_TYPES: + # Prefer platform-specific value-search methods + if hasattr(connector, "search_indicators_by_value"): + self._collect(graph, platform, "indicator", + _safe_call(connector.search_indicators_by_value, seed.value)) + elif hasattr(connector, "search_observables_by_value"): + self._collect(graph, platform, "observable", + _safe_call(connector.search_observables_by_value, seed.value)) + # Also fall back to generic list_objects with query filter + else: + results = _safe_call( + connector.list_objects, "indicator", + filters={"query": seed.value}, + ) + self._collect(graph, platform, "indicator", results) + + # ── Direct case / alert / ticket lookup ─────────────────────────── + if seed.seed_type in (SeedType.CASE_ID, SeedType.ALERT_ID, SeedType.TICKET_REF): + result = _safe_call(connector.get_object, "observed-data", seed.value) + if result: + node = normalize(platform, "incident", result) + if node: + _add_node(graph, node) + + # ── Incident text search for IOC / hostname / username seeds ────── + if seed.seed_type in _INCIDENT_SEARCH_TYPES: + results = _safe_call( + connector.list_objects, "observed-data", + filters={"query": seed.value}, + ) + self._collect(graph, platform, "incident", results) + + # ── Hostname / username: search indicators too ──────────────────── + if seed.seed_type in (SeedType.HOSTNAME, SeedType.USERNAME): + results = _safe_call( + connector.list_objects, "indicator", + filters={"query": seed.value}, + ) + self._collect(graph, platform, "indicator", results) + + # ── Step 2: incident expansion ──────────────────────────────────────── + + def _expand_incident(self, graph: EvidenceGraph, node: EvidenceNode) -> None: + connector = self._connectors.get(node.platform) + if connector is None: + return + + platform = node.platform + inc_id = node.source_id + + if platform == "xsoar": + self._expand_xsoar_incident(graph, node, connector, inc_id) + elif platform == "greymatter": + self._expand_gm_incident(graph, node, connector, inc_id) + elif platform == "threatq": + self._expand_tq_event(graph, node, connector, inc_id) + elif platform == "thehive": + self._expand_hive_case(graph, node, connector, inc_id) + elif platform == "servicenow_secops": + self._expand_sn_secops_incident(graph, node, connector, inc_id) + elif platform == "cortex_xdr": + self._expand_xdr_incident(graph, node, connector, inc_id) + + def _expand_xsoar_incident( + self, + graph: EvidenceGraph, + parent: EvidenceNode, + connector: Any, + inc_id: str, + ) -> None: + if hasattr(connector, "get_incident_alerts"): + for r in _safe_call(connector.get_incident_alerts, inc_id): + child = normalize(parent.platform, "alert", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + if hasattr(connector, "get_incident_tasks"): + for r in _safe_call(connector.get_incident_tasks, inc_id): + child = normalize(parent.platform, "task", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + if hasattr(connector, "get_incident_timeline"): + for r in _safe_call(connector.get_incident_timeline, inc_id): + child = normalize(parent.platform, "timeline", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + def _expand_gm_incident( + self, + graph: EvidenceGraph, + parent: EvidenceNode, + connector: Any, + case_id: str, + ) -> None: + if hasattr(connector, "get_investigation_observables"): + for r in _safe_call(connector.get_investigation_observables, case_id): + child = normalize(parent.platform, "observable", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + if hasattr(connector, "get_investigation_tasks"): + for r in _safe_call(connector.get_investigation_tasks, case_id): + child = normalize(parent.platform, "task", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + def _expand_tq_event( + self, + graph: EvidenceGraph, + parent: EvidenceNode, + connector: Any, + event_id: str, + ) -> None: + if hasattr(connector, "get_event_indicators"): + for r in _safe_call(connector.get_event_indicators, event_id): + child = normalize(parent.platform, "indicator", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + if hasattr(connector, "get_event_adversaries"): + for r in _safe_call(connector.get_event_adversaries, event_id): + child = normalize(parent.platform, "adversary", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + def _expand_hive_case( + self, + graph: EvidenceGraph, + parent: EvidenceNode, + connector: Any, + case_id: str, + ) -> None: + if hasattr(connector, "get_case_observables"): + for r in _safe_call(connector.get_case_observables, case_id): + child = normalize(parent.platform, "observable", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + if hasattr(connector, "get_case_tasks"): + for r in _safe_call(connector.get_case_tasks, case_id): + child = normalize(parent.platform, "task", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + def _expand_sn_secops_incident( + self, + graph: EvidenceGraph, + parent: EvidenceNode, + connector: Any, + incident_sys_id: str, + ) -> None: + if hasattr(connector, "get_incident_tasks"): + for r in _safe_call(connector.get_incident_tasks, incident_sys_id): + child = normalize(parent.platform, "task", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + if hasattr(connector, "get_incident_observables"): + for r in _safe_call(connector.get_incident_observables, incident_sys_id): + child = normalize(parent.platform, "observable", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + def _expand_xdr_incident( + self, + graph: EvidenceGraph, + parent: EvidenceNode, + connector: Any, + incident_id: str, + ) -> None: + if hasattr(connector, "get_incident_alerts"): + for r in _safe_call(connector.get_incident_alerts, incident_id): + child = normalize(parent.platform, "alert", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + if hasattr(connector, "get_incident_artifacts"): + for r in _safe_call(connector.get_incident_artifacts, incident_id): + child = normalize(parent.platform, "artifact", r) + if child: + _add_node(graph, child) + _add_part_of(graph, child, parent) + + # ── Helpers ─────────────────────────────────────────────────────────── + + def _collect( + self, + graph: EvidenceGraph, + platform: str, + record_type: str, + results: list[dict[str, Any]] | None, + ) -> None: + if not results: + return + for raw in results: + node = normalize(platform, record_type, raw) + if node: + _add_node(graph, node) + + +# ── Module-level helpers ─────────────────────────────────────────────────── + +def _add_node(graph: EvidenceGraph, node: EvidenceNode) -> None: + """Add *node* to *graph*, skipping duplicates (by node_id).""" + if node.node_id not in graph.nodes: + graph.nodes[node.node_id] = node + + +def _add_part_of( + graph: EvidenceGraph, + child: EvidenceNode, + parent: EvidenceNode, +) -> None: + """Add a structural ``part-of`` edge from *child* to *parent*.""" + graph.edges.append(EvidenceEdge( + source_id = child.node_id, + target_id = parent.node_id, + relationship_type = "part-of", + confidence = 1.0, + source_platform = child.platform, + )) + + +def _safe_call(fn: Any, *args: Any, **kwargs: Any) -> list[dict[str, Any]]: + """ + Call *fn* with *args* / *kwargs*, returning an empty list on any exception. + + Connectors may not support every method or a platform may be unreachable. + Evidence collection should be best-effort — a single failure must not stop + the whole graph build. + """ + try: + result = fn(*args, **kwargs) + if result is None: + return [] + if isinstance(result, dict): + return [result] + return list(result) + except Exception as exc: # noqa: BLE001 + logger.debug("Evidence expansion skipped (%s): %s", fn, exc) + return [] diff --git a/gnat/investigations/correlator.py b/gnat/investigations/correlator.py new file mode 100644 index 00000000..79920413 --- /dev/null +++ b/gnat/investigations/correlator.py @@ -0,0 +1,128 @@ +""" +gnat.investigations.correlator +================================= + +Builds cross-system correlation edges in an :class:`~.model.EvidenceGraph`. + +After the :class:`~.builder.InvestigationBuilder` has collected and +normalised all evidence nodes, the correlator: + +1. Indexes each node by its extracted correlation attributes (IOC values, + hostnames, usernames, campaign labels, ticket references). +2. Adds :class:`~.model.EvidenceEdge` objects between any two nodes from + **different platforms** that share one or more attributes. + +Only cross-platform matches generate edges — same-platform links are +uninteresting for correlation purposes because a single platform already +knows its own relationships. + +Usage:: + + from gnat.investigations.correlator import correlate + + correlate(graph) # mutates graph in-place +""" + +from __future__ import annotations + +from gnat.investigations.model import EvidenceEdge, EvidenceGraph + + +def correlate(graph: EvidenceGraph) -> None: + """ + Add cross-system correlation edges to *graph* in-place. + + Builds five index maps (by IOC value, hostname, username, campaign label, + ticket reference) and then emits ``same-*`` edges for any cross-platform + matches found in those maps. + + Parameters + ---------- + graph : EvidenceGraph + The graph to correlate. Modified in-place. + """ + # ── Build indexes ────────────────────────────────────────────────────── + for node in graph.nodes.values(): + for val in node.ioc_values: + key = val.lower().strip() + if key: + graph.by_ioc.setdefault(key, []) + if node.node_id not in graph.by_ioc[key]: + graph.by_ioc[key].append(node.node_id) + + for h in node.hostnames: + key = h.lower().strip() + if key: + graph.by_hostname.setdefault(key, []) + if node.node_id not in graph.by_hostname[key]: + graph.by_hostname[key].append(node.node_id) + + for u in node.usernames: + key = u.lower().strip() + if key: + graph.by_username.setdefault(key, []) + if node.node_id not in graph.by_username[key]: + graph.by_username[key].append(node.node_id) + + for c in node.campaign_labels: + key = c.lower().strip() + if key: + graph.by_campaign.setdefault(key, []) + if node.node_id not in graph.by_campaign[key]: + graph.by_campaign[key].append(node.node_id) + + for t in node.ticket_refs: + key = t.strip() + if key: + graph.by_ticket.setdefault(key, []) + if node.node_id not in graph.by_ticket[key]: + graph.by_ticket[key].append(node.node_id) + + # ── Emit cross-platform edges ────────────────────────────────────────── + _add_edges(graph, graph.by_ioc, "same-ioc", "IOC") + _add_edges(graph, graph.by_hostname, "same-host", "hostname") + _add_edges(graph, graph.by_username, "same-user", "username") + _add_edges(graph, graph.by_campaign, "same-campaign", "campaign label") + _add_edges(graph, graph.by_ticket, "same-ticket", "ticket") + + +def _add_edges( + graph: EvidenceGraph, + index: dict[str, list[str]], + relationship_type: str, + label: str, +) -> None: + """Emit cross-platform edges for every multi-node entry in *index*.""" + existing: set[tuple[str, str, str]] = { + (e.source_id, e.target_id, e.relationship_type) + for e in graph.edges + } + + for key, node_ids in index.items(): + if len(node_ids) < 2: + continue + + # Only create edges when at least two different platforms are involved + platforms = {graph.nodes[nid].platform for nid in node_ids if nid in graph.nodes} + if len(platforms) < 2: + continue + + for i, a in enumerate(node_ids): + for b in node_ids[i + 1:]: + if a not in graph.nodes or b not in graph.nodes: + continue + if graph.nodes[a].platform == graph.nodes[b].platform: + continue # same-platform pair — skip + # Canonical order to avoid duplicate reverse-direction edges + src, tgt = (a, b) if a < b else (b, a) + sig = (src, tgt, relationship_type) + if sig in existing: + continue + existing.add(sig) + graph.edges.append(EvidenceEdge( + source_id = src, + target_id = tgt, + relationship_type = relationship_type, + confidence = 0.9, + reasoning = f"Shared {label}: {key}", + )) diff --git a/gnat/investigations/model.py b/gnat/investigations/model.py new file mode 100644 index 00000000..7b1881f8 --- /dev/null +++ b/gnat/investigations/model.py @@ -0,0 +1,232 @@ +""" +gnat.investigations.model +============================ + +Core data model for the evidence graph built by :class:`InvestigationBuilder`. + +The graph has two primitive types: + +* :class:`EvidenceNode` — a normalised record from any connected platform + (incident, observable, asset, identity, finding, task, artifact, …). +* :class:`EvidenceEdge` — a directed relationship between two nodes + (source attribution, cross-platform correlation, parent–child containment). + +:class:`EvidenceGraph` is the container that holds nodes + edges plus +pre-built correlation indexes (by IOC value, hostname, username, campaign +label, and ticket reference) so the correlator can run in O(n) time. + +Seeds +----- +An investigation starts from one or more :class:`Seed` values supplied by +the analyst. Each seed has a :class:`SeedType` that controls which platform +APIs are queried:: + + seeds = [ + Seed("185.220.101.0", SeedType.IP), + Seed("INC-4892", SeedType.CASE_ID, hint_platform="xsoar"), + ] +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class SeedType(str, Enum): + """Classification of an investigation seed value.""" + + IOC_VALUE = "ioc_value" # Generic indicator (IP, domain, URL, hash) + IP = "ip" + DOMAIN = "domain" + HASH = "hash" + EMAIL = "email" + URL = "url" + HOSTNAME = "hostname" + USERNAME = "username" + ALERT_ID = "alert_id" + CASE_ID = "case_id" + TICKET_REF = "ticket_ref" + EMAIL_SUBJ = "email_subject" + + +class NodeType(str, Enum): + """Normalised record category regardless of source platform.""" + + INCIDENT = "incident" + OBSERVABLE = "observable" + ASSET = "asset" + IDENTITY = "identity" + FINDING = "finding" + TASK = "task" + DECISION = "decision" + ARTIFACT = "artifact" + TIMELINE_EVENT = "timeline_event" + + +@dataclass +class Seed: + """ + A single investigation seed value. + + Parameters + ---------- + value : str + The seed string (IP address, case ID, hostname, hash, …). + seed_type : SeedType + Tells the builder how to query each connector. + hint_platform : str or None + Restrict expansion to a single platform name (e.g. ``"xsoar"``). + When ``None``, all connected platforms are queried. + """ + + value: str + seed_type: SeedType + hint_platform: str | None = None + + +@dataclass +class EvidenceNode: + """ + A normalised record from any connected platform. + + Parameters + ---------- + node_id : str + Stable deduplication key: ``"{platform}::{node_type}::{source_id}"``. + node_type : NodeType + Normalised category. + platform : str + Source connector name (``"xsoar"``, ``"greymatter"``, ``"threatq"``). + source_id : str + Native platform identifier (incident id, case UUID, event id, …). + stix : dict + Normalised STIX 2.1 SDO built from the native record. + raw : dict + Unmodified platform API response — preserved for traceability. + ioc_values : list of str + Extracted indicator values (IPs, domains, hashes, URLs). + hostnames : list of str + Extracted hostnames / asset names. + usernames : list of str + Extracted usernames / identity references. + campaign_labels : list of str + Campaign or actor labels found in tags, names, or custom fields. + ticket_refs : list of str + External ticket references (Jira, ServiceNow, …). + time_window : (str, str) or None + Earliest and latest timestamps found in the record. + """ + + node_id: str + node_type: NodeType + platform: str + source_id: str + stix: dict[str, Any] + raw: dict[str, Any] + ioc_values: list[str] = field(default_factory=list) + hostnames: list[str] = field(default_factory=list) + usernames: list[str] = field(default_factory=list) + campaign_labels: list[str] = field(default_factory=list) + ticket_refs: list[str] = field(default_factory=list) + time_window: tuple[str, str] | None = None + + +@dataclass +class EvidenceEdge: + """ + A directed relationship between two :class:`EvidenceNode` objects. + + Parameters + ---------- + source_id : str + ``node_id`` of the source node. + target_id : str + ``node_id`` of the target node. + relationship_type : str + Relationship verb: ``"part-of"``, ``"same-ioc"``, ``"same-host"``, + ``"same-user"``, ``"same-campaign"``, ``"same-ticket"``, + ``"indicates"``, ``"related-to"``. + confidence : float + 0–1 confidence score. Auto-correlation edges default to 0.9. + Structural (part-of) edges are 1.0. + source_platform : str + Which platform produced this edge (empty for inferred edges). + reasoning : str + Human-readable justification (e.g. ``"Shared IOC: 185.220.101.5"``). + """ + + source_id: str + target_id: str + relationship_type: str + confidence: float = 1.0 + source_platform: str = "" + reasoning: str = "" + + +@dataclass +class EvidenceGraph: + """ + Container for the full evidence graph produced by :class:`InvestigationBuilder`. + + Attributes + ---------- + title : str + Human-readable investigation title. + seeds : list of Seed + The seeds that started this investigation. + nodes : dict + ``{node_id: EvidenceNode}`` — all collected evidence. + edges : list of EvidenceEdge + All structural and correlation edges. + by_ioc : dict + ``{ioc_value_lower: [node_id, …]}`` — correlation index. + by_hostname : dict + ``{hostname_lower: [node_id, …]}`` — correlation index. + by_username : dict + ``{username_lower: [node_id, …]}`` — correlation index. + by_campaign : dict + ``{label_lower: [node_id, …]}`` — correlation index. + by_ticket : dict + ``{ticket_ref: [node_id, …]}`` — correlation index. + """ + + title: str + seeds: list[Seed] + nodes: dict[str, EvidenceNode] = field(default_factory=dict) + edges: list[EvidenceEdge] = field(default_factory=list) + # Correlation indexes populated by correlator + by_ioc: dict[str, list[str]] = field(default_factory=dict) + by_hostname: dict[str, list[str]] = field(default_factory=dict) + by_username: dict[str, list[str]] = field(default_factory=dict) + by_campaign: dict[str, list[str]] = field(default_factory=dict) + by_ticket: dict[str, list[str]] = field(default_factory=dict) + + # ── Convenience helpers ─────────────────────────────────────────────── + + def summary(self) -> dict[str, Any]: + """Return a compact summary dict suitable for logging or display.""" + platform_counts: dict[str, int] = {} + type_counts: dict[str, int] = {} + for node in self.nodes.values(): + platform_counts[node.platform] = platform_counts.get(node.platform, 0) + 1 + type_counts[node.node_type] = type_counts.get(node.node_type, 0) + 1 + cross = sum( + 1 for e in self.edges + if e.relationship_type.startswith("same-") + and self.nodes.get(e.source_id, EvidenceNode("", NodeType.OBSERVABLE, "", "", {}, {})).platform + != self.nodes.get(e.target_id, EvidenceNode("", NodeType.OBSERVABLE, "", "", {}, {})).platform + ) + return { + "title": self.title, + "seeds": len(self.seeds), + "nodes": len(self.nodes), + "edges": len(self.edges), + "cross_platform_hits": cross, + "by_platform": platform_counts, + "by_type": type_counts, + "shared_iocs": sum(1 for v in self.by_ioc.values() if len(v) > 1), + "shared_hosts": sum(1 for v in self.by_hostname.values() if len(v) > 1), + "shared_campaigns": sum(1 for v in self.by_campaign.values() if len(v) > 1), + } diff --git a/gnat/investigations/normalizer.py b/gnat/investigations/normalizer.py new file mode 100644 index 00000000..399bd294 --- /dev/null +++ b/gnat/investigations/normalizer.py @@ -0,0 +1,872 @@ +""" +gnat.investigations.normalizer +================================= + +Translate raw platform API records into :class:`~.model.EvidenceNode` objects +with extracted correlation attributes. + +Each platform stores the same underlying concepts (incident, indicator, +observable, task, timeline entry) in different schemas. The normaliser's job +is to produce a common :class:`EvidenceNode` regardless of source so the +correlator and builder can work in a uniform model. + +Normaliser functions are intentionally defensive — they never raise on +missing fields; they produce partial nodes instead. + +Usage:: + + from gnat.investigations.normalizer import normalize + + node = normalize("xsoar", "incident", raw_incident_dict) + node = normalize("greymatter", "observable", raw_observable_dict) + node = normalize("threatq", "event", raw_event_dict) +""" + +from __future__ import annotations + +import re +from typing import Any + +from gnat.investigations.model import EvidenceNode, NodeType + +# ── Regex helpers ────────────────────────────────────────────────────────── + +_IP_RE = re.compile(r'\b(?:\d{1,3}\.){3}\d{1,3}\b') +_DOMAIN_RE = re.compile(r'\b(?:[a-z0-9\-]+\.)+[a-z]{2,}\b', re.IGNORECASE) +_HASH_RE = re.compile(r'\b[0-9a-fA-F]{32,64}\b') +_EMAIL_RE = re.compile(r'\b[^\s@]+@[^\s@]+\.[^\s@]+\b') + +# Ticket patterns: JIRA-123, INC-4892, CHG-0001, ticket#12345 +_TICKET_RE = re.compile(r'\b(?:[A-Z]+-\d+|(?:INC|CHG|REQ|TICKET)-?\d+)\b', re.IGNORECASE) + + +def _extract_iocs(text: str) -> list[str]: + """Pull IP addresses, hashes, emails from a free-text string.""" + found: list[str] = [] + found.extend(_IP_RE.findall(text)) + found.extend(_HASH_RE.findall(text)) + found.extend(_EMAIL_RE.findall(text)) + return list(dict.fromkeys(found)) # dedup, preserve order + + +def _extract_tickets(text: str) -> list[str]: + return list(dict.fromkeys(_TICKET_RE.findall(text))) + + +def _node_id(platform: str, node_type: NodeType, source_id: str) -> str: + return f"{platform}::{node_type}::{source_id}" + + +# ── XSOAR ────────────────────────────────────────────────────────────────── + +def _xsoar_incident(platform: str, raw: dict[str, Any]) -> EvidenceNode: + inc_id = str(raw.get("id", "")) + opened_at = raw.get("occurred", raw.get("created", "")) + modified = raw.get("modified", opened_at) + name = raw.get("name", "") + details = raw.get("details", "") + custom = raw.get("CustomFields", {}) if isinstance(raw.get("CustomFields"), dict) else {} + + # Extract correlation attributes + blob = f"{name} {details} {' '.join(str(v) for v in custom.values())}" + ioc_values = _extract_iocs(blob) + ticket_refs = _extract_tickets(blob) + # Hostnames from CustomFields + hostnames = [ + str(custom[k]) for k in ("src_hostname", "dest_hostname", "hostname") + if k in custom and custom[k] + ] + usernames = [ + str(custom[k]) for k in ("src_user", "dest_user", "username", "account_id") + if k in custom and custom[k] + ] + # Campaign / actor from labels + campaign_labels = [ + lbl.get("value", "") for lbl in raw.get("labels", []) + if isinstance(lbl, dict) and lbl.get("type", "").lower() in ("campaign", "actor", "malware") + and lbl.get("value") + ] + + stix: dict[str, Any] = { + "type": "observed-data", + "id": f"observed-data--{inc_id}", + "created": opened_at, + "modified": modified, + "first_observed": opened_at, + "last_observed": modified, + "number_observed": 1, + "object_refs": [], + "name": name, + "description": details, + "x_xsoar_incident_id": inc_id, + "x_xsoar_severity": raw.get("severity", 0), + "x_xsoar_status": raw.get("status", 0), + "x_xsoar_owner": raw.get("owner", ""), + "x_xsoar_type": raw.get("type", ""), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.INCIDENT, inc_id), + node_type = NodeType.INCIDENT, + platform = platform, + source_id = inc_id, + stix = stix, + raw = raw, + ioc_values = ioc_values, + hostnames = hostnames, + usernames = usernames, + campaign_labels = campaign_labels, + ticket_refs = ticket_refs, + time_window = (opened_at, modified) if opened_at else None, + ) + + +def _xsoar_indicator(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("id", "")) + value = raw.get("value", "") + i_type = str(raw.get("indicator_type", "")).lower() + created = raw.get("timestamp", "") + modified = raw.get("modified", created) + + stix: dict[str, Any] = { + "type": "indicator", + "id": f"indicator--{src_id}", + "name": value, + "pattern": f"[{i_type}:value = '{value}']", + "pattern_type": "stix", + "created": created, + "modified": modified, + "indicator_types": [raw.get("indicator_type", "unknown")], + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.OBSERVABLE, src_id), + node_type = NodeType.OBSERVABLE, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + ioc_values = [value] if value else [], + time_window = (created, modified) if created else None, + ) + + +def _xsoar_alert(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("id", raw.get("alertId", ""))) + name = raw.get("name", raw.get("message", "")) + created = raw.get("startDate", raw.get("occurred", "")) + stix: dict[str, Any] = { + "type": "observed-data", + "id": f"observed-data--alert-{src_id}", + "created": created, + "modified": raw.get("closeDate", created), + "first_observed": created, + "last_observed": created, + "number_observed": 1, + "object_refs": [], + "name": name, + "x_xsoar_alert_id": src_id, + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.FINDING, f"alert-{src_id}"), + node_type = NodeType.FINDING, + platform = platform, + source_id = f"alert-{src_id}", + stix = stix, + raw = raw, + ioc_values = _extract_iocs(name), + time_window = (created, created) if created else None, + ) + + +def _xsoar_task(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("id", "")) + name = raw.get("name", raw.get("title", "")) + created = raw.get("startDate", "") + stix: dict[str, Any] = { + "type": "note", + "id": f"note--task-{src_id}", + "created": created, + "modified": raw.get("dueDate", created), + "abstract": name, + "content": raw.get("description", ""), + "x_xsoar_task_id": src_id, + "x_xsoar_status": raw.get("state", ""), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.TASK, src_id), + node_type = NodeType.TASK, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + time_window = (created, created) if created else None, + ) + + +def _xsoar_timeline(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("id", "")) + content = raw.get("contents", raw.get("message", "")) + created = raw.get("created", "") + stix: dict[str, Any] = { + "type": "note", + "id": f"note--timeline-{src_id}", + "created": created, + "modified": created, + "abstract": f"Timeline: {raw.get('type', 'entry')}", + "content": content, + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.TIMELINE_EVENT, src_id), + node_type = NodeType.TIMELINE_EVENT, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + ioc_values = _extract_iocs(content), + time_window = (created, created) if created else None, + ) + + +# ── GreyMatter ───────────────────────────────────────────────────────────── + +def _gm_incident(platform: str, raw: dict[str, Any]) -> EvidenceNode: + data = raw.get("data", raw) + src_id = str(data.get("id", "")) + created = data.get("created_at", "") + modified = data.get("updated_at", created) + title = data.get("title", data.get("name", "")) + desc = data.get("description", "") + blob = f"{title} {desc}" + ticket_refs = _extract_tickets(blob) + campaign_labels = [ + t for t in data.get("tags", []) + if isinstance(t, str) and len(t) > 2 + ] + + stix: dict[str, Any] = { + "type": "observed-data", + "id": f"observed-data--{src_id}", + "created": created, + "modified": modified, + "first_observed": created, + "last_observed": modified, + "number_observed": 1, + "object_refs": [], + "name": title, + "description": desc, + "x_gm_case_number": data.get("case_number", ""), + "x_gm_status": data.get("status", ""), + "x_gm_severity": data.get("severity", ""), + "x_gm_assigned_to": data.get("assigned_to", ""), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.INCIDENT, src_id), + node_type = NodeType.INCIDENT, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + ioc_values = _extract_iocs(blob), + ticket_refs = ticket_refs, + campaign_labels = campaign_labels, + time_window = (created, modified) if created else None, + ) + + +def _gm_observable(platform: str, raw: dict[str, Any]) -> EvidenceNode: + data = raw.get("data", raw) + src_id = str(data.get("id", "")) + value = data.get("value", data.get("name", "")) + gm_type = data.get("type", "unknown") + created = data.get("created_at", "") + modified = data.get("updated_at", created) + + # Build STIX pattern + _pattern_map = { + "ipv4": f"[ipv4-addr:value = '{value}']", + "ipv6": f"[ipv6-addr:value = '{value}']", + "domain": f"[domain-name:value = '{value}']", + "url": f"[url:value = '{value}']", + "md5": f"[file:hashes.MD5 = '{value}']", + "sha1": f"[file:hashes.SHA-1 = '{value}']", + "sha256": f"[file:hashes.SHA-256 = '{value}']", + "email": f"[email-addr:value = '{value}']", + } + pattern = _pattern_map.get(gm_type, f"[unknown:value = '{value}']") + + stix: dict[str, Any] = { + "type": "indicator", + "id": f"indicator--{src_id}", + "name": value, + "pattern": pattern, + "pattern_type": "stix", + "created": created, + "modified": modified, + "indicator_types": [data.get("classification", "unknown")], + "confidence": data.get("confidence", 50), + "x_gm_type": gm_type, + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.OBSERVABLE, src_id), + node_type = NodeType.OBSERVABLE, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + ioc_values = [value] if value else [], + time_window = (created, modified) if created else None, + ) + + +def _gm_task(platform: str, raw: dict[str, Any]) -> EvidenceNode: + data = raw.get("data", raw) + src_id = str(data.get("id", "")) + title = data.get("title", data.get("name", "")) + created = data.get("created_at", "") + stix: dict[str, Any] = { + "type": "note", + "id": f"note--gm-task-{src_id}", + "created": created, + "modified": data.get("updated_at", created), + "abstract": title, + "content": data.get("description", ""), + "x_gm_task_status": data.get("status", ""), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.TASK, src_id), + node_type = NodeType.TASK, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + time_window = (created, created) if created else None, + ) + + +# ── ThreatQ ──────────────────────────────────────────────────────────────── + +def _tq_event(platform: str, raw: dict[str, Any]) -> EvidenceNode: + data = raw.get("data", raw) + src_id = str(data.get("id", "")) + title = data.get("title", "") + desc = data.get("description", "") + created = data.get("created_at", "") + happened = data.get("happened_at", created) + modified = data.get("updated_at", created) + blob = f"{title} {desc}" + ticket_refs = _extract_tickets(blob) + + stix: dict[str, Any] = { + "type": "observed-data", + "id": f"observed-data--{src_id}", + "created": created, + "modified": modified, + "first_observed": happened, + "last_observed": happened, + "number_observed": 1, + "object_refs": [], + "name": title, + "description": desc, + "x_tq_event_type": data.get("event_type", ""), + "x_tq_event_id": src_id, + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.INCIDENT, src_id), + node_type = NodeType.INCIDENT, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + ioc_values = _extract_iocs(blob), + ticket_refs = ticket_refs, + time_window = (happened, happened) if happened else None, + ) + + +def _tq_indicator(platform: str, raw: dict[str, Any]) -> EvidenceNode: + data = raw.get("data", raw) + src_id = str(data.get("id", "")) + value = data.get("value", "") + tq_type = data.get("type", "unknown") + created = data.get("created_at", "") + modified = data.get("updated_at", created) + + stix: dict[str, Any] = { + "type": "indicator", + "id": f"indicator--{src_id}", + "name": value, + "pattern": f"[{tq_type}:value = '{value}']", + "pattern_type": "stix", + "created": created, + "modified": modified, + "indicator_types": [data.get("class", "unknown")], + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.OBSERVABLE, src_id), + node_type = NodeType.OBSERVABLE, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + ioc_values = [value] if value else [], + time_window = (created, modified) if created else None, + ) + + +def _tq_adversary(platform: str, raw: dict[str, Any]) -> EvidenceNode: + data = raw.get("data", raw) + src_id = str(data.get("id", "")) + name = data.get("name", data.get("value", "")) + created = data.get("created_at", "") + stix: dict[str, Any] = { + "type": "threat-actor", + "id": f"threat-actor--{src_id}", + "name": name, + "created": created, + "modified": data.get("updated_at", created), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.IDENTITY, src_id), + node_type = NodeType.IDENTITY, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + campaign_labels = [name] if name else [], + time_window = (created, created) if created else None, + ) + + +# ── TheHive ──────────────────────────────────────────────────────────────── + +def _hive_case(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("_id", raw.get("id", ""))) + title = raw.get("title", "") + desc = raw.get("description", "") + created = raw.get("_createdAt", raw.get("startDate", "")) + modified = raw.get("_updatedAt", raw.get("endDate", created)) + blob = f"{title} {desc}" + campaign_labels = [t for t in raw.get("tags", []) if isinstance(t, str)] + ticket_refs = _extract_tickets(blob) + + stix: dict[str, Any] = { + "type": "observed-data", + "id": f"observed-data--hive-{src_id}", + "created": created, + "modified": modified, + "first_observed": created, + "last_observed": modified, + "number_observed": 1, + "object_refs": [], + "name": title, + "description": desc, + "x_hive_case_id": src_id, + "x_hive_status": raw.get("status", ""), + "x_hive_severity": raw.get("severity", 2), + "x_hive_assigned_to": raw.get("assignee", ""), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.INCIDENT, src_id), + node_type = NodeType.INCIDENT, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + ioc_values = _extract_iocs(blob), + campaign_labels = campaign_labels, + ticket_refs = ticket_refs, + time_window = (created, modified) if created else None, + ) + + +def _hive_observable(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("_id", raw.get("id", ""))) + value = raw.get("data", "") + data_type = raw.get("dataType", "unknown") + created = raw.get("_createdAt", "") + modified = raw.get("_updatedAt", created) + + _pattern_map = { + "ip": f"[ipv4-addr:value = '{value}']", + "domain": f"[domain-name:value = '{value}']", + "url": f"[url:value = '{value}']", + "hash": f"[file:hashes.MD5 = '{value}']", + "mail": f"[email-addr:value = '{value}']", + "hostname": f"[domain-name:value = '{value}']", + "filename": f"[file:name = '{value}']", + } + pattern = _pattern_map.get(data_type, f"[unknown:value = '{value}']") + + stix: dict[str, Any] = { + "type": "indicator", + "id": f"indicator--hive-obs-{src_id}", + "name": value, + "pattern": pattern, + "pattern_type": "stix", + "created": created, + "modified": modified, + "indicator_types": [data_type], + "x_hive_ioc": raw.get("ioc", False), + "x_hive_tlp": raw.get("tlp", 1), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.OBSERVABLE, src_id), + node_type = NodeType.OBSERVABLE, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + ioc_values = [value] if value else [], + time_window = (created, modified) if created else None, + ) + + +def _hive_task(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("_id", raw.get("id", ""))) + title = raw.get("title", raw.get("name", "")) + created = raw.get("_createdAt", "") + modified = raw.get("_updatedAt", created) + + stix: dict[str, Any] = { + "type": "note", + "id": f"note--hive-task-{src_id}", + "created": created, + "modified": modified, + "abstract": title, + "content": raw.get("description", ""), + "x_hive_task_status": raw.get("status", ""), + "x_hive_assignee": raw.get("assignee", ""), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.TASK, src_id), + node_type = NodeType.TASK, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + time_window = (created, modified) if created else None, + ) + + +# ── ServiceNow SecOps ───────────────────────────────────────────────────── + +def _sn_secops_incident(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("sys_id", "")) + title = raw.get("short_description", "") + desc = raw.get("description", "") + created = raw.get("sys_created_on", "") + modified = raw.get("sys_updated_on", created) + blob = f"{title} {desc} {raw.get('work_notes', '')}" + campaign_labels = [ + raw.get("category", ""), + raw.get("subcategory", ""), + ] + campaign_labels = [c for c in campaign_labels if c] + ticket_refs = _extract_tickets(blob) + # Extract linked Jira/ticket from correlation_id or correlation_display + corr = raw.get("correlation_id", "") or raw.get("correlation_display", "") + if corr: + ticket_refs = list(dict.fromkeys(ticket_refs + [corr])) + + stix: dict[str, Any] = { + "type": "observed-data", + "id": f"observed-data--sn-{src_id}", + "created": created, + "modified": modified, + "first_observed": created, + "last_observed": modified, + "number_observed": 1, + "object_refs": [], + "name": title, + "description": desc, + "x_sn_sys_id": src_id, + "x_sn_number": raw.get("number", ""), + "x_sn_state": raw.get("state", {}).get("value", raw.get("state", "")), + "x_sn_priority": raw.get("priority", {}).get("value", raw.get("priority", "")), + "x_sn_assigned_to": raw.get("assigned_to", {}).get("display_value", ""), + "x_sn_category": raw.get("category", ""), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.INCIDENT, src_id), + node_type = NodeType.INCIDENT, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + ioc_values = _extract_iocs(blob), + campaign_labels = campaign_labels, + ticket_refs = ticket_refs, + time_window = (created, modified) if created else None, + ) + + +def _sn_secops_task(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("sys_id", "")) + title = raw.get("short_description", raw.get("name", "")) + created = raw.get("sys_created_on", "") + modified = raw.get("sys_updated_on", created) + + stix: dict[str, Any] = { + "type": "note", + "id": f"note--sn-task-{src_id}", + "created": created, + "modified": modified, + "abstract": title, + "content": raw.get("description", raw.get("work_notes", "")), + "x_sn_task_state": raw.get("state", {}).get("value", raw.get("state", "")), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.TASK, src_id), + node_type = NodeType.TASK, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + time_window = (created, modified) if created else None, + ) + + +def _sn_secops_observable(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("sys_id", "")) + value = raw.get("value", "") + obs_type = raw.get("type", {}).get("display_value", raw.get("type", "unknown")) + created = raw.get("sys_created_on", "") + modified = raw.get("sys_updated_on", created) + + _pattern_map: dict[str, str] = { + "IP Address": f"[ipv4-addr:value = '{value}']", + "Domain": f"[domain-name:value = '{value}']", + "URL": f"[url:value = '{value}']", + "File Hash": f"[file:hashes.MD5 = '{value}']", + "Email": f"[email-addr:value = '{value}']", + } + pattern = _pattern_map.get(obs_type, f"[unknown:value = '{value}']") + + stix: dict[str, Any] = { + "type": "indicator", + "id": f"indicator--sn-obs-{src_id}", + "name": value, + "pattern": pattern, + "pattern_type": "stix", + "created": created, + "modified": modified, + "indicator_types": [obs_type], + "x_sn_obs_type": obs_type, + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.OBSERVABLE, src_id), + node_type = NodeType.OBSERVABLE, + platform = platform, + source_id = src_id, + stix = stix, + raw = raw, + ioc_values = [value] if value else [], + time_window = (created, modified) if created else None, + ) + + +# ── Cortex XDR ──────────────────────────────────────────────────────────── + +def _xdr_incident(platform: str, raw: dict[str, Any]) -> EvidenceNode: + inc_id = str(raw.get("incident_id", "")) + name = raw.get("incident_name", f"XDR Incident {inc_id}") + desc = raw.get("description", "") + ts = raw.get("creation_time", "") + mod_ts = raw.get("modification_time", ts) + hosts = raw.get("hosts", []) + users = raw.get("users", []) + blob = f"{name} {desc}" + + stix: dict[str, Any] = { + "type": "observed-data", + "id": f"observed-data--xdr-{inc_id}", + "created": str(ts), + "modified": str(mod_ts), + "first_observed": str(ts), + "last_observed": str(mod_ts), + "number_observed": 1, + "object_refs": [], + "name": name, + "description": desc, + "x_xdr_incident_id": inc_id, + "x_xdr_severity": raw.get("severity", ""), + "x_xdr_status": raw.get("status", ""), + "x_xdr_alert_count": raw.get("alert_count", 0), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.INCIDENT, inc_id), + node_type = NodeType.INCIDENT, + platform = platform, + source_id = inc_id, + stix = stix, + raw = raw, + ioc_values = _extract_iocs(blob), + hostnames = [str(h) for h in hosts if h], + usernames = [str(u) for u in users if u], + time_window = (str(ts), str(mod_ts)) if ts else None, + ) + + +def _xdr_alert(platform: str, raw: dict[str, Any]) -> EvidenceNode: + alert_id = str(raw.get("alert_id", "")) + name = raw.get("name", raw.get("alert_name", f"XDR Alert {alert_id}")) + ts = raw.get("detection_timestamp", "") + host = raw.get("host_name", "") + remote_ip = raw.get("remote_ip", "") + + stix: dict[str, Any] = { + "type": "observed-data", + "id": f"observed-data--xdr-alert-{alert_id}", + "created": str(ts), + "modified": str(ts), + "first_observed": str(ts), + "last_observed": str(ts), + "number_observed": 1, + "object_refs": [], + "name": name, + "x_xdr_alert_id": alert_id, + "x_xdr_severity": raw.get("severity", ""), + "x_xdr_category": raw.get("category", ""), + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.FINDING, f"alert-{alert_id}"), + node_type = NodeType.FINDING, + platform = platform, + source_id = f"alert-{alert_id}", + stix = stix, + raw = raw, + ioc_values = [remote_ip] if remote_ip else [], + hostnames = [host] if host else [], + time_window = (str(ts), str(ts)) if ts else None, + ) + + +def _xdr_artifact(platform: str, raw: dict[str, Any]) -> EvidenceNode: + src_id = str(raw.get("alert_id", raw.get("file_sha256", raw.get("network_remote_ip", "")))) + artifact_type = "network" if "network_remote_ip" in raw else "file" + value = raw.get("network_remote_ip", "") or raw.get("file_sha256", "") + name = raw.get("file_name", "") or raw.get("network_remote_domain", value) + + stix: dict[str, Any] = { + "type": "indicator", + "id": f"indicator--xdr-artifact-{src_id[:40]}", + "name": name or value, + "pattern": ( + f"[ipv4-addr:value = '{value}']" if artifact_type == "network" + else f"[file:hashes.'SHA-256' = '{value}']" + ), + "pattern_type": "stix", + "created": "", + "modified": "", + "indicator_types": ["malicious-activity"], + "x_source_platform": platform, + } + return EvidenceNode( + node_id = _node_id(platform, NodeType.ARTIFACT, f"artifact-{src_id[:40]}"), + node_type = NodeType.ARTIFACT, + platform = platform, + source_id = f"artifact-{src_id[:40]}", + stix = stix, + raw = raw, + ioc_values = [value] if value else [], + ) + + +# ── Public dispatcher ────────────────────────────────────────────────────── + +_DISPATCH: dict[tuple[str, str], Any] = { + # XSOAR + ("xsoar", "incident"): _xsoar_incident, + ("xsoar", "indicator"): _xsoar_indicator, + ("xsoar", "alert"): _xsoar_alert, + ("xsoar", "task"): _xsoar_task, + ("xsoar", "timeline"): _xsoar_timeline, + # GreyMatter + ("greymatter", "incident"): _gm_incident, + ("greymatter", "observable"): _gm_observable, + ("greymatter", "task"): _gm_task, + # ThreatQ + ("threatq", "event"): _tq_event, + ("threatq", "indicator"): _tq_indicator, + ("threatq", "adversary"): _tq_adversary, + # TheHive + ("thehive", "case"): _hive_case, + ("thehive", "incident"): _hive_case, + ("thehive", "observable"): _hive_observable, + ("thehive", "task"): _hive_task, + # ServiceNow SecOps + ("servicenow_secops", "incident"): _sn_secops_incident, + ("servicenow_secops", "task"): _sn_secops_task, + ("servicenow_secops", "observable"): _sn_secops_observable, + # Cortex XDR + ("cortex_xdr", "incident"): _xdr_incident, + ("cortex_xdr", "alert"): _xdr_alert, + ("cortex_xdr", "artifact"): _xdr_artifact, + # Aliases — "observed-data" → incident normaliser for each platform + ("xsoar", "observed-data"): _xsoar_incident, + ("greymatter", "observed-data"): _gm_incident, + ("threatq", "observed-data"): _tq_event, + ("thehive", "observed-data"): _hive_case, + ("servicenow_secops", "observed-data"): _sn_secops_incident, + ("cortex_xdr", "observed-data"): _xdr_incident, + # indicator alias for platforms that call it differently + ("thehive", "indicator"): _hive_observable, + ("servicenow_secops", "indicator"): _sn_secops_observable, + ("cortex_xdr", "indicator"): _xdr_alert, +} + + +def normalize( + platform: str, + record_type: str, + raw: dict[str, Any], +) -> EvidenceNode | None: + """ + Translate a raw platform record into an :class:`EvidenceNode`. + + Returns ``None`` if the platform/record_type combination is unknown + or if *raw* is empty. + + Parameters + ---------- + platform : str + Connector name (``"xsoar"``, ``"greymatter"``, ``"threatq"``, + ``"thehive"``, ``"servicenow_secops"``, ``"cortex_xdr"``). + record_type : str + Platform record category (``"incident"``, ``"indicator"``, + ``"observable"``, ``"alert"``, ``"task"``, ``"event"``, …). + raw : dict + Raw API response from the connector. + + Returns + ------- + EvidenceNode or None + """ + if not raw: + return None + key = (platform.lower(), record_type.lower()) + fn = _DISPATCH.get(key) + if fn is None and record_type in ("observable", "indicator"): + fn = _DISPATCH.get((platform.lower(), "observable")) \ + or _DISPATCH.get((platform.lower(), "indicator")) + if fn is None: + return None + return fn(platform, raw) diff --git a/gnat/investigations/workspace.py b/gnat/investigations/workspace.py new file mode 100644 index 00000000..3a5f1117 --- /dev/null +++ b/gnat/investigations/workspace.py @@ -0,0 +1,173 @@ +""" +gnat.investigations.workspace +================================ + +Materialise a completed :class:`~.model.EvidenceGraph` into a GNAT +:class:`~gnat.context.workspace.Workspace`. + +Each :class:`~.model.EvidenceNode` becomes a STIX object in the workspace. +Each :class:`~.model.EvidenceEdge` becomes a STIX Relationship with +confidence and reasoning stored in ``x_*`` extension fields. + +The workspace ``metadata`` dict stores the full investigation summary, +seed list, and correlation indexes so the graph can be reconstructed +or reviewed without re-querying the connected platforms. + +Usage:: + + from gnat.investigations.workspace import materialize + + ws = materialize( + graph, + workspace_manager, + name="ransomware-apr-2026", + ) + print(f"Workspace '{ws.name}' — {len(ws.objects)} objects") +""" + +from __future__ import annotations + +import logging +from typing import Any + +from gnat.investigations.model import EvidenceGraph, EvidenceNode +from gnat.orm.base import STIXBase +from gnat.orm.relationship import Relationship + +logger = logging.getLogger(__name__) + +# STIX types that the ORM can represent as proper Relationship SROs +_RELATIONSHIP_TYPES = frozenset({ + "part-of", + "same-ioc", + "same-host", + "same-user", + "same-campaign", + "same-ticket", + "indicates", + "related-to", +}) + + +def materialize( + graph: EvidenceGraph, + workspace_manager: Any, + name: str | None = None, + description: str = "", +) -> Any: + """ + Persist an :class:`~.model.EvidenceGraph` into a GNAT workspace. + + Parameters + ---------- + graph : EvidenceGraph + The completed evidence graph produced by + :class:`~.builder.InvestigationBuilder`. + workspace_manager : WorkspaceManager + A :class:`~gnat.context.workspace.WorkspaceManager` instance used to + create the workspace. + name : str, optional + Workspace name. Defaults to a slug derived from *graph.title*. + description : str, optional + Human-readable workspace description. + + Returns + ------- + Workspace + The newly created (or updated) workspace containing all graph nodes + and edges. + """ + ws_name = name or _title_to_slug(graph.title) + ws_desc = description or f"Evidence graph: {graph.title}" + + ws = workspace_manager.create(ws_name, description=ws_desc) + + # ── Add nodes ────────────────────────────────────────────────────────── + added = 0 + for node in graph.nodes.values(): + try: + stix_obj = _node_to_stix_base(node) + ws.add(stix_obj) + added += 1 + except Exception as exc: # noqa: BLE001 + logger.warning("Could not add node %s to workspace: %s", node.node_id, exc) + + # ── Add edges as Relationship SROs ───────────────────────────────────── + for edge in graph.edges: + src_node = graph.nodes.get(edge.source_id) + tgt_node = graph.nodes.get(edge.target_id) + if not src_node or not tgt_node: + continue + src_stix_id = src_node.stix.get("id", "") + tgt_stix_id = tgt_node.stix.get("id", "") + if not src_stix_id or not tgt_stix_id: + continue + try: + rel = Relationship( + relationship_type = edge.relationship_type, + source_ref = src_stix_id, + target_ref = tgt_stix_id, + ) + rel["x_confidence"] = edge.confidence + rel["x_reasoning"] = edge.reasoning + rel["x_source_platform"] = edge.source_platform + ws.add(rel) + except Exception as exc: # noqa: BLE001 + logger.warning("Could not add edge %s→%s: %s", + edge.source_id, edge.target_id, exc) + + # ── Store investigation metadata in workspace ────────────────────────── + summary = graph.summary() + ws.metadata = { + "investigation_title": graph.title, + "seeds": [ + {"value": s.value, "type": s.seed_type, "platform": s.hint_platform} + for s in graph.seeds + ], + "summary": summary, + "correlation": { + "shared_iocs": {k: v for k, v in graph.by_ioc.items() if len(v) > 1}, + "shared_hosts": {k: v for k, v in graph.by_hostname.items() if len(v) > 1}, + "shared_users": {k: v for k, v in graph.by_username.items() if len(v) > 1}, + "shared_campaigns": {k: v for k, v in graph.by_campaign.items() if len(v) > 1}, + "shared_tickets": {k: v for k, v in graph.by_ticket.items() if len(v) > 1}, + }, + } + ws.save() + + logger.info( + "Materialised %d nodes, %d edges into workspace %r", + added, len(graph.edges), ws_name, + ) + return ws + + +# ── Helpers ──────────────────────────────────────────────────────────────── + +def _node_to_stix_base(node: EvidenceNode) -> STIXBase: + """Wrap a normalised node's STIX dict as a :class:`~gnat.orm.base.STIXBase`.""" + stix_type = node.stix.get("type", "x-evidence-node") + obj = STIXBase(stix_type=stix_type, **{ + k: v for k, v in node.stix.items() if k != "type" + }) + # Tag with investigation metadata not already in the STIX dict + obj["x_evidence_node_id"] = node.node_id + obj["x_evidence_node_type"] = node.node_type + obj["x_source_platform"] = node.platform + obj["x_source_id"] = node.source_id + if node.time_window: + obj["x_time_window_start"] = node.time_window[0] + obj["x_time_window_end"] = node.time_window[1] + return obj + + +def _title_to_slug(title: str) -> str: + """Convert an investigation title to a valid workspace name slug.""" + slug = title.lower() + for ch in (" ", "/", "\\", ":", ";", ",", ".", "!", "?", "\"", "'"): + slug = slug.replace(ch, "-") + # Collapse repeated hyphens and strip leading/trailing + while "--" in slug: + slug = slug.replace("--", "-") + slug = slug.strip("-") + return f"investigation-{slug}"[:80] diff --git a/tests/unit/context/test_workspace_extended.py b/tests/unit/context/test_workspace_extended.py new file mode 100644 index 00000000..33d9cd7e --- /dev/null +++ b/tests/unit/context/test_workspace_extended.py @@ -0,0 +1,713 @@ +""" +tests/unit/context/test_workspace_extended.py +============================================== + +Extended unit tests for gnat.context.workspace and gnat.context.global_context. + +Targets uncovered lines including: +- WorkspaceManager: for_tenant(), delete(), list() with SQLite store, _default_store fallback +- Workspace: save(), export_bundle() non-FlatFile path, remove() with WorkspaceStore, + get_enrichment_history(), commit() error paths and deletion handling, + enrich() RuntimeError fallback, aenrich(), _enrich_async unknown source +- GlobalContext: get_object() +- GlobalContextRegistry: from_config(), unregister() clears default_name, all() sort +""" + +from __future__ import annotations + +import asyncio +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from gnat.context.store import FlatFileStore +from gnat.context.global_context import GlobalContext, GlobalContextRegistry +from gnat.context.workspace import Workspace, WorkspaceManager, CommitResult +from gnat.orm.indicator import Indicator +from gnat.orm.malware import Malware + + +# =========================================================================== +# Helpers (mirrors test_context.py helpers to avoid import coupling) +# =========================================================================== + +def _make_indicator(name: str = "evil.com", value: str = "evil.com") -> Indicator: + return Indicator( + name=name, + pattern=f"[domain-name:value = '{value}']", + pattern_type="stix", + indicator_types=["malicious-activity"], + ) + + +def _mock_global_context(name: str = "threatq", read_only: bool = False, + objects: list = None) -> GlobalContext: + mock_cli = MagicMock() + mock_cli.target = name + mock_cli.ping.return_value = True + + if objects is None: + objects = [_make_indicator(f"obj-{i}").to_dict() for i in range(2)] + + mock_cli.client.list_objects.return_value = [ + {"id": o["id"], "value": o.get("name", ""), "type": "indicator"} + for o in objects + ] + mock_cli.client.to_stix.side_effect = lambda raw: { + "type": "indicator", + "id": raw.get("id", "indicator--mock"), + "name": raw.get("value", ""), + "pattern": f"[domain-name:value = '{raw.get('value', '')}']", + "pattern_type": "stix", + "created": "", "modified": "", + "indicator_types": ["malicious-activity"], + } + mock_cli.client.from_stix.return_value = {"value": "mocked"} + mock_cli.client.upsert_object.return_value = {"id": "indicator--written", "value": "mocked"} + mock_cli.client.delete_object.return_value = None + mock_cli.client.get_object.return_value = {"id": "indicator--x", "value": "x.com", "type": "indicator"} + + gc = GlobalContext(name=name, client=mock_cli, read_only=read_only) + return gc + + +def _make_registry(names=("threatq", "recorded_future", "crowdstrike"), + default="threatq", read_only=("recorded_future",)): + registry = GlobalContextRegistry(default_name=default) + for name in names: + gc = _mock_global_context(name=name, read_only=(name in read_only)) + registry.register(gc) + return registry + + +def _make_workspace(name="test-ws", registry=None, store=None, tmp_path=None): + if registry is None: + registry = _make_registry() + if store is None: + store = FlatFileStore(base_dir=str( + (tmp_path or Path(tempfile.mkdtemp())) / "workspaces" + )) + return Workspace(name, registry, store) + + +def _sqlite_store(): + try: + from gnat.context.store import WorkspaceStore + store = WorkspaceStore("sqlite:///:memory:") + store.create_all() + return store + except ImportError: + return None + + +# =========================================================================== +# GlobalContext — extended +# =========================================================================== + +class TestGlobalContextExtended: + + def test_get_object_delegates_to_client(self): + gc = _mock_global_context("tq") + gc.client.client.get_object.return_value = {"id": "indicator--x", "value": "x.com", "type": "indicator"} + gc.client.client.to_stix.side_effect = None + gc.client.client.to_stix.return_value = { + "type": "indicator", "id": "indicator--x", + "name": "x.com", "pattern": "[domain-name:value = 'x.com']", + "pattern_type": "stix", "created": "", "modified": "", + "indicator_types": ["malicious-activity"], + } + result = gc.get_object("indicator", "indicator--x") + gc.client.client.get_object.assert_called_once_with("indicator", "indicator--x") + assert result["id"] == "indicator--x" + + def test_delete_object_delegates_to_client(self): + gc = _mock_global_context("tq") + gc.delete_object("indicator", "indicator--x") + gc.client.client.delete_object.assert_called_once_with("indicator", "indicator--x") + + def test_priority_attribute(self): + gc = GlobalContext("tq", MagicMock(), priority=5) + assert gc.priority == 5 + + def test_description_attribute(self): + gc = GlobalContext("tq", MagicMock(), description="My platform") + assert gc.description == "My platform" + + +# =========================================================================== +# GlobalContextRegistry — extended +# =========================================================================== + +class TestGlobalContextRegistryExtended: + + def test_unregister_clears_default_name(self): + """Unregistering the current default clears the _default_name.""" + registry = _make_registry() + registry.set_default("threatq") + assert registry._default_name == "threatq" + result = registry.unregister("threatq") + assert result is True + assert registry._default_name is None + + def test_all_sorted_by_priority(self): + registry = GlobalContextRegistry() + gc_low = GlobalContext("low", MagicMock(), priority=20) + gc_high = GlobalContext("high", MagicMock(), priority=1) + registry.register(gc_low) + registry.register(gc_high) + ordered = registry.all() + assert ordered[0].name == "high" + assert ordered[1].name == "low" + + def test_from_clients_no_default(self): + """from_clients with no default sets no default_name.""" + clients = { + "tq": MagicMock(target="threatq", ping=MagicMock(return_value=True), client=MagicMock()), + } + registry = GlobalContextRegistry.from_clients(clients) + # Should have the context registered + assert "tq" in registry + + def test_from_clients_with_read_only_list(self): + clients = { + "tq": MagicMock(target="threatq", ping=MagicMock(return_value=True), client=MagicMock()), + "rf": MagicMock(target="recordedfuture", ping=MagicMock(return_value=True), client=MagicMock()), + } + registry = GlobalContextRegistry.from_clients(clients, default="tq", read_only=["rf"]) + assert registry.get("rf").read_only is True + assert registry.get("tq").read_only is False + + def test_from_config_missing_target_warns(self, tmp_path): + """[global.noname] sections without 'target' are skipped.""" + cfg_path = tmp_path / "config.ini" + cfg_path.write_text( + "[global]\ndefault =\n\n[global.missing-target]\n# no target key\n" + ) + # Should not raise — just log warning and skip + registry = GlobalContextRegistry.from_config(str(cfg_path)) + assert "missing-target" not in registry + + def test_from_config_no_global_section(self, tmp_path): + """Config with no [global] section should return empty registry.""" + cfg_path = tmp_path / "config.ini" + cfg_path.write_text("[DEFAULT]\ntimeout = 10\n") + registry = GlobalContextRegistry.from_config(str(cfg_path)) + assert len(registry) == 0 + + def test_writable_returns_empty_when_all_read_only(self): + registry = GlobalContextRegistry() + registry.register(GlobalContext("rf", MagicMock(), read_only=True)) + assert registry.writable() == [] + + def test_default_property_auto_selects_writable(self): + registry = GlobalContextRegistry() + registry.register(GlobalContext("rf", MagicMock(), read_only=True, priority=1)) + registry.register(GlobalContext("tq", MagicMock(), read_only=False, priority=2)) + # No explicit default set — should pick lowest-priority writable + default = registry.default + assert default.name == "tq" + + +# =========================================================================== +# Workspace — save and export_bundle +# =========================================================================== + +class TestWorkspaceSave: + + def test_save_persists_all_objects(self, tmp_path): + ws = _make_workspace(tmp_path=tmp_path) + ind1 = _make_indicator("a.com") + ind2 = _make_indicator("b.com") + ws.add(ind1, mark_dirty=False) + ws.add(ind2, mark_dirty=True) + # Should not raise + ws.save() + + def test_export_bundle_non_flatfile(self, tmp_path): + """export_bundle with a non-FlatFile store returns a valid bundle dict.""" + store = _sqlite_store() + if store is None: + pytest.skip("SQLAlchemy not installed") + registry = _make_registry() + ws = Workspace("bundle-test", registry, store) + ind = _make_indicator("x.com") + ws.add(ind, mark_dirty=False) + bundle = ws.export_bundle() + assert bundle["type"] == "bundle" + assert bundle["spec_version"] == "2.1" + assert any(o["id"] == ind.id for o in bundle["objects"]) + + def test_export_bundle_flatfile(self, tmp_path): + """export_bundle with FlatFileStore delegates to store.export_bundle.""" + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("y.com") + ws.add(ind, mark_dirty=False) + bundle = ws.export_bundle() + assert bundle["type"] == "bundle" + + +# =========================================================================== +# Workspace — get_enrichment_history +# =========================================================================== + +class TestWorkspaceEnrichmentHistory: + + def test_get_enrichment_history_flatfile(self, tmp_path): + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator() + ws.add(ind, mark_dirty=False) + ws._apply_enrichment(ind, {"x_score": 80, "type": "indicator", + "id": f"indicator--enrich-hist", + "name": "hist.com", + "pattern": "[domain-name:value = 'hist.com']", + "pattern_type": "stix", + "created": "", "modified": "", + "indicator_types": ["malicious-activity"]}, + "recorded_future", "create_relationships") + history = ws.get_enrichment_history() + assert isinstance(history, list) + + def test_get_enrichment_history_filtered(self, tmp_path): + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("filter.com") + ws.add(ind, mark_dirty=False) + ws._log_enrichment(ind.id, "rf", {"score": 90}, "tag_only") + history = ws.get_enrichment_history(stix_id=ind.id) + assert isinstance(history, list) + + def test_get_enrichment_history_sqlite(self): + store = _sqlite_store() + if store is None: + pytest.skip("SQLAlchemy not installed") + registry = _make_registry() + ws = Workspace("hist-test", registry, store) + ind = _make_indicator() + ws.add(ind, mark_dirty=False) + ws._log_enrichment(ind.id, "rf", {"score": 90}, "tag_only") + history = ws.get_enrichment_history() + assert isinstance(history, list) + assert len(history) >= 1 + + +# =========================================================================== +# Workspace — commit error paths and deletions +# =========================================================================== + +class TestWorkspaceCommitExtended: + + def test_commit_error_on_write_failure(self, tmp_path): + """Commit records errors when write_object raises.""" + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("fail.com") + ws.add(ind, mark_dirty=True) + + ws._registry.default.client.client.from_stix.return_value = {} + ws._registry.default.client.client.upsert_object.side_effect = RuntimeError("network error") + + result = ws.commit() + assert not result.success + assert len(result.errors) == 1 + assert "network error" in result.errors[0]["error"] + + def test_commit_deletion_of_removed_object(self, tmp_path): + """Deleted objects (in snapshot but not in objects) are committed as deletions.""" + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("remove-me.com") + ws.add(ind, mark_dirty=False) + ws._snapshot[ind.id] = ind.to_dict() + + # Simulate removal from objects but keep in snapshot + del ws.objects[ind.id] + # Don't add to dirty — deletion is detected via snapshot diff + + ws._registry.default.client.client.delete_object.return_value = None + + result = ws.commit() + assert ind.id in result.deleted + assert ind.id not in ws._snapshot + + def test_commit_deletion_dry_run(self, tmp_path): + """dry_run includes deleted objects in would_write.""" + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("dry-delete.com") + ws.add(ind, mark_dirty=False) + ws._snapshot[ind.id] = ind.to_dict() + del ws.objects[ind.id] + + result = ws.commit(dry_run=True) + deleted_entries = [e for e in result.would_write if e["action"] == "deleted"] + assert len(deleted_entries) == 1 + + def test_commit_deletion_error_path(self, tmp_path): + """Errors during deletion are recorded in result.errors.""" + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("err-del.com") + ws.add(ind, mark_dirty=False) + ws._snapshot[ind.id] = ind.to_dict() + del ws.objects[ind.id] + + ws._registry.default.client.client.delete_object.side_effect = RuntimeError("del error") + + result = ws.commit() + assert not result.success + assert any("del error" in e["error"] for e in result.errors) + + def test_commit_marks_clean_after_success(self, tmp_path): + """After a successful commit the dirty set is cleared.""" + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("clean.com") + ws.add(ind, mark_dirty=True) + assert ind.id in ws.dirty + + ws._registry.default.client.client.from_stix.return_value = {} + ws._registry.default.client.client.upsert_object.return_value = { + "id": ind.id, "value": ind.name, "type": "indicator" + } + ws._registry.default.client.client.to_stix.side_effect = None + ws._registry.default.client.client.to_stix.return_value = ind.to_dict() + + result = ws.commit() + assert result.success + assert ind.id not in ws.dirty + + def test_commit_with_stix_ids_subset_skips_non_dirty(self, tmp_path): + """Committing by stix_ids only commits the requested ids.""" + ws = _make_workspace(tmp_path=tmp_path) + ind1 = _make_indicator("a.com") + ind2 = _make_indicator("b.com") + ws.add(ind1, mark_dirty=True) + ws.add(ind2, mark_dirty=True) + + ws._registry.default.client.client.from_stix.return_value = {} + ws._registry.default.client.client.upsert_object.return_value = { + "id": ind1.id, "value": ind1.name, "type": "indicator" + } + ws._registry.default.client.client.to_stix.side_effect = None + ws._registry.default.client.client.to_stix.return_value = ind1.to_dict() + + result = ws.commit(stix_ids=[ind1.id]) + assert ind1.id in result.written + assert ind2.id not in result.written + + +# =========================================================================== +# Workspace — remove with WorkspaceStore +# =========================================================================== + +class TestWorkspaceRemoveWithStore: + + def test_remove_uses_soft_delete_with_sqlite(self): + store = _sqlite_store() + if store is None: + pytest.skip("SQLAlchemy not installed") + registry = _make_registry() + ws = Workspace("remove-test", registry, store) + ind = _make_indicator("remove.com") + ws.add(ind, mark_dirty=False) + + result = ws.remove(ind.id) + assert result is True + assert ind.id not in ws.objects + + def test_remove_adds_to_dirty(self, tmp_path): + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("dirty-remove.com") + ws.add(ind, mark_dirty=False) + ws.remove(ind.id) + assert ind.id in ws.dirty + + +# =========================================================================== +# Workspace — enrich() RuntimeError fallback +# =========================================================================== + +class TestWorkspaceEnrichFallback: + + def test_enrich_falls_back_to_sequential_on_runtime_error(self, tmp_path): + """enrich() should fall back to _enrich_sequential when no event loop.""" + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("enrich.com") + ws.add(ind, mark_dirty=False) + + # Patch asyncio.get_event_loop to raise RuntimeError + with patch("asyncio.get_event_loop", side_effect=RuntimeError("no loop")): + with patch.object(ws, "_enrich_sequential") as mock_seq: + ws.enrich(sources=["recorded_future"]) + mock_seq.assert_called_once() + + def test_enrich_sequential_unknown_source_skips(self, tmp_path): + """_enrich_sequential silently skips unknown sources.""" + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("x.com") + ws.add(ind, mark_dirty=False) + # Should not raise for unknown source + ws._enrich_sequential(["nonexistent_source"], [ind.id], "tag_only", 0) + assert len(ws) == 1 # no new objects + + def test_enrich_sequential_handles_exception_in_source(self, tmp_path): + """_enrich_sequential swallows exceptions from individual source queries.""" + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("exc.com") + ws.add(ind, mark_dirty=False) + + # Make the registry source raise on list_objects + gc = _mock_global_context("error_source", read_only=True) + gc.client.client.list_objects.side_effect = RuntimeError("source down") + ws._registry.register(gc) + + # Should not raise + ws._enrich_sequential(["error_source"], [ind.id], "tag_only", 0) + assert len(ws) == 1 + + def test_enrich_async_unknown_source_warns(self, tmp_path): + """_enrich_async logs a warning and skips unknown sources.""" + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("async-x.com") + ws.add(ind, mark_dirty=False) + # Run async method with unknown source + asyncio.run(ws._enrich_async(["nonexistent"], [ind.id], "tag_only", 0)) + assert len(ws) == 1 + + +# =========================================================================== +# Workspace — aenrich() +# =========================================================================== + +class TestWorkspaceAenrich: + + def test_aenrich_returns_self(self, tmp_path): + ws = _make_workspace(tmp_path=tmp_path) + ind = _make_indicator("aenrich.com") + ws.add(ind, mark_dirty=False) + + result = asyncio.run(ws.aenrich(sources=["recorded_future"], stix_ids=[ind.id])) + assert result is ws + + def test_aenrich_with_no_objects(self, tmp_path): + """aenrich on empty workspace should not raise.""" + ws = _make_workspace(tmp_path=tmp_path) + result = asyncio.run(ws.aenrich(sources=["recorded_future"])) + assert result is ws + + +# =========================================================================== +# WorkspaceManager — extended +# =========================================================================== + +class TestWorkspaceManagerExtended: + + def test_for_tenant_returns_tenant_manager(self, tmp_path): + store = FlatFileStore(base_dir=str(tmp_path / "workspaces")) + manager = WorkspaceManager(_make_registry(), store=store) + tenant_mgr = manager.for_tenant("acme") + assert tenant_mgr.tenant_id == "acme" + + def test_for_tenant_isolation(self, tmp_path): + """Two tenants share the same store but have isolated namespaces.""" + store = FlatFileStore(base_dir=str(tmp_path / "workspaces")) + manager = WorkspaceManager(_make_registry(), store=store) + + acme = manager.for_tenant("acme") + beta = manager.for_tenant("beta") + + acme.create("investigation") + # beta should not see acme's workspace + assert len(beta.list()) == 0 + assert len(acme.list()) == 1 + + def test_delete_workspace_flatfile(self, tmp_path): + store = FlatFileStore(base_dir=str(tmp_path / "workspaces")) + manager = WorkspaceManager(_make_registry(), store=store) + manager.create("to-delete") + assert manager.delete("to-delete") is True + + def test_delete_nonexistent_workspace(self, tmp_path): + store = FlatFileStore(base_dir=str(tmp_path / "workspaces")) + manager = WorkspaceManager(_make_registry(), store=store) + assert manager.delete("does-not-exist") is False + + def test_list_with_sqlite_store(self): + store = _sqlite_store() + if store is None: + pytest.skip("SQLAlchemy not installed") + registry = _make_registry() + manager = WorkspaceManager(registry, store=store) + manager.create("ws-alpha") + manager.create("ws-beta") + listed = manager.list() + names = [w["name"] for w in listed] + assert "ws-alpha" in names + assert "ws-beta" in names + + def test_list_sqlite_includes_object_count(self): + store = _sqlite_store() + if store is None: + pytest.skip("SQLAlchemy not installed") + registry = _make_registry() + manager = WorkspaceManager(registry, store=store) + ws = manager.create("count-ws") + ws.add(_make_indicator("x.com"), mark_dirty=False) + listed = manager.list() + entry = next(w for w in listed if w["name"] == "count-ws") + assert "object_count" in entry + assert entry["object_count"] == 1 + + def test_open_with_sqlite_nonexistent_raises(self): + store = _sqlite_store() + if store is None: + pytest.skip("SQLAlchemy not installed") + registry = _make_registry() + manager = WorkspaceManager(registry, store=store) + with pytest.raises(KeyError, match="No workspace"): + manager.open("ghost") + + def test_create_duplicate_flatfile_does_not_raise(self, tmp_path): + """FlatFileStore uses get_or_create internally — creating duplicate is fine.""" + store = FlatFileStore(base_dir=str(tmp_path / "workspaces")) + manager = WorkspaceManager(_make_registry(), store=store) + manager.create("dup-ws") + # FlatFileStore's create_workspace just overwrites — no error + ws2 = manager.create("dup-ws") + assert ws2.name == "dup-ws" + + def test_default_store_falls_back_to_flatfile(self, tmp_path): + """_default_store() returns a FlatFileStore when WorkspaceStore init fails.""" + from gnat.context.store import WorkspaceStore + with patch.object(WorkspaceStore, "__init__", side_effect=Exception("db error")): + store = WorkspaceManager._default_store("sqlite:///bad.db") + assert isinstance(store, FlatFileStore) + + def test_from_clients_with_db_url(self): + clients = { + "tq": MagicMock(target="threatq", ping=MagicMock(return_value=True), + client=MagicMock()), + } + manager = WorkspaceManager.from_clients( + clients, default="tq", db_url="sqlite:///:memory:" + ) + assert manager._registry.default.name == "tq" + + +# =========================================================================== +# WorkspaceManager — with WorkspaceStore (SQLite) +# =========================================================================== + +class TestWorkspaceManagerSQLite: + + @pytest.fixture + def sqlite_manager(self): + store = _sqlite_store() + if store is None: + pytest.skip("SQLAlchemy not installed") + return WorkspaceManager(_make_registry(), store=store) + + def test_create_and_open(self, sqlite_manager): + sqlite_manager.create("open-test") + ws = sqlite_manager.open("open-test") + assert ws.name == "open-test" + + def test_delete_existing(self, sqlite_manager): + sqlite_manager.create("del-test") + assert sqlite_manager.delete("del-test") is True + + def test_get_or_create_opens_existing(self, sqlite_manager): + sqlite_manager.create("existing") + ws = sqlite_manager.get_or_create("existing") + assert ws.name == "existing" + + def test_get_or_create_creates_new(self, sqlite_manager): + ws = sqlite_manager.get_or_create("brand-new") + assert ws.name == "brand-new" + + def test_persistence_across_instances(self, sqlite_manager): + """Objects added to one workspace instance should appear in a reopened instance.""" + ws1 = sqlite_manager.create("persist") + ind = _make_indicator("persist.com") + ws1.add(ind, mark_dirty=False) + + ws2 = sqlite_manager.open("persist") + assert ind.id in ws2.objects + + def test_workspace_init_with_sqlite_store_hydrates(self): + """Workspace._init_store uses the WorkspaceStore path when applicable.""" + store = _sqlite_store() + if store is None: + pytest.skip("SQLAlchemy not installed") + registry = _make_registry() + ws = Workspace("hydrate-test", registry, store) + ind = _make_indicator("hydrate.com") + ws.add(ind, mark_dirty=False) + + ws2 = Workspace("hydrate-test", registry, store) + assert ind.id in ws2.objects + + +# =========================================================================== +# CommitResult +# =========================================================================== + +class TestCommitResultExtended: + + def test_deleted_populated_on_deletion(self): + result = CommitResult("ws", "tq", False) + result.deleted.append("indicator--x") + assert "indicator--x" in result.deleted + + def test_would_write_populated_on_dry_run(self): + result = CommitResult("ws", "tq", True) + result.would_write.append({"id": "indicator--x", "action": "added"}) + assert len(result.would_write) == 1 + + def test_success_false_when_deleted_errors(self): + result = CommitResult("ws", "tq", False) + result.errors.append({"id": "indicator--x", "error": "not found"}) + assert result.success is False + + def test_success_true_with_written_and_deleted(self): + result = CommitResult("ws", "tq", False) + result.written.append("indicator--a") + result.deleted.append("indicator--b") + assert result.success is True + + +# =========================================================================== +# Workspace — _from_dict type dispatch +# =========================================================================== + +class TestWorkspaceFromDict: + + def test_from_dict_indicator(self): + from gnat.orm.indicator import Indicator + ind = _make_indicator() + obj = Workspace._from_dict(ind.to_dict()) + assert isinstance(obj, Indicator) + + def test_from_dict_malware(self): + from gnat.orm.malware import Malware + mal = Malware(name="BadMal") + obj = Workspace._from_dict(mal.to_dict()) + assert isinstance(obj, Malware) + + def test_from_dict_unknown_type_falls_back_to_stixbase(self): + from gnat.orm.base import STIXBase + d = {"type": "x-custom", "id": "x-custom--abc", "name": "custom"} + obj = Workspace._from_dict(d) + assert isinstance(obj, STIXBase) + + def test_from_dict_vulnerability(self): + from gnat.orm.vulnerability import Vulnerability + vuln = Vulnerability(name="CVE-2024-0001") + obj = Workspace._from_dict(vuln.to_dict()) + assert isinstance(obj, Vulnerability) + + def test_from_dict_relationship(self): + from gnat.orm.relationship import Relationship + rel = Relationship( + relationship_type="related-to", + source_ref="indicator--a", + target_ref="indicator--b", + ) + obj = Workspace._from_dict(rel.to_dict()) + assert isinstance(obj, Relationship)