From 0170e315df5e3f6f7970388a53d856218265cb91 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 9 Apr 2026 00:10:42 +0000 Subject: [PATCH] Implement GNAT Phase 4: Control, Reasoning, Safety MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 4A — Execution Context & Domain Boundaries - ExecutionContext dataclass (context_id, initiated_by, domain, trust_level, policy_set, workspace_id, created_at, parent_context_id, is_replay, budget) - QueryBudget: finite connector query budget; raises BudgetExceeded when exhausted - Domain enum + DOMAIN_CALL_RULES + @domain_boundary decorator (thread-local stack) - DomainBoundaryViolation, TrustLevelViolation, @require_trust_level decorator - Migration 0004: execution_log table Phase P-1 — Connector Trust & Versioning - BaseClient: TRUST_LEVEL, API_VERSION, API_PREFIX, COST_UNIT class vars - BaseClient._request(): deducts COST_UNIT from ExecutionContext budget - BudgetExceeded(GNATClientError) exception - 16 connectors updated with explicit trust/version/prefix assignments Phase 4B — Idempotency & Schema Evolution - Migration 0005: idempotency_key on workspace_objects - WorkspaceObjectModel: idempotency_key column + make_idempotency_key() helper - STIXBase: schema_version = 1 - Migration 0006: agent_sessions + agent_actions tables Phase 4C — Hypothesis Engine, Negative Evidence, Reasoning - STIXHypothesis (x-gnat-hypothesis): statement, confidence, status, supporting/refuting evidence; full lifecycle with close(verdict) - NegativeEvidenceRecord (x-gnat-negative-evidence): TTL-based suppression; is_expired(), seconds_remaining() - HypothesisEngine: propose → evaluate → close; Solr corroboration; trust-weighted confidence; auto-classify at thresholds - ReasoningEngine: prioritize() with composite scoring (trust×0.4 + age×0.3 + corroboration×0.3 − neg_penalty×0.5); structured explanation dicts; STIX note storage Phase 4D — Agent Governance & HITL - AgentActionType enum + agent_can_act() trust-level permission matrix - AgentGovernor: can_act(), require_can_act(), record_action(), rate_limit_check() (sliding window), per-agent policy overrides - AgentAction dataclass with impact_level validation - HITLGateway: bridges AgentGovernor to existing gnat/review/; low/medium auto-approve; high→ReviewItem PENDING; critical→PENDING+XSOAR Phase 4E — Isolation, Performance, Testing - Migrations 0007 (trust_boundary on workspaces) + 0008 (query_cost_log) - WorkspaceModel: trust_boundary + allowed_connector_refs columns - Workspace: trust_boundary/allowed_connector_refs attrs + check_connector_trust() - SimulationConnector: canned STIX fixtures, no network - ReplayRunner: execution_log replay with assertion support - AgentTestHarness: mock HITL approvals for deterministic tests Tests: 90 new unit tests covering all Phase 4 components (3955 total, +90) https://claude.ai/code/session_01BDoue9HxB83ijLzFARAugq --- CHANGELOG.md | 38 ++ alembic/versions/0004_add_execution_log.py | 46 ++ alembic/versions/0005_add_idempotency.py | 44 ++ alembic/versions/0006_add_agent_tables.py | 52 ++ .../versions/0007_workspace_trust_boundary.py | 44 ++ alembic/versions/0008_query_cost_log.py | 40 ++ gnat/agents/governor.py | 366 ++++++++++++++ gnat/agents/hitl.py | 285 +++++++++++ gnat/clients/base.py | 55 +++ gnat/connectors/CISA/client.py | 6 + gnat/connectors/alienvault/client.py | 6 + gnat/connectors/crowdstrike/client.py | 6 + gnat/connectors/elastic/client.py | 6 + gnat/connectors/feedly/client.py | 6 + gnat/connectors/graylog/client.py | 6 + gnat/connectors/misp/client.py | 6 + gnat/connectors/qradar/client.py | 6 + gnat/connectors/recordedfuture/client.py | 5 + gnat/connectors/recordedfuture/rfv3.py | 2 + gnat/connectors/security_onion/client.py | 6 + gnat/connectors/sentinel/client.py | 6 + gnat/connectors/shadowserver/client.py | 6 + gnat/connectors/splunk/client.py | 6 + gnat/connectors/threatq/client.py | 6 + gnat/connectors/virustotal/client.py | 6 + gnat/connectors/wazuh/client.py | 6 + gnat/connectors/xsoar/client.py | 6 + gnat/context/store.py | 71 ++- gnat/context/workspace.py | 50 ++ gnat/core/__init__.py | 26 + gnat/core/context.py | 297 ++++++++++++ gnat/core/domains.py | 280 +++++++++++ gnat/orm/base.py | 3 + gnat/policy/__init__.py | 4 + gnat/policy/models.py | 43 ++ gnat/reasoning/__init__.py | 14 + gnat/reasoning/engine.py | 273 +++++++++++ gnat/reasoning/hypothesis.py | 256 ++++++++++ gnat/stix/sdos/__init__.py | 17 + gnat/stix/sdos/hypothesis.py | 201 ++++++++ gnat/stix/sdos/negative_evidence.py | 149 ++++++ gnat/testing/__init__.py | 33 ++ gnat/testing/simulation.py | 401 +++++++++++++++ tests/unit/test_phase4_core.py | 443 +++++++++++++++++ tests/unit/test_phase4_governor.py | 409 ++++++++++++++++ tests/unit/test_phase4_reasoning.py | 458 ++++++++++++++++++ 46 files changed, 4500 insertions(+), 1 deletion(-) create mode 100644 alembic/versions/0004_add_execution_log.py create mode 100644 alembic/versions/0005_add_idempotency.py create mode 100644 alembic/versions/0006_add_agent_tables.py create mode 100644 alembic/versions/0007_workspace_trust_boundary.py create mode 100644 alembic/versions/0008_query_cost_log.py create mode 100644 gnat/agents/governor.py create mode 100644 gnat/agents/hitl.py create mode 100644 gnat/core/__init__.py create mode 100644 gnat/core/context.py create mode 100644 gnat/core/domains.py create mode 100644 gnat/reasoning/__init__.py create mode 100644 gnat/reasoning/engine.py create mode 100644 gnat/reasoning/hypothesis.py create mode 100644 gnat/stix/sdos/__init__.py create mode 100644 gnat/stix/sdos/hypothesis.py create mode 100644 gnat/stix/sdos/negative_evidence.py create mode 100644 gnat/testing/__init__.py create mode 100644 gnat/testing/simulation.py create mode 100644 tests/unit/test_phase4_core.py create mode 100644 tests/unit/test_phase4_governor.py create mode 100644 tests/unit/test_phase4_reasoning.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c0f3eb2..02ddc500 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,44 @@ Detailed per-version release notes are available in [`docs/releases/`](docs/rele ## [Unreleased] +### Added — Phase 4: Control, Reasoning, Safety + +**4A — Execution Context & Domain Boundaries** +- `gnat/core/context.py`: `ExecutionContext` dataclass carrying `context_id` (UUID), `initiated_by`, `domain`, `trust_level`, `policy_set`, `workspace_id`, `created_at`, `parent_context_id`, `is_replay`; factory methods `create()`, `from_connector()`, `child()` +- `gnat/core/context.py`: `QueryBudget` dataclass — finite query budget for connector calls; `consume()` raises `BudgetExceeded` when exhausted; attached to `ExecutionContext` via `max_budget_units` param on `create()` +- `gnat/core/domains.py`: `Domain` enum (ingestion/analysis/investigation/reporting/execution); `DOMAIN_CALL_RULES` permission graph; `@domain_boundary(target_domain)` decorator with thread-local stack enforcement; `DomainBoundaryViolation` and `TrustLevelViolation` exceptions; `@require_trust_level(minimum)` decorator +- `alembic/versions/0004_add_execution_log.py`: `execution_log` table (context_id PK, initiated_by, domain, trust_level, policy_set, workspace_id, created_at, parent_context_id, is_replay, event_type, notes) + +**P-1 — Connector Trust & Versioning** +- `BaseClient`: added `TRUST_LEVEL: str = "semi_trusted"`, `API_VERSION: str = ""`, `API_PREFIX: str = ""`, `COST_UNIT: int = 1` class variables; added `_context: Any = None` attribute for budget tracking +- `BaseClient._request()`: deducts `COST_UNIT` from `ExecutionContext.budget` when a context is attached; raises `BudgetExceeded` when exhausted +- `BudgetExceeded(GNATClientError)`: new exception with `connector`, `cost`, `remaining` attributes +- 16 connectors updated with explicit `TRUST_LEVEL`, `API_VERSION`, `API_PREFIX`: Splunk, XSOAR, Graylog, Security Onion, Sentinel, QRadar, Elastic, Wazuh (trusted_internal); ThreatQ, CrowdStrike, Feedly, VirusTotal, MISP, Recorded Future (semi_trusted); AlienVault, Shadowserver (untrusted_external) + +**4B — Idempotency & Schema Evolution** +- `alembic/versions/0005_add_idempotency.py`: `idempotency_key VARCHAR(255)` column with partial unique index on `workspace_objects` +- `WorkspaceObjectModel`: `idempotency_key` column added; `WorkspaceStore.make_idempotency_key()` static method computing `{connector_id}:{stix_type}:{external_id}:{sha1[:12]}` +- `STIXBase`: `schema_version: int = 1` class variable for ORM versioning +- `alembic/versions/0006_add_agent_tables.py`: `agent_sessions` and `agent_actions` tables + +**4C — Hypothesis Engine, Negative Evidence, Reasoning** +- `gnat/stix/sdos/hypothesis.py`: `STIXHypothesis` custom SDO (`x-gnat-hypothesis`); fields: statement, confidence [0-1], status (pending/confirmed/refuted/inconclusive), supporting_evidence[], refuting_evidence[]; methods: `add_supporting_evidence()`, `add_refuting_evidence()`, `update_confidence()`, `close(verdict)`; full `to_dict()`/`from_dict()` serialization +- `gnat/stix/sdos/negative_evidence.py`: `NegativeEvidenceRecord` custom SDO (`x-gnat-negative-evidence`); fields: target_ref, queried_connector, ttl_seconds, query_timestamp; methods: `is_expired()`, `seconds_remaining()` +- `gnat/reasoning/hypothesis.py`: `HypothesisEngine` — `propose()`, `evaluate()` (Solr corroboration + weighted confidence), `close()`, `get()`, `list_all()`; confidence scoring: trusted_internal→0.9, semi_trusted→0.6, untrusted_external→0.3; auto-classify ≥0.75→confirmed, ≤0.15+refutation→refuted +- `gnat/reasoning/engine.py`: `ReasoningEngine` — `prioritize(observables, context, store_notes)` returning `[(observable, score, explanation)]` sorted descending; composite score: trust_weight×0.4 + age_factor×0.3 + corroboration_bonus×0.3 − neg_penalty×0.5; structured machine-readable explanation dicts; STIX `note` objects stored per scored observable + +**4D — Agent Governance & HITL** +- `gnat/policy/models.py`: `AgentActionType` enum (read_stix/write_stix/delete_stix/enrich/ingest/export/trigger_playbook/manage_workspace/escalate/hypothesize); `agent_can_act(trust_level, action_type)` matrix; `_TRUST_ACTION_PERMISSIONS` per trust level +- `gnat/agents/governor.py`: `AgentGovernor` — `can_act()`, `require_can_act()`, `record_action()`, `rate_limit_check()` (sliding-window counter), `get_action_log()`, `set_policy_override()`; `AgentAction` dataclass with `to_dict()`; `RateLimitExceeded` and `AgentPermissionDenied` exceptions +- `gnat/agents/hitl.py`: `HITLGateway` bridging `AgentGovernor` to existing `gnat/review/service.py`; four-tier impact model: low/medium auto-approve, high→ReviewItem PENDING, critical→PENDING + XSOAR notification via `XSOARClient.upsert_object()`; timeout auto-rejection; `evaluate()`, `submit_for_approval()`, `check_approval_status()`, `auto_approve_pending()` + +**4E — Isolation, Performance, Testing** +- `alembic/versions/0007_workspace_trust_boundary.py`: `trust_boundary VARCHAR(50)` and `allowed_connector_refs TEXT` columns on `workspaces` +- `alembic/versions/0008_query_cost_log.py`: `query_cost_log` table for per-connector cost tracking +- `WorkspaceModel`: `trust_boundary` and `allowed_connector_refs` columns added +- `Workspace`: `trust_boundary` and `allowed_connector_refs` attributes loaded from DB; `check_connector_trust(connector)` enforces trust rank and allowlist at connector instantiation +- `gnat/testing/__init__.py` + `gnat/testing/simulation.py`: `SimulationConnector(BaseClient)` — canned STIX fixtures, no network; `ReplayRunner` — replays `execution_log` sequences through pipeline with assertion support; `AgentTestHarness` — mock-approves all HITL submissions for deterministic agent tests + ### Added — AI & Connector Improvements **Google Gemini provider (`gnat/agents/gemini.py`)** diff --git a/alembic/versions/0004_add_execution_log.py b/alembic/versions/0004_add_execution_log.py new file mode 100644 index 00000000..4caa99be --- /dev/null +++ b/alembic/versions/0004_add_execution_log.py @@ -0,0 +1,46 @@ +"""Add execution_log table (Phase 4A). + +Append-only audit log for all GNAT operations. Every pipeline run, +enrichment call, connector request, and agent action writes one row. + +Revision ID: 0004 +Revises: 0003 +Create Date: 2026-04-08 00:00:04.000000 + +""" +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision = "0004" +down_revision = "0003" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "execution_log", + sa.Column("context_id", sa.String(36), primary_key=True), + sa.Column("initiated_by", sa.String(128), nullable=False), + sa.Column("domain", sa.String(32), nullable=False, index=True), + sa.Column("trust_level", sa.String(32), nullable=False, index=True), + sa.Column("policy_set", sa.String(64), nullable=False, server_default="default"), + sa.Column("workspace_id", sa.String(256), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, index=True), + sa.Column("parent_context_id", sa.String(36), nullable=True), + sa.Column("is_replay", sa.Boolean, nullable=False, server_default="0"), + sa.Column("event_type", sa.String(64), nullable=True), # "security_event", "replay_event", etc. + sa.Column("notes", sa.Text, nullable=True), + ) + op.create_index( + "ix_execution_log_workspace_domain", + "execution_log", + ["workspace_id", "domain"], + ) + + +def downgrade() -> None: + op.drop_index("ix_execution_log_workspace_domain", "execution_log") + op.drop_table("execution_log") diff --git a/alembic/versions/0005_add_idempotency.py b/alembic/versions/0005_add_idempotency.py new file mode 100644 index 00000000..3471b619 --- /dev/null +++ b/alembic/versions/0005_add_idempotency.py @@ -0,0 +1,44 @@ +"""Add idempotency_key to workspace_objects (Phase 4B). + +Enables safe pipeline replay: re-ingesting the same STIX object produces +a single stored row with a logged replay_event rather than a duplicate. + +Revision ID: 0005 +Revises: 0004 +Create Date: 2026-04-08 00:00:05.000000 + +""" +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision = "0005" +down_revision = "0004" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("workspace_objects") as batch_op: + batch_op.add_column( + sa.Column( + "idempotency_key", + sa.String(255), + nullable=True, + unique=False, + ) + ) + op.create_index( + "ix_workspace_objects_idempotency", + "workspace_objects", + ["idempotency_key"], + unique=True, + postgresql_where=sa.text("idempotency_key IS NOT NULL"), + ) + + +def downgrade() -> None: + op.drop_index("ix_workspace_objects_idempotency", "workspace_objects") + with op.batch_alter_table("workspace_objects") as batch_op: + batch_op.drop_column("idempotency_key") diff --git a/alembic/versions/0006_add_agent_tables.py b/alembic/versions/0006_add_agent_tables.py new file mode 100644 index 00000000..30584377 --- /dev/null +++ b/alembic/versions/0006_add_agent_tables.py @@ -0,0 +1,52 @@ +"""Add agent_sessions and agent_actions tables (Phase 4D). + +Postgres-backed agent memory and audit trail. All agent actions +pass through AgentGovernor which writes here for complete auditability. + +Revision ID: 0006 +Revises: 0005 +Create Date: 2026-04-08 00:00:06.000000 + +""" +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision = "0006" +down_revision = "0005" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "agent_sessions", + sa.Column("session_id", sa.String(36), primary_key=True), + sa.Column("agent_id", sa.String(128), nullable=False, index=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("context_id", sa.String(36), nullable=True), # FK to execution_log.context_id + sa.Column("state_json", sa.Text, nullable=True), + ) + + op.create_table( + "agent_actions", + sa.Column("action_id", sa.String(36), primary_key=True), + sa.Column("agent_id", sa.String(128), nullable=False, index=True), + sa.Column("session_id", sa.String(36), nullable=True), # FK to agent_sessions + sa.Column("action_type", sa.String(64), nullable=False), + sa.Column("target_ref", sa.String(256), nullable=True), + sa.Column("impact_level", sa.String(16), nullable=False, server_default="low"), # low/medium/high/critical + sa.Column("approved_by", sa.String(128), nullable=True), + sa.Column("executed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("result_json", sa.Text, nullable=True), + sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("status", sa.String(32), nullable=False, server_default="pending"), # pending/approved/rejected/executed + ) + op.create_index("ix_agent_actions_agent_status", "agent_actions", ["agent_id", "status"]) + + +def downgrade() -> None: + op.drop_index("ix_agent_actions_agent_status", "agent_actions") + op.drop_table("agent_actions") + op.drop_table("agent_sessions") diff --git a/alembic/versions/0007_workspace_trust_boundary.py b/alembic/versions/0007_workspace_trust_boundary.py new file mode 100644 index 00000000..abd74eed --- /dev/null +++ b/alembic/versions/0007_workspace_trust_boundary.py @@ -0,0 +1,44 @@ +"""Add trust_boundary and allowed_connector_refs to workspaces (Phase 4E). + +Enables workspace isolation: connectors whose TRUST_LEVEL is below the +workspace's trust_boundary are rejected at instantiation time. + +Revision ID: 0007 +Revises: 0006 +Create Date: 2026-04-08 00:00:07.000000 + +""" +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision = "0007" +down_revision = "0006" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("workspaces") as batch_op: + batch_op.add_column( + sa.Column( + "trust_boundary", + sa.String(32), + nullable=False, + server_default="semi_trusted", + ) + ) + batch_op.add_column( + sa.Column( + "allowed_connector_refs", + sa.Text, # JSON array of connector class names; NULL = all allowed + nullable=True, + ) + ) + + +def downgrade() -> None: + with op.batch_alter_table("workspaces") as batch_op: + batch_op.drop_column("allowed_connector_refs") + batch_op.drop_column("trust_boundary") diff --git a/alembic/versions/0008_query_cost_log.py b/alembic/versions/0008_query_cost_log.py new file mode 100644 index 00000000..26c5fd2c --- /dev/null +++ b/alembic/versions/0008_query_cost_log.py @@ -0,0 +1,40 @@ +"""Add query_cost_log table (Phase 4E). + +Tracks per-connector query cost for capacity planning and budget enforcement. + +Revision ID: 0008 +Revises: 0007 +Create Date: 2026-04-08 00:00:08.000000 + +""" +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision = "0008" +down_revision = "0007" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "query_cost_log", + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("connector_id", sa.String(128), nullable=False, index=True), + sa.Column("cost_units", sa.Integer, nullable=False), + sa.Column("context_id", sa.String(36), nullable=True), # FK to execution_log + sa.Column("operation", sa.String(64), nullable=True), # "bulk_pull", "lookup", "search" + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False, index=True), + ) + op.create_index( + "ix_query_cost_connector_ts", + "query_cost_log", + ["connector_id", "timestamp"], + ) + + +def downgrade() -> None: + op.drop_index("ix_query_cost_connector_ts", "query_cost_log") + op.drop_table("query_cost_log") diff --git a/gnat/agents/governor.py b/gnat/agents/governor.py new file mode 100644 index 00000000..18cd1010 --- /dev/null +++ b/gnat/agents/governor.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.agents.governor +==================== + +Agent governance layer — permission checks, action logging, and rate limiting. + +:class:`AgentGovernor` is the single choke-point for all agent actions. Every +action an agent wishes to perform must pass through :meth:`AgentGovernor.can_act` +before execution and be recorded via :meth:`AgentGovernor.record_action`. + +Usage +----- +:: + + from gnat.agents.governor import AgentGovernor, AgentAction + from gnat.policy.models import AgentActionType + + governor = AgentGovernor() + + if governor.can_act("research-agent-1", AgentActionType.ENRICH, "semi_trusted"): + governor.rate_limit_check("research-agent-1", window_seconds=60) + action = AgentAction( + agent_id="research-agent-1", + action_type=AgentActionType.ENRICH, + target_ref="indicator--abc123", + impact_level="low", + ) + governor.record_action(action) +""" + +from __future__ import annotations + +import logging +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any + +from gnat.policy.models import AgentActionType, agent_can_act + +logger = logging.getLogger(__name__) + + +class RateLimitExceeded(Exception): + """Raised when an agent exceeds its configured request rate.""" + + def __init__(self, agent_id: str, window_seconds: int, current_count: int) -> None: + """Initialize RateLimitExceeded.""" + super().__init__( + f"Agent {agent_id!r} exceeded rate limit: " + f"{current_count} calls within {window_seconds}s window" + ) + self.agent_id = agent_id + self.window_seconds = window_seconds + self.current_count = current_count + + +class AgentPermissionDenied(Exception): + """Raised when an agent action is denied by the governor.""" + + def __init__(self, agent_id: str, action_type: AgentActionType, trust_level: str) -> None: + """Initialize AgentPermissionDenied.""" + super().__init__( + f"Agent {agent_id!r} (trust={trust_level!r}) denied " + f"permission for action {action_type.value!r}" + ) + self.agent_id = agent_id + self.action_type = action_type + self.trust_level = trust_level + + +# Valid impact levels (ordered low → high) +IMPACT_LEVELS = ("low", "medium", "high", "critical") + + +@dataclass +class AgentAction: + """ + Record of a single agent action, stored in the ``agent_actions`` table. + + Parameters + ---------- + agent_id : str + Identifier of the agent requesting the action. + action_type : AgentActionType + Category of the action being performed. + target_ref : str + STIX ID or other reference to the object being acted upon. + impact_level : str + Severity classification: ``"low"``, ``"medium"``, ``"high"``, or ``"critical"``. + session_id : str, optional + UUID of the parent agent session. + context_id : str, optional + UUID of the active :class:`~gnat.core.context.ExecutionContext`. + result_json : dict, optional + Action outcome (populated after execution). + approved_by : str, optional + Human reviewer or auto-approve policy that authorised this action. + """ + + agent_id: str + action_type: AgentActionType + target_ref: str = "" + impact_level: str = "low" + session_id: str = field(default_factory=lambda: str(uuid.uuid4())) + context_id: str | None = None + result_json: dict[str, Any] = field(default_factory=dict) + approved_by: str | None = None + action_id: str = field(default_factory=lambda: str(uuid.uuid4())) + submitted_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + executed_at: datetime | None = None + status: str = "pending" # pending | approved | rejected | executed + + def __post_init__(self) -> None: + """Validate impact_level.""" + if self.impact_level not in IMPACT_LEVELS: + raise ValueError( + f"impact_level must be one of {IMPACT_LEVELS}, got {self.impact_level!r}" + ) + + def to_dict(self) -> dict[str, Any]: + """Serialize to a plain dict suitable for DB insertion.""" + return { + "action_id": self.action_id, + "agent_id": self.agent_id, + "session_id": self.session_id, + "action_type": self.action_type.value, + "target_ref": self.target_ref, + "impact_level": self.impact_level, + "context_id": self.context_id, + "result_json": self.result_json, + "approved_by": self.approved_by, + "submitted_at": self.submitted_at.isoformat(), + "executed_at": self.executed_at.isoformat() if self.executed_at else None, + "status": self.status, + } + + +class AgentGovernor: + """ + Central governance authority for all GNAT agent actions. + + Responsibilities: + + * **Permission checks** — delegates to + :func:`~gnat.policy.models.agent_can_act` and optional per-agent + override policies loaded from config. + * **Action recording** — persists :class:`AgentAction` records to the + in-memory audit log (and optionally to the ``agent_actions`` DB table). + * **Rate limiting** — sliding-window counter per agent; raises + :exc:`RateLimitExceeded` when the configured limit is exceeded. + + Parameters + ---------- + max_calls_per_window : int + Maximum number of calls allowed per agent within *window_seconds*. + Default is 100. + window_seconds : int + Sliding window size for rate limiting in seconds. Default is 60. + policy_overrides : dict, optional + Per-agent permission overrides: ``{"agent-id": {"action_type": bool}}``. + Overrides take precedence over the trust-level default matrix. + """ + + def __init__( + self, + max_calls_per_window: int = 100, + window_seconds: int = 60, + policy_overrides: dict[str, dict[str, bool]] | None = None, + ) -> None: + """Initialize AgentGovernor.""" + self._max_calls = max_calls_per_window + self._window_seconds = window_seconds + self._policy_overrides: dict[str, dict[str, bool]] = policy_overrides or {} + + # Sliding-window rate-limit state: agent_id → list of epoch timestamps + self._call_timestamps: dict[str, list[float]] = {} + + # In-memory audit log + self._action_log: list[AgentAction] = [] + + # ── Public API ───────────────────────────────────────────────────────────── + + def can_act( + self, + agent_id: str, + action_type: AgentActionType, + trust_level: str = "semi_trusted", + ) -> bool: + """ + Return ``True`` if *agent_id* with *trust_level* may perform *action_type*. + + Checks per-agent policy overrides first; falls back to the default + trust-level permission matrix from + :func:`~gnat.policy.models.agent_can_act`. + + Parameters + ---------- + agent_id : str + Identifier of the requesting agent. + action_type : AgentActionType + The action being requested. + trust_level : str + Trust classification of the agent's connector or context. + """ + # 1. Per-agent override (explicit allow/deny) + agent_overrides = self._policy_overrides.get(agent_id, {}) + action_key = action_type.value + if action_key in agent_overrides: + result = bool(agent_overrides[action_key]) + self._emit_decision(agent_id, action_type, trust_level, result, source="override") + return result + + # 2. Trust-level default matrix + result = agent_can_act(trust_level, action_type) + self._emit_decision(agent_id, action_type, trust_level, result, source="policy") + return result + + def require_can_act( + self, + agent_id: str, + action_type: AgentActionType, + trust_level: str = "semi_trusted", + ) -> None: + """ + Assert *agent_id* may perform *action_type*, raising if not. + + Parameters + ---------- + agent_id : str + action_type : AgentActionType + trust_level : str + + Raises + ------ + AgentPermissionDenied + If the action is not permitted. + """ + if not self.can_act(agent_id, action_type, trust_level): + raise AgentPermissionDenied(agent_id, action_type, trust_level) + + def record_action(self, action: AgentAction) -> None: + """ + Persist an :class:`AgentAction` to the audit log. + + The action is appended to the in-memory log and emitted via the + :class:`~gnat.plugins.hooks.HookBus` when available. + + Parameters + ---------- + action : AgentAction + The action to record. + """ + self._action_log.append(action) + logger.info( + "AgentGovernor: recorded action agent=%r type=%r impact=%r target=%r status=%r", + action.agent_id, + action.action_type.value, + action.impact_level, + action.target_ref, + action.status, + ) + try: + from gnat.plugins.hooks import HookBus + HookBus.instance().emit( + "agent_action", + agent_id=action.agent_id, + action_type=action.action_type.value, + impact_level=action.impact_level, + target_ref=action.target_ref, + status=action.status, + ) + except Exception: # noqa: BLE001 + pass + + def rate_limit_check( + self, + agent_id: str, + window_seconds: int | None = None, + max_calls: int | None = None, + ) -> None: + """ + Check if *agent_id* is within its rate limit; raise if exceeded. + + Uses a sliding-window counter. Expired timestamps outside the window + are pruned on each call. + + Parameters + ---------- + agent_id : str + The agent being checked. + window_seconds : int, optional + Override the governor's default window. + max_calls : int, optional + Override the governor's default call limit. + + Raises + ------ + RateLimitExceeded + If the agent has exceeded the configured limit. + """ + window = window_seconds or self._window_seconds + limit = max_calls or self._max_calls + now = time.monotonic() + cutoff = now - window + + timestamps = self._call_timestamps.setdefault(agent_id, []) + # Prune old entries + timestamps[:] = [t for t in timestamps if t > cutoff] + + if len(timestamps) >= limit: + raise RateLimitExceeded(agent_id, window, len(timestamps)) + + timestamps.append(now) + + def get_action_log(self, agent_id: str | None = None) -> list[AgentAction]: + """ + Return recorded actions, optionally filtered by *agent_id*. + + Parameters + ---------- + agent_id : str, optional + Filter to a specific agent; ``None`` returns all. + """ + if agent_id is None: + return list(self._action_log) + return [a for a in self._action_log if a.agent_id == agent_id] + + def set_policy_override( + self, agent_id: str, action_type: AgentActionType, allowed: bool + ) -> None: + """ + Set a per-agent permission override. + + Parameters + ---------- + agent_id : str + The agent to configure. + action_type : AgentActionType + The action type to override. + allowed : bool + Whether to allow (``True``) or deny (``False``) the action. + """ + self._policy_overrides.setdefault(agent_id, {})[action_type.value] = allowed + + # ── Internals ────────────────────────────────────────────────────────────── + + def _emit_decision( + self, + agent_id: str, + action_type: AgentActionType, + trust_level: str, + granted: bool, + source: str, + ) -> None: + logger.debug( + "AgentGovernor: agent=%r action=%r trust=%r → %s (source=%s)", + agent_id, + action_type.value, + trust_level, + "ALLOW" if granted else "DENY", + source, + ) diff --git a/gnat/agents/hitl.py b/gnat/agents/hitl.py new file mode 100644 index 00000000..f8dc12ef --- /dev/null +++ b/gnat/agents/hitl.py @@ -0,0 +1,285 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.agents.hitl +================ + +Human-in-the-Loop (HITL) gateway for agent actions. + +:class:`HITLGateway` is a thin bridge between :class:`AgentGovernor` and the +existing :class:`~gnat.review.service.ReviewService`. It implements a +four-tier impact model: + +* **low / medium** — auto-approved per policy; action is logged but not queued. +* **high** — submitted to :class:`~gnat.review.service.ReviewService` as a + ``PENDING`` :class:`~gnat.review.models.ReviewItem`; blocks the action until + approved or rejected. +* **critical** — same as ``high``, plus an XSOAR playbook is triggered via the + existing :class:`~gnat.connectors.xsoar.client.XSOARClient`. + +Usage +----- +:: + + from gnat.agents.hitl import HITLGateway + from gnat.agents.governor import AgentAction + from gnat.policy.models import AgentActionType + + gateway = HITLGateway(review_service=svc) + action = AgentAction( + agent_id="threat-hunter-1", + action_type=AgentActionType.TRIGGER_PLAYBOOK, + target_ref="indicator--abc", + impact_level="high", + ) + review_item = gateway.submit_for_approval(action) + status = gateway.check_approval_status(review_item.id) +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from gnat.agents.governor import AgentAction, IMPACT_LEVELS + +if TYPE_CHECKING: + from gnat.review.service import ReviewService + from gnat.review.models import ReviewItem, ReviewStatus + +logger = logging.getLogger(__name__) + +# Impact levels that require human review +_REVIEW_REQUIRED = frozenset({"high", "critical"}) + +# Timeout for pending approvals (seconds); actions auto-rejected after this +DEFAULT_APPROVAL_TIMEOUT_SECONDS = 3600 + + +class HITLGateway: + """ + Human-in-the-Loop gateway bridging :class:`AgentGovernor` to + :class:`~gnat.review.service.ReviewService`. + + Parameters + ---------- + review_service : ReviewService + The existing GNAT review queue service. + approval_timeout_seconds : int + Seconds before a pending approval is auto-rejected. + Defaults to 3600 (1 hour). + xsoar_client : object, optional + Pre-configured :class:`~gnat.connectors.xsoar.client.XSOARClient` + instance. When provided, critical actions trigger a SOAR playbook. + source_workspace : str + Workspace label used when submitting review items. + """ + + def __init__( + self, + review_service: "ReviewService", + approval_timeout_seconds: int = DEFAULT_APPROVAL_TIMEOUT_SECONDS, + xsoar_client: Any | None = None, + source_workspace: str = "agent-actions", + ) -> None: + """Initialize HITLGateway.""" + self._review_service = review_service + self._approval_timeout = approval_timeout_seconds + self._xsoar_client = xsoar_client + self._source_workspace = source_workspace + + # Track review_item_id → AgentAction for status polling + self._pending: dict[str, AgentAction] = {} + + # ── Public API ───────────────────────────────────────────────────────────── + + def evaluate(self, action: AgentAction) -> tuple[bool, "ReviewItem | None"]: + """ + Evaluate an agent action against the impact-tier policy. + + Returns ``(auto_approved, review_item)``. + + * ``(True, None)`` — low/medium impact; auto-approved, no review item. + * ``(False, ReviewItem)`` — high/critical; submitted for human review. + + Parameters + ---------- + action : AgentAction + The action to evaluate. + + Returns + ------- + tuple of (bool, ReviewItem or None) + ``(auto_approved, review_item)`` + """ + if action.impact_level not in _REVIEW_REQUIRED: + # Auto-approve low / medium + action.approved_by = "auto-policy" + action.status = "approved" + logger.info( + "HITLGateway: auto-approved %s action by agent %r (impact=%r)", + action.action_type.value, + action.agent_id, + action.impact_level, + ) + return True, None + + # High / critical → submit for human review + review_item = self.submit_for_approval(action) + return False, review_item + + def submit_for_approval(self, action: AgentAction) -> "ReviewItem": + """ + Submit *action* to the review queue and return the created + :class:`~gnat.review.models.ReviewItem`. + + For ``"critical"`` impact actions, an XSOAR playbook notification is + also triggered if an XSOAR client was provided. + + Parameters + ---------- + action : AgentAction + Action requiring human approval. + + Returns + ------- + ReviewItem + The newly created pending review item. + """ + stix_data = self._action_to_stix(action) + review_item = self._review_service.submit( + stix_data=stix_data, + source_workspace=self._source_workspace, + submitted_by=action.agent_id, + ) + action.status = "pending" + self._pending[review_item.id] = action + logger.info( + "HITLGateway: submitted review item %r for agent %r action %s (impact=%r)", + review_item.id, + action.agent_id, + action.action_type.value, + action.impact_level, + ) + + if action.impact_level == "critical": + self._notify_xsoar(action, review_item) + + return review_item + + def check_approval_status(self, review_id: str) -> "ReviewStatus": + """ + Return the current :class:`~gnat.review.models.ReviewStatus` for *review_id*. + + If the item has been pending longer than ``approval_timeout_seconds``, it + is automatically rejected. + + Parameters + ---------- + review_id : str + The ID of the review item to check. + + Returns + ------- + ReviewStatus + """ + from gnat.review.models import ReviewStatus + + review_item = self._review_service.get(review_id) + + # Check timeout for pending items + if review_item.status == ReviewStatus.PENDING: + submitted = review_item.submitted_at + if submitted.tzinfo is None: + submitted = submitted.replace(tzinfo=timezone.utc) + elapsed = (datetime.now(timezone.utc) - submitted).total_seconds() + if elapsed > self._approval_timeout: + logger.warning( + "HITLGateway: review %r timed out after %.0fs — auto-rejecting", + review_id, + elapsed, + ) + self._review_service.reject( + review_id, + reviewed_by="system-timeout", + reason=f"Approval timeout ({self._approval_timeout}s)", + ) + review_item = self._review_service.get(review_id) + + # Update tracked action + if review_id in self._pending: + self._pending[review_id].status = "rejected" + + return review_item.status + + def auto_approve_pending(self, review_id: str, reviewer: str = "auto-policy") -> None: + """ + Programmatically approve a pending review item (used in tests and + auto-escalation scenarios). + + Parameters + ---------- + review_id : str + ID of the item to approve. + reviewer : str + Name recorded as the approver. + """ + self._review_service.approve(review_id, reviewed_by=reviewer) + if review_id in self._pending: + self._pending[review_id].status = "approved" + self._pending[review_id].approved_by = reviewer + + # ── Internals ────────────────────────────────────────────────────────────── + + @staticmethod + def _action_to_stix(action: AgentAction) -> dict[str, Any]: + """Convert an :class:`AgentAction` to a minimal STIX-compatible dict.""" + import uuid as _uuid + + return { + "type": "x-gnat-agent-action", + "id": f"x-gnat-agent-action--{_uuid.uuid4()}", + "spec_version": "2.1", + "created": action.submitted_at.isoformat(), + "modified": action.submitted_at.isoformat(), + "agent_id": action.agent_id, + "action_type": action.action_type.value, + "target_ref": action.target_ref, + "impact_level": action.impact_level, + "session_id": action.session_id, + "context_id": action.context_id, + "action_id": action.action_id, + } + + def _notify_xsoar(self, action: AgentAction, review_item: Any) -> None: + """Fire an XSOAR playbook notification for critical actions.""" + if self._xsoar_client is None: + logger.debug( + "HITLGateway: no XSOAR client configured; skipping critical notification" + ) + return + try: + payload = { + "type": "note", + "id": f"note--{__import__('uuid').uuid4()}", + "spec_version": "2.1", + "created": datetime.now(timezone.utc).isoformat(), + "modified": datetime.now(timezone.utc).isoformat(), + "abstract": f"CRITICAL agent action pending approval: {action.action_type.value}", + "content": ( + f"Agent: {action.agent_id}\n" + f"Action: {action.action_type.value}\n" + f"Target: {action.target_ref}\n" + f"Review ID: {review_item.id}\n" + f"Impact: {action.impact_level}" + ), + "object_refs": [action.target_ref] if action.target_ref else [], + } + self._xsoar_client.upsert_object(payload) + logger.info( + "HITLGateway: XSOAR notified for critical action by agent %r", + action.agent_id, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("HITLGateway: XSOAR notification failed — %s", exc) diff --git a/gnat/clients/base.py b/gnat/clients/base.py index df28d862..803490fd 100644 --- a/gnat/clients/base.py +++ b/gnat/clients/base.py @@ -56,6 +56,34 @@ def __init__(self, message: str, status: int = 0, body: str = ""): self.body = body +class BudgetExceeded(GNATClientError): + """ + Raised when an :class:`~gnat.core.context.ExecutionContext` query budget + is exhausted before the requested operation can complete. + + Attributes + ---------- + connector : str + Name of the connector that triggered the budget check. + cost : int + Cost units the connector attempted to consume. + remaining : int + Budget units remaining when the check failed (always 0 or negative). + """ + + def __init__(self, connector: str, cost: int, remaining: int) -> None: + """Initialize BudgetExceeded.""" + super().__init__( + f"Query budget exhausted: {connector!r} requires {cost} units " + f"but only {remaining} remaining.", + status=0, + body="", + ) + self.connector = connector + self.cost = cost + self.remaining = remaining + + class BaseClient: """ urllib3-backed HTTP client base class for all GNAT connectors. @@ -79,8 +107,27 @@ class BaseClient: _auth_headers : dict Headers injected into every request after :meth:`authenticate` runs. Subclasses should populate this during authentication. + + Class Variables + --------------- + TRUST_LEVEL : str + Trust classification for this connector. Set explicitly on each + subclass. Valid values: ``"trusted_internal"``, ``"semi_trusted"``, + ``"untrusted_external"``. Defaults to ``"semi_trusted"``. + API_VERSION : str + API version string (e.g. ``"v2"``). Empty string means unversioned. + API_PREFIX : str + URL path prefix used by this connector's API version (e.g. ``"/v3"``). + COST_UNIT : int + Relative cost weight for budget accounting. Single-object lookups + use 1; bulk pulls use 10; search operations use 5. """ + TRUST_LEVEL: str = "semi_trusted" + API_VERSION: str = "" + API_PREFIX: str = "" + COST_UNIT: int = 1 + def __init__( self, host: str, @@ -97,6 +144,8 @@ def __init__( self.config = config or {} self._auth_headers: dict[str, str] = {} self._authenticated = False + # Optional ExecutionContext for budget tracking (set by callers) + self._context: Any = None retry = Retry( total=max_retries, @@ -232,6 +281,12 @@ def _request( self.authenticate() self._authenticated = True + # Deduct from query budget if one is attached via ExecutionContext + if self._context is not None: + budget = getattr(self._context, "budget", None) + if budget is not None: + budget.consume(self.COST_UNIT, type(self).__name__) + url = urljoin(self.host + "/", path.lstrip("/")) if params: url = f"{url}?{urlencode(params, doseq=True)}" diff --git a/gnat/connectors/CISA/client.py b/gnat/connectors/CISA/client.py index 7aefe6ee..dbba8113 100644 --- a/gnat/connectors/CISA/client.py +++ b/gnat/connectors/CISA/client.py @@ -65,6 +65,12 @@ class CISAClient(BaseClient, ConnectorMixin): host : str Base URL, default ``"https://www.cisa.gov"``. """ + TRUST_LEVEL: str = "untrusted_external" + API_VERSION: str = "v1" + API_PREFIX: str = "" + COST_UNIT: int = 1 + + stix_type_map: dict[str, str] = { "vulnerability": "kev", diff --git a/gnat/connectors/alienvault/client.py b/gnat/connectors/alienvault/client.py index fdb5a2ea..67c9a0c7 100644 --- a/gnat/connectors/alienvault/client.py +++ b/gnat/connectors/alienvault/client.py @@ -105,6 +105,12 @@ class AlienVaultClient(BaseClient, ConnectorMixin): api_key : str OTX API key. """ + TRUST_LEVEL: str = "untrusted_external" + API_VERSION: str = "v1" + API_PREFIX: str = "/api/v1" + COST_UNIT: int = 1 + + stix_type_map: dict[str, str] = { "indicator": "indicators", diff --git a/gnat/connectors/crowdstrike/client.py b/gnat/connectors/crowdstrike/client.py index 4a37c529..c7d79b0e 100644 --- a/gnat/connectors/crowdstrike/client.py +++ b/gnat/connectors/crowdstrike/client.py @@ -23,6 +23,12 @@ class CrowdStrikeClient(BaseClient, ConnectorMixin): """HTTP client for the CrowdStrike Falcon REST API.""" + TRUST_LEVEL: str = "semi_trusted" + API_VERSION: str = "v1" + API_PREFIX: str = "/oauth2" + COST_UNIT: int = 1 + + stix_type_map: dict[str, str] = { "indicator": "iocs", diff --git a/gnat/connectors/elastic/client.py b/gnat/connectors/elastic/client.py index ca7ba50a..b48acb7c 100644 --- a/gnat/connectors/elastic/client.py +++ b/gnat/connectors/elastic/client.py @@ -82,6 +82,12 @@ class ElasticClient: config : ElasticConfig Validated connector configuration. """ + TRUST_LEVEL: str = "trusted_internal" + API_VERSION: str = "v1" + API_PREFIX: str = "" + COST_UNIT: int = 10 + + def __init__(self, config: ElasticConfig) -> None: """Initialize ElasticClient.""" diff --git a/gnat/connectors/feedly/client.py b/gnat/connectors/feedly/client.py index 35338e27..f03e9f05 100644 --- a/gnat/connectors/feedly/client.py +++ b/gnat/connectors/feedly/client.py @@ -70,6 +70,12 @@ class FeedlyClient(BaseClient, ConnectorMixin): api_token : str Feedly access token. """ + TRUST_LEVEL: str = "semi_trusted" + API_VERSION: str = "v3" + API_PREFIX: str = "/v3" + COST_UNIT: int = 1 + + stix_type_map: dict[str, str] = { "indicator": "iocFeed", diff --git a/gnat/connectors/graylog/client.py b/gnat/connectors/graylog/client.py index 565793fa..67413cc5 100644 --- a/gnat/connectors/graylog/client.py +++ b/gnat/connectors/graylog/client.py @@ -74,6 +74,12 @@ class GraylogClient(BaseClient, ConnectorMixin): password : str Graylog password. """ + TRUST_LEVEL: str = "trusted_internal" + API_VERSION: str = "v1" + API_PREFIX: str = "/api" + COST_UNIT: int = 1 + + stix_type_map: dict[str, str] = { "observed-data": "search", diff --git a/gnat/connectors/misp/client.py b/gnat/connectors/misp/client.py index a3a35de4..127384c2 100644 --- a/gnat/connectors/misp/client.py +++ b/gnat/connectors/misp/client.py @@ -62,6 +62,12 @@ class MISPClient: """urllib3-based HTTP client for the MISP REST API.""" + TRUST_LEVEL: str = "semi_trusted" + API_VERSION: str = "v1" + API_PREFIX: str = "/" + COST_UNIT: int = 1 + + def __init__(self, config: MISPConfig) -> None: """Initialize MISPClient.""" diff --git a/gnat/connectors/qradar/client.py b/gnat/connectors/qradar/client.py index b0e8e692..ca77f0c3 100644 --- a/gnat/connectors/qradar/client.py +++ b/gnat/connectors/qradar/client.py @@ -96,6 +96,12 @@ class QRadarClient: config : QRadarConfig Validated connector configuration. """ + TRUST_LEVEL: str = "trusted_internal" + API_VERSION: str = "v1" + API_PREFIX: str = "/api" + COST_UNIT: int = 1 + + def __init__(self, config: QRadarConfig) -> None: """Initialize QRadarClient.""" diff --git a/gnat/connectors/recordedfuture/client.py b/gnat/connectors/recordedfuture/client.py index 7331417a..d8cd7a7e 100644 --- a/gnat/connectors/recordedfuture/client.py +++ b/gnat/connectors/recordedfuture/client.py @@ -23,6 +23,11 @@ class RecordedFutureClient(BaseClient, ConnectorMixin): """HTTP client for the Recorded Future Connect API v2.""" + TRUST_LEVEL: str = "semi_trusted" + API_VERSION: str = "v2" + API_PREFIX: str = "/v2" + COST_UNIT: int = 1 + stix_type_map: dict[str, str] = { "indicator": "ip", "malware": "malware", diff --git a/gnat/connectors/recordedfuture/rfv3.py b/gnat/connectors/recordedfuture/rfv3.py index 8b93c61a..e3ee0864 100644 --- a/gnat/connectors/recordedfuture/rfv3.py +++ b/gnat/connectors/recordedfuture/rfv3.py @@ -34,8 +34,10 @@ class RecordedFutureClientV3(RecordedFutureBase): """Recorded Future Connect API v3 client.""" + TRUST_LEVEL: str = "semi_trusted" API_VERSION = "v3" API_PREFIX = "/v3" + COST_UNIT: int = 1 # ------------------------------------------------------------------ # Alert API v3 (cursor-paginated; overrides v2 offset behaviour) diff --git a/gnat/connectors/security_onion/client.py b/gnat/connectors/security_onion/client.py index 9462eaec..ced72b01 100644 --- a/gnat/connectors/security_onion/client.py +++ b/gnat/connectors/security_onion/client.py @@ -78,6 +78,12 @@ class SecurityOnionClient(BaseClient, ConnectorMixin): password : str Security Onion password. """ + TRUST_LEVEL: str = "trusted_internal" + API_VERSION: str = "v1" + API_PREFIX: str = "/api" + COST_UNIT: int = 1 + + stix_type_map: dict[str, str] = { "observed-data": "alerts", diff --git a/gnat/connectors/sentinel/client.py b/gnat/connectors/sentinel/client.py index d547e410..a2048561 100644 --- a/gnat/connectors/sentinel/client.py +++ b/gnat/connectors/sentinel/client.py @@ -57,6 +57,12 @@ class SentinelClient: """urllib3-based HTTP client for the Microsoft Sentinel REST API.""" + TRUST_LEVEL: str = "trusted_internal" + API_VERSION: str = "v1" + API_PREFIX: str = "/api" + COST_UNIT: int = 1 + + def __init__(self, config: SentinelConfig) -> None: """Initialize SentinelClient.""" diff --git a/gnat/connectors/shadowserver/client.py b/gnat/connectors/shadowserver/client.py index e57454ce..725d78dc 100644 --- a/gnat/connectors/shadowserver/client.py +++ b/gnat/connectors/shadowserver/client.py @@ -57,6 +57,12 @@ class ShadowServerClient(BaseClient, ConnectorMixin): api_secret : str Shadowserver API secret for HMAC signing. """ + TRUST_LEVEL: str = "untrusted_external" + API_VERSION: str = "v1" + API_PREFIX: str = "/api-query" + COST_UNIT: int = 1 + + stix_type_map: dict[str, str] = { "indicator": "ip", diff --git a/gnat/connectors/splunk/client.py b/gnat/connectors/splunk/client.py index c9d52c6c..1336a289 100644 --- a/gnat/connectors/splunk/client.py +++ b/gnat/connectors/splunk/client.py @@ -85,6 +85,12 @@ class SplunkClient(BaseClient, ConnectorMixin): config : SplunkConfig, optional Fully-constructed config object (alternative to keyword args). """ + TRUST_LEVEL: str = "trusted_internal" + API_VERSION: str = "v2" + API_PREFIX: str = "/services" + COST_UNIT: int = 10 + + def __init__( self, diff --git a/gnat/connectors/threatq/client.py b/gnat/connectors/threatq/client.py index af1bcbc1..a6783b11 100644 --- a/gnat/connectors/threatq/client.py +++ b/gnat/connectors/threatq/client.py @@ -111,6 +111,12 @@ class ThreatQClient(BaseClient, ConnectorMixin): **kwargs Forwarded to :class:`~gnat.clients.base.BaseClient`. """ + TRUST_LEVEL: str = "semi_trusted" + API_VERSION: str = "v1" + API_PREFIX: str = "/api" + COST_UNIT: int = 1 + + stix_type_map: dict[str, str] = { "indicator": "indicator", diff --git a/gnat/connectors/virustotal/client.py b/gnat/connectors/virustotal/client.py index 7cd0c919..1a853f50 100644 --- a/gnat/connectors/virustotal/client.py +++ b/gnat/connectors/virustotal/client.py @@ -52,6 +52,12 @@ class VirusTotalClient(BaseClient, ConnectorMixin): api_key : str VirusTotal API key. """ + TRUST_LEVEL: str = "semi_trusted" + API_VERSION: str = "v3" + API_PREFIX: str = "/api/v3" + COST_UNIT: int = 1 + + stix_type_map: dict[str, str] = { "indicator": "files", diff --git a/gnat/connectors/wazuh/client.py b/gnat/connectors/wazuh/client.py index df393412..f4c56ee2 100644 --- a/gnat/connectors/wazuh/client.py +++ b/gnat/connectors/wazuh/client.py @@ -84,6 +84,12 @@ class WazuhClient: config : WazuhConfig Validated connector configuration. """ + TRUST_LEVEL: str = "trusted_internal" + API_VERSION: str = "v1" + API_PREFIX: str = "/api" + COST_UNIT: int = 1 + + def __init__(self, config: WazuhConfig) -> None: """Initialize WazuhClient.""" diff --git a/gnat/connectors/xsoar/client.py b/gnat/connectors/xsoar/client.py index 9629af2c..41b3fd33 100644 --- a/gnat/connectors/xsoar/client.py +++ b/gnat/connectors/xsoar/client.py @@ -92,6 +92,12 @@ class XSOARClient(BaseClient, ConnectorMixin): verify_ssl : bool TLS certificate verification. Default ``True``. """ + TRUST_LEVEL: str = "trusted_internal" + API_VERSION: str = "v1" + API_PREFIX: str = "/xsoar" + COST_UNIT: int = 1 + + stix_type_map: dict[str, str] = { "indicator": "indicator", diff --git a/gnat/context/store.py b/gnat/context/store.py index bcfa4eb3..3314c9f9 100644 --- a/gnat/context/store.py +++ b/gnat/context/store.py @@ -56,6 +56,7 @@ from __future__ import annotations +import hashlib import json import logging from datetime import datetime, timezone @@ -131,6 +132,10 @@ class WorkspaceModel(_Base): ) # Arbitrary JSON metadata (analyst notes, tags, config overrides, etc.) metadata_json = Column(Text, nullable=True, default="{}") + # Trust boundary for workspace isolation (4E-1): minimum connector trust level + trust_boundary = Column(String(50), nullable=True, default="semi_trusted") + # JSON list of connector class names allowed to read/write this workspace + allowed_connector_refs = Column(Text, nullable=True, default="[]") objects = relationship( "WorkspaceObjectModel", @@ -188,6 +193,8 @@ class WorkspaceObjectModel(_Base): is_dirty = Column(Boolean, default=False, nullable=False) # Soft-delete — marks objects removed from workspace without DB purge is_deleted = Column(Boolean, default=False, nullable=False) + # Idempotency key for replay-safe writes: (connector_id, stix_type, external_id, content_hash) + idempotency_key = Column(String(255), nullable=True, unique=True) workspace = relationship("WorkspaceModel", back_populates="objects") @@ -381,16 +388,75 @@ def delete_workspace(self, name: str) -> bool: # ── Object CRUD ──────────────────────────────────────────────────────── + @staticmethod + def make_idempotency_key( + connector_id: str, + stix_type: str, + external_id: str, + stix_dict: dict, + ) -> str: + """ + Compute a stable idempotency key for a STIX object write. + + Key format: ``{connector_id}:{stix_type}:{external_id}:{content_hash}`` + + The content hash is a SHA-1 of the serialised stix_dict so that + re-ingesting the identical payload produces the same key (collision + safe for idempotency purposes; not for security). + """ + content_hash = hashlib.sha1( # noqa: S324 # nosec — not security use + json.dumps(stix_dict, sort_keys=True).encode() + ).hexdigest()[:12] + return f"{connector_id}:{stix_type}:{external_id}:{content_hash}" + def upsert_object( - self, workspace_id: int, stix_dict: dict, source_platform: str = "", is_dirty: bool = False + self, + workspace_id: int, + stix_dict: dict, + source_platform: str = "", + is_dirty: bool = False, + idempotency_key: str | None = None, ) -> WorkspaceObjectModel: """ Insert or update a STIX object in a workspace. If an object with the same ``workspace_id`` + ``stix_id`` already exists it is updated in-place; otherwise a new row is inserted. + + When *idempotency_key* is provided, the write is skipped silently if + an existing row already carries that key (indicating the same content + was already ingested from the same connector). This makes pipeline + replay safe — re-running produces identical storage state. + + Parameters + ---------- + workspace_id : int + DB id of the target workspace. + stix_dict : dict + Full STIX 2.1 object dict; must contain ``"id"`` and ``"type"``. + source_platform : str + Originating platform/connector name. + is_dirty : bool + Mark object as modified (needs push-back to platform). + idempotency_key : str, optional + Pre-computed idempotency key. Use :meth:`make_idempotency_key` + to generate. When ``None`` no idempotency check is performed. """ with self.session() as sess: + # Idempotency check: skip if exact same content already stored + if idempotency_key: + dup = ( + sess.query(WorkspaceObjectModel) + .filter_by(idempotency_key=idempotency_key) + .first() + ) + if dup is not None: + logger.debug( + "upsert_object: idempotency hit for key %s — skipping write", + idempotency_key, + ) + return dup + existing = ( sess.query(WorkspaceObjectModel) .filter_by(workspace_id=workspace_id, stix_id=stix_dict["id"]) @@ -404,6 +470,8 @@ def upsert_object( existing.is_deleted = False if source_platform: existing.source_platform = source_platform + if idempotency_key: + existing.idempotency_key = idempotency_key sess.commit() return existing @@ -415,6 +483,7 @@ def upsert_object( stix_json=json.dumps(stix_dict), source_platform=source_platform, is_dirty=is_dirty, + idempotency_key=idempotency_key, ) sess.add(obj) sess.commit() diff --git a/gnat/context/workspace.py b/gnat/context/workspace.py index 5131cb82..d82ce3c4 100644 --- a/gnat/context/workspace.py +++ b/gnat/context/workspace.py @@ -142,6 +142,10 @@ def __init__( self._snapshot: dict[str, dict] = {} # Workspace DB id (WorkspaceStore only) self._ws_id: int | None = None + # Trust boundary — minimum connector trust level required (4E-1) + self.trust_boundary: str = "semi_trusted" + # Explicit connector allowlist (empty = no restriction) + self.allowed_connector_refs: list[str] = [] self._init_store() @@ -154,6 +158,14 @@ def _init_store(self) -> None: if isinstance(self._store, WorkspaceStore): ws_model = self._store.get_or_create_workspace(self.name, description=self.description) self._ws_id = ws_model.id + # Load trust isolation settings (4E-1) + self.trust_boundary = getattr(ws_model, "trust_boundary", None) or "semi_trusted" + import json as _json + raw_refs = getattr(ws_model, "allowed_connector_refs", None) or "[]" + try: + self.allowed_connector_refs = _json.loads(raw_refs) + except (ValueError, TypeError): + self.allowed_connector_refs = [] # Re-hydrate in-memory cache from persisted objects for stix_dict in self._store.get_objects(self._ws_id): obj = self._from_dict(stix_dict) @@ -702,6 +714,44 @@ def _mark_clean(self) -> None: if isinstance(self._store, WorkspaceStore): self._store.mark_clean(self._ws_id) + # ── Trust boundary enforcement (4E-1) ────────────────────────────────────── + + def check_connector_trust(self, connector: Any) -> None: + """ + Verify that *connector* satisfies this workspace's trust boundary. + + Raises ``PermissionError`` if the connector's ``TRUST_LEVEL`` is + below the workspace ``trust_boundary``, or if the connector is not + in the ``allowed_connector_refs`` allowlist (when the list is non-empty). + + Parameters + ---------- + connector : object + Any :class:`~gnat.clients.base.BaseClient` subclass instance. + """ + _TRUST_ORDER = {"trusted_internal": 2, "semi_trusted": 1, "untrusted_external": 0} + connector_trust = getattr(type(connector), "TRUST_LEVEL", "semi_trusted") + required_trust = self.trust_boundary or "semi_trusted" + + connector_rank = _TRUST_ORDER.get(connector_trust, 0) + required_rank = _TRUST_ORDER.get(required_trust, 1) + + if connector_rank < required_rank: + connector_name = type(connector).__name__ + raise PermissionError( + f"Connector {connector_name!r} (trust={connector_trust!r}) " + f"does not meet workspace {self.name!r} trust boundary " + f"({required_trust!r} required). Access denied." + ) + + if self.allowed_connector_refs: + connector_name = type(connector).__name__ + if connector_name not in self.allowed_connector_refs: + raise PermissionError( + f"Connector {connector_name!r} is not in the allowed connector " + f"list for workspace {self.name!r}." + ) + def _resolve_source(self, name: str | None) -> GlobalContext: """Internal helper for resolve source.""" if name: diff --git a/gnat/core/__init__.py b/gnat/core/__init__.py new file mode 100644 index 00000000..97ce847b --- /dev/null +++ b/gnat/core/__init__.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.core +========= + +Cross-cutting infrastructure for GNAT Phase 4: execution tracing, +domain boundary enforcement, and trust-level control. +""" + +from gnat.core.context import ExecutionContext, QueryBudget +from gnat.core.domains import ( + Domain, + DomainBoundaryViolation, + domain_boundary, + require_trust_level, +) + +__all__ = [ + "ExecutionContext", + "QueryBudget", + "Domain", + "DomainBoundaryViolation", + "domain_boundary", + "require_trust_level", +] diff --git a/gnat/core/context.py b/gnat/core/context.py new file mode 100644 index 00000000..c1c53f35 --- /dev/null +++ b/gnat/core/context.py @@ -0,0 +1,297 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.core.context +================= + +Unified execution context that every GNAT operation carries. + +Every pipeline run, enrichment call, connector request, and agent action +is tagged with an :class:`ExecutionContext`. This gives GNAT end-to-end +traceability: you can reconstruct exactly which connector, in which domain, +under which trust level, produced any stored object. + +Usage +----- +:: + + from gnat.core.context import ExecutionContext + + ctx = ExecutionContext.create( + initiated_by="splunk", + domain="ingestion", + workspace_id="ws-threats-2026", + ) + # Pass ctx to pipeline entry points; it propagates automatically. +""" + +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Query Budget (4E-2) +# --------------------------------------------------------------------------- + +@dataclass +class QueryBudget: + """ + Finite query budget carried by an :class:`ExecutionContext`. + + Each :class:`~gnat.clients.base.BaseClient` call deducts + :attr:`~gnat.clients.base.BaseClient.COST_UNIT` units from the budget. + When the budget reaches zero a + :class:`~gnat.clients.base.BudgetExceeded` exception is raised. + + Parameters + ---------- + max_units : int + Maximum total cost units for this execution. Default ``1000``. + """ + + max_units: int = 1000 + _consumed: int = field(default=0, init=False, repr=False) + + @property + def remaining(self) -> int: + """Remaining budget units.""" + return max(0, self.max_units - self._consumed) + + @property + def is_exhausted(self) -> bool: + """True when budget is fully consumed.""" + return self._consumed >= self.max_units + + def consume(self, units: int, connector: str = "") -> None: + """ + Deduct *units* from the budget. + + Parameters + ---------- + units : int + Cost units to consume. + connector : str + Name of the consuming connector (for the error message). + + Raises + ------ + gnat.clients.base.BudgetExceeded + If the budget would be exceeded. + """ + if self._consumed + units > self.max_units: + from gnat.clients.base import BudgetExceeded + raise BudgetExceeded( + connector=connector, + cost=units, + remaining=self.remaining, + ) + self._consumed += units + logger.debug( + "QueryBudget: consumed %d units by %r; remaining=%d/%d", + units, connector, self.remaining, self.max_units, + ) + +# Valid trust levels (mirrors BaseClient.TRUST_LEVEL values) +TRUST_LEVELS = frozenset({"trusted_internal", "semi_trusted", "untrusted_external"}) + +# Valid domain names (mirrors gnat.core.domains.Domain) +VALID_DOMAINS = frozenset({"ingestion", "analysis", "investigation", "reporting", "execution"}) + + +@dataclass +class ExecutionContext: + """ + Unified execution context propagated through all GNAT operations. + + Parameters + ---------- + context_id : str + UUID identifying this specific execution trace. + initiated_by : str + Connector name, agent ID, or ``"manual"`` for human-triggered runs. + domain : str + Operational domain: ``"ingestion"``, ``"analysis"``, ``"investigation"``, + ``"reporting"``, or ``"execution"``. + trust_level : str + Trust classification inherited from the initiating connector: + ``"trusted_internal"``, ``"semi_trusted"``, or ``"untrusted_external"``. + policy_set : str + Name of the active policy set (from ``[agent_policy]`` INI section). + workspace_id : str + Workspace isolation boundary for this operation. + created_at : datetime + UTC timestamp when this context was created. + parent_context_id : str, optional + UUID of the parent context for sub-operations; enables trace trees. + is_replay : bool + ``True`` if this execution is replaying a previously recorded run. + Pipeline runners suppress side-effects (SOAR triggers, etc.) when set. + """ + + context_id: str + initiated_by: str + domain: str + trust_level: str + policy_set: str + workspace_id: str + created_at: datetime + parent_context_id: str | None = None + is_replay: bool = False + budget: QueryBudget | None = None + + # ── Factory ──────────────────────────────────────────────────────────────── + + @classmethod + def create( + cls, + initiated_by: str, + domain: str, + workspace_id: str, + trust_level: str = "semi_trusted", + policy_set: str = "default", + parent_context_id: str | None = None, + is_replay: bool = False, + max_budget_units: int | None = None, + ) -> ExecutionContext: + """ + Create a new :class:`ExecutionContext` with a fresh UUID and UTC timestamp. + + Parameters + ---------- + initiated_by : str + Connector name, agent ID, or ``"manual"``. + domain : str + One of: ``"ingestion"``, ``"analysis"``, ``"investigation"``, + ``"reporting"``, ``"execution"``. + workspace_id : str + Target workspace identifier. + trust_level : str + Trust classification. Defaults to ``"semi_trusted"``. + policy_set : str + Active policy set name. Defaults to ``"default"``. + parent_context_id : str, optional + UUID of the parent context for nested operations. + is_replay : bool + Mark this as a replay run. + + Returns + ------- + ExecutionContext + """ + budget = QueryBudget(max_units=max_budget_units) if max_budget_units is not None else None + return cls( + context_id=str(uuid.uuid4()), + initiated_by=initiated_by, + domain=domain, + trust_level=trust_level, + policy_set=policy_set, + workspace_id=workspace_id, + created_at=datetime.now(timezone.utc), + parent_context_id=parent_context_id, + is_replay=is_replay, + budget=budget, + ) + + @classmethod + def from_connector( + cls, + connector: Any, + domain: str, + workspace_id: str, + policy_set: str = "default", + parent_context_id: str | None = None, + is_replay: bool = False, + ) -> ExecutionContext: + """ + Create an :class:`ExecutionContext` inheriting trust from a connector. + + The connector's ``TRUST_LEVEL`` class variable is read automatically. + Falls back to ``"semi_trusted"`` if the connector has no ``TRUST_LEVEL``. + + Parameters + ---------- + connector : object + Any :class:`~gnat.clients.base.BaseClient` subclass instance. + domain : str + Operational domain for this execution. + workspace_id : str + Target workspace. + """ + trust_level = getattr(type(connector), "TRUST_LEVEL", "semi_trusted") + initiated_by = type(connector).__name__ + return cls.create( + initiated_by=initiated_by, + domain=domain, + workspace_id=workspace_id, + trust_level=trust_level, + policy_set=policy_set, + parent_context_id=parent_context_id, + is_replay=is_replay, + ) + + # ── Serialization ────────────────────────────────────────────────────────── + + def to_dict(self) -> dict[str, Any]: + """Serialize to a plain dict suitable for Postgres insertion.""" + return { + "context_id": self.context_id, + "initiated_by": self.initiated_by, + "domain": self.domain, + "trust_level": self.trust_level, + "policy_set": self.policy_set, + "workspace_id": self.workspace_id, + "created_at": self.created_at.isoformat(), + "parent_context_id": self.parent_context_id, + "is_replay": self.is_replay, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ExecutionContext: + """Deserialize from a plain dict (e.g. from a DB row).""" + created_at = data["created_at"] + if isinstance(created_at, str): + created_at = datetime.fromisoformat(created_at) + return cls( + context_id=data["context_id"], + initiated_by=data["initiated_by"], + domain=data["domain"], + trust_level=data["trust_level"], + policy_set=data.get("policy_set", "default"), + workspace_id=data["workspace_id"], + created_at=created_at, + parent_context_id=data.get("parent_context_id"), + is_replay=bool(data.get("is_replay", False)), + ) + + # ── Helpers ──────────────────────────────────────────────────────────────── + + def child(self, initiated_by: str, domain: str | None = None) -> ExecutionContext: + """ + Create a child context for a sub-operation. + + Inherits ``workspace_id``, ``trust_level``, and ``policy_set`` from the + parent. ``parent_context_id`` is set to this context's ``context_id``. + """ + return ExecutionContext.create( + initiated_by=initiated_by, + domain=domain or self.domain, + workspace_id=self.workspace_id, + trust_level=self.trust_level, + policy_set=self.policy_set, + parent_context_id=self.context_id, + is_replay=self.is_replay, + ) + + def __repr__(self) -> str: # pragma: no cover + return ( + f"ExecutionContext(id={self.context_id[:8]}…, " + f"by={self.initiated_by!r}, domain={self.domain!r}, " + f"trust={self.trust_level!r})" + ) diff --git a/gnat/core/domains.py b/gnat/core/domains.py new file mode 100644 index 00000000..7b160126 --- /dev/null +++ b/gnat/core/domains.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.core.domains +================= + +Operational domain model and cross-domain boundary enforcement. + +Five domains partition GNAT operations: + +.. code-block:: text + + ingestion — connector pulls, normalisation, raw data ingest + analysis — enrichment, correlation, STIX assembly + investigation— hypothesis testing, evidence linking (read-only from ingestion) + reporting — export, visualization, alerting + execution — SOAR actions, automated response + +Domain boundaries are enforced via the :func:`domain_boundary` decorator. +Any call that violates the allowed-caller graph raises +:class:`DomainBoundaryViolation`, giving a clear, logged signal rather than +a silent constraint failure. + +Usage +----- +:: + + from gnat.core.domains import Domain, domain_boundary + + @domain_boundary(allowed_callers=[Domain.INGESTION, Domain.ANALYSIS]) + def run_enrichment(observable, context): + ... + +Trust-level enforcement +----------------------- +:: + + from gnat.core.domains import require_trust_level + + @require_trust_level("trusted_internal") + def trigger_soar_playbook(playbook_id, context): + ... +""" + +from __future__ import annotations + +import functools +import logging +import threading +from enum import Enum +from typing import Any + +from gnat.clients.base import GNATClientError + +logger = logging.getLogger(__name__) + +# Thread-local storage for the active domain stack +_domain_stack = threading.local() + + +class Domain(str, Enum): + """Enumeration of GNAT operational domains.""" + + INGESTION = "ingestion" + ANALYSIS = "analysis" + INVESTIGATION = "investigation" + REPORTING = "reporting" + EXECUTION = "execution" + + +# Permitted caller domains for each target domain. +# A domain may only be entered from the domains listed here. +DOMAIN_CALL_RULES: dict[Domain, frozenset[Domain]] = { + Domain.INGESTION: frozenset({Domain.INGESTION}), + Domain.ANALYSIS: frozenset({Domain.INGESTION, Domain.ANALYSIS}), + Domain.INVESTIGATION: frozenset({Domain.ANALYSIS, Domain.INVESTIGATION}), + Domain.REPORTING: frozenset({Domain.INVESTIGATION, Domain.REPORTING}), + Domain.EXECUTION: frozenset({Domain.INVESTIGATION, Domain.EXECUTION}), +} + +# Trust levels ordered from least to most privileged +_TRUST_ORDER = ["untrusted_external", "semi_trusted", "trusted_internal"] + + +class DomainBoundaryViolation(GNATClientError): + """ + Raised when an operation attempts an illegal cross-domain call. + + Attributes + ---------- + caller_domain : str + The domain that attempted the cross-boundary call. + target_domain : str + The domain that was called illegally. + """ + + def __init__( + self, + caller_domain: str, + target_domain: str, + detail: str = "", + ) -> None: + """Initialize DomainBoundaryViolation.""" + self.caller_domain = caller_domain + self.target_domain = target_domain + msg = ( + f"Domain boundary violation: {caller_domain!r} cannot call into " + f"{target_domain!r} domain." + ) + if detail: + msg = f"{msg} {detail}" + super().__init__(msg) + + +class TrustLevelViolation(GNATClientError): + """ + Raised when an operation requires a higher trust level than present. + + Attributes + ---------- + required : str + The required minimum trust level. + actual : str + The trust level of the active context. + """ + + def __init__(self, required: str, actual: str) -> None: + """Initialize TrustLevelViolation.""" + self.required = required + self.actual = actual + super().__init__( + f"Trust level violation: operation requires {required!r} " + f"but active trust level is {actual!r}." + ) + + +# ── Active domain stack helpers ──────────────────────────────────────────────── + +def _get_domain_stack() -> list[Domain]: + """Return the thread-local domain stack, initialising if absent.""" + if not hasattr(_domain_stack, "stack"): + _domain_stack.stack = [] + return _domain_stack.stack # type: ignore[return-value] + + +def _current_domain() -> Domain | None: + """Return the domain at the top of the thread-local stack, or None.""" + stack = _get_domain_stack() + return stack[-1] if stack else None + + +# ── Decorators ───────────────────────────────────────────────────────────────── + +def domain_boundary(target_domain: Domain, allowed_callers: list[Domain] | None = None): + """ + Decorator that enforces domain boundary rules on the wrapped function. + + The decorated function is tagged with *target_domain*. When called, the + decorator checks that the currently active domain (from the thread-local + stack) is in the set of allowed callers. If not, + :class:`DomainBoundaryViolation` is raised before the function executes. + + If no caller domain is active (i.e. this is a top-level call), the call + is permitted — external code can always enter any domain at the top level. + + Parameters + ---------- + target_domain : Domain + The domain this function belongs to. + allowed_callers : list of Domain, optional + Domains permitted to call this function. Defaults to the global + :data:`DOMAIN_CALL_RULES` for *target_domain*. + + Examples + -------- + :: + + @domain_boundary(Domain.INGESTION) + def run_ingest_pipeline(source, context): + ... + + @domain_boundary(Domain.REPORTING, allowed_callers=[Domain.INVESTIGATION]) + def generate_report(workspace, context): + ... + """ + effective_callers = ( + frozenset(allowed_callers) + if allowed_callers is not None + else DOMAIN_CALL_RULES.get(target_domain, frozenset()) + ) + + def decorator(func): # type: ignore[no-untyped-def] + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + caller = _current_domain() + if caller is not None and caller not in effective_callers: + logger.warning( + "DomainBoundaryViolation: %r -> %r in %s", + caller.value, + target_domain.value, + func.__qualname__, + ) + raise DomainBoundaryViolation( + caller_domain=caller.value, + target_domain=target_domain.value, + detail=f"(function: {func.__qualname__})", + ) + + stack = _get_domain_stack() + stack.append(target_domain) + try: + return func(*args, **kwargs) + finally: + stack.pop() + + wrapper._domain = target_domain # type: ignore[attr-defined] + return wrapper + + return decorator + + +def require_trust_level(minimum: str): + """ + Decorator that enforces a minimum trust level on the wrapped function. + + The *minimum* trust level is compared to the ``trust_level`` attribute of + the first argument named ``context`` (or the first positional arg if no + ``context`` kwarg is present). Falls back to no-op if no context is found. + + Parameters + ---------- + minimum : str + Minimum required trust level: ``"trusted_internal"``, + ``"semi_trusted"``, or ``"untrusted_external"``. + + Raises + ------ + TrustLevelViolation + If the active context's trust level is below *minimum*. + + Examples + -------- + :: + + @require_trust_level("trusted_internal") + def trigger_soar_playbook(playbook_id, context): + ... + """ + min_rank = _TRUST_ORDER.index(minimum) if minimum in _TRUST_ORDER else 0 + + def decorator(func): # type: ignore[no-untyped-def] + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # Try to find context in kwargs first, then positional args + ctx = kwargs.get("context") + if ctx is None: + for arg in args: + if hasattr(arg, "trust_level"): + ctx = arg + break + + if ctx is not None: + actual = getattr(ctx, "trust_level", "untrusted_external") + actual_rank = ( + _TRUST_ORDER.index(actual) if actual in _TRUST_ORDER else 0 + ) + if actual_rank < min_rank: + logger.warning( + "TrustLevelViolation in %s: required=%r actual=%r", + func.__qualname__, + minimum, + actual, + ) + raise TrustLevelViolation(required=minimum, actual=actual) + + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/gnat/orm/base.py b/gnat/orm/base.py index 8090c7ee..cea824bf 100644 --- a/gnat/orm/base.py +++ b/gnat/orm/base.py @@ -72,6 +72,9 @@ class STIXBase: stix_type: str = "stix-object" + # Schema version for GNAT ORM evolution (increment on breaking changes) + schema_version: int = 1 + def __init__( self, client: Optional["GNATClient"] = None, diff --git a/gnat/policy/__init__.py b/gnat/policy/__init__.py index 0fe86d4d..5b798f08 100644 --- a/gnat/policy/__init__.py +++ b/gnat/policy/__init__.py @@ -42,9 +42,11 @@ def create(key=Depends(engine.require(Permission.WRITE_INVESTIGATIONS, from gnat.policy.engine import PolicyEngine from gnat.policy.middleware import build_audit_middleware from gnat.policy.models import ( + AgentActionType, Permission, Role, ROLE_PERMISSIONS, + agent_can_act, permissions_for, roles_with, ) @@ -52,9 +54,11 @@ def create(key=Depends(engine.require(Permission.WRITE_INVESTIGATIONS, __all__ = [ "PolicyEngine", "build_audit_middleware", + "AgentActionType", "Permission", "Role", "ROLE_PERMISSIONS", + "agent_can_act", "permissions_for", "roles_with", ] diff --git a/gnat/policy/models.py b/gnat/policy/models.py index 1455ed92..9127c6da 100644 --- a/gnat/policy/models.py +++ b/gnat/policy/models.py @@ -102,3 +102,46 @@ def permissions_for(role: Role) -> set[Permission]: def roles_with(permission: Permission) -> list[Role]: """Return all roles that have *permission*.""" return [r for r, perms in ROLE_PERMISSIONS.items() if permission in perms] + + +# --------------------------------------------------------------------------- +# Agent permission matrix +# --------------------------------------------------------------------------- + +class AgentActionType(str, Enum): + """Action types that agents can request permission to perform.""" + + READ_STIX = "read_stix" # Read STIX objects from workspace + WRITE_STIX = "write_stix" # Create/modify STIX objects + DELETE_STIX = "delete_stix" # Remove STIX objects + ENRICH = "enrich" # Call connector enrichment + INGEST = "ingest" # Pull data from external sources + EXPORT = "export" # Push data to external systems + TRIGGER_PLAYBOOK = "trigger_playbook" # Fire SOAR playbook + MANAGE_WORKSPACE = "manage_workspace" # Create/delete workspaces + ESCALATE = "escalate" # Raise impact level / request human review + HYPOTHESIZE = "hypothesize" # Propose or evaluate hypotheses + + +# Default permission matrix: which action types are permitted per trust level +_TRUST_ACTION_PERMISSIONS: dict[str, set[AgentActionType]] = { + "trusted_internal": set(AgentActionType), # all actions + "semi_trusted": { + AgentActionType.READ_STIX, + AgentActionType.WRITE_STIX, + AgentActionType.ENRICH, + AgentActionType.INGEST, + AgentActionType.HYPOTHESIZE, + AgentActionType.ESCALATE, + }, + "untrusted_external": { + AgentActionType.READ_STIX, + AgentActionType.ENRICH, + AgentActionType.HYPOTHESIZE, + }, +} + + +def agent_can_act(trust_level: str, action_type: AgentActionType) -> bool: + """Return True if *trust_level* permits *action_type* by default.""" + return action_type in _TRUST_ACTION_PERMISSIONS.get(trust_level, set()) diff --git a/gnat/reasoning/__init__.py b/gnat/reasoning/__init__.py new file mode 100644 index 00000000..3f91c1a8 --- /dev/null +++ b/gnat/reasoning/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.reasoning +============== + +Phase 4C reasoning layer: hypothesis testing, negative evidence tracking, +and evidence-weighted prioritisation. +""" + +from gnat.reasoning.hypothesis import HypothesisEngine +from gnat.reasoning.engine import ReasoningEngine + +__all__ = ["HypothesisEngine", "ReasoningEngine"] diff --git a/gnat/reasoning/engine.py b/gnat/reasoning/engine.py new file mode 100644 index 00000000..e6133881 --- /dev/null +++ b/gnat/reasoning/engine.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.reasoning.engine +===================== + +Evidence-weighted reasoning engine for observable prioritisation. + +:class:`ReasoningEngine` takes a set of STIX observables and scores them +based on connector trust level, hypothesis confidence, negative evidence +TTL, object age, and cross-connector corroboration. Outputs are stored +as STIX ``note`` objects linked to the scored observables. + +Usage +----- +:: + + from gnat.reasoning.engine import ReasoningEngine + from gnat.core.context import ExecutionContext + + engine = ReasoningEngine(manager=manager, workspace_name="analysis-ws") + ctx = ExecutionContext.create( + initiated_by="manual", domain="analysis", workspace_id="analysis-ws" + ) + results = engine.prioritize(observables, context=ctx) + for observable, score, explanation in results: + print(f"{score:.2f} {explanation['summary']}") +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from gnat.stix.sdos.negative_evidence import NegativeEvidenceRecord + +if TYPE_CHECKING: + from gnat.core.context import ExecutionContext + from gnat.context.workspace import WorkspaceManager + from gnat.orm.base import STIXBase + +logger = logging.getLogger(__name__) + +# Trust level → scoring weight +_TRUST_WEIGHTS: dict[str, float] = { + "trusted_internal": 0.9, + "semi_trusted": 0.6, + "untrusted_external": 0.3, +} + +# Object age scoring: confidence decays by this fraction per 24-hour period +_AGE_DECAY_PER_DAY = 0.05 + +# Maximum corroboration bonus +_MAX_CORROBORATION_BONUS = 0.25 + + +class ReasoningEngine: + """ + Evidence-weighted observable prioritisation engine. + + Scores STIX observables using multiple signal sources: + + * **Connector trust level** — objects from trusted internal platforms + receive higher base weight. + * **Hypothesis confidence** — observables linked to high-confidence + hypotheses receive a boost. + * **Negative evidence** — recent negative evidence records suppress + the score (connector returned nothing for this observable). + * **Object age** — older objects decay gradually. + * **Cross-connector corroboration** — corroboration count from Solr + search provides a bounded bonus. + + Outputs are structured dicts (not free text) for machine readability. + Results are stored as STIX ``note`` objects in the workspace. + + Parameters + ---------- + manager : WorkspaceManager + Workspace manager. + workspace_name : str + Name of the target workspace. + search_index : SearchIndex, optional + Solr (or Null) search index for corroboration queries. + """ + + def __init__( + self, + manager: "WorkspaceManager", + workspace_name: str = "analysis", + search_index: Any | None = None, + ) -> None: + """Initialize ReasoningEngine.""" + self._manager = manager + self._workspace_name = workspace_name + if search_index is not None: + self._search_index = search_index + else: + from gnat.search.index import NullSearchIndex + self._search_index = NullSearchIndex() + + # ── Public API ───────────────────────────────────────────────────────────── + + def prioritize( + self, + observable_set: list["STIXBase"], + context: "ExecutionContext | None" = None, + store_notes: bool = True, + ) -> list[tuple["STIXBase", float, dict[str, Any]]]: + """ + Score and rank a set of observables. + + Parameters + ---------- + observable_set : list of STIXBase + Observables to evaluate. + context : ExecutionContext, optional + Active execution context. Provides trust level and workspace info. + store_notes : bool + If ``True``, write scoring results as STIX ``note`` objects. + Default ``True``. + + Returns + ------- + list of (STIXBase, float, dict) + Tuples of ``(observable, score, explanation)`` sorted by score + descending. Score is in ``[0.0, 1.0]``. Explanation dict is + machine-readable (not free text). + """ + results: list[tuple["STIXBase", float, dict[str, Any]]] = [] + ws = self._manager.open(self._workspace_name) + + # Gather negative evidence records for fast lookup + neg_evidence_by_target: dict[str, list[NegativeEvidenceRecord]] = {} + for obj in ws.objects.values(): + raw = obj.to_dict() if hasattr(obj, "to_dict") else {} + if raw.get("type") == NegativeEvidenceRecord.stix_type: + rec = NegativeEvidenceRecord.from_dict(raw) + target = rec._properties.get("target_ref", "") + neg_evidence_by_target.setdefault(target, []).append(rec) + + for observable in observable_set: + score, explanation = self._score_observable( + observable=observable, + context=context, + neg_evidence=neg_evidence_by_target.get(observable.id, []), + ) + results.append((observable, score, explanation)) + + if store_notes: + self._store_note(observable, score, explanation) + + results.sort(key=lambda t: t[1], reverse=True) + logger.info( + "ReasoningEngine.prioritize: scored %d observables in workspace %r", + len(results), + self._workspace_name, + ) + return results + + # ── Scoring ──────────────────────────────────────────────────────────────── + + def _score_observable( + self, + observable: "STIXBase", + context: "ExecutionContext | None", + neg_evidence: list[NegativeEvidenceRecord], + ) -> tuple[float, dict[str, Any]]: + """Compute a composite score for one observable.""" + explanation: dict[str, Any] = { + "observable_id": observable.id, + "observable_type": getattr(observable, "stix_type", "unknown"), + "components": {}, + } + + # 1. Connector trust weight + trust = "semi_trusted" + if context is not None: + trust = context.trust_level + trust_weight = _TRUST_WEIGHTS.get(trust, 0.5) + explanation["components"]["trust_weight"] = { + "trust_level": trust, + "weight": trust_weight, + } + + # 2. Object age decay + age_factor = self._age_factor(observable) + explanation["components"]["age_factor"] = age_factor + + # 3. Negative evidence penalty + neg_penalty = 0.0 + fresh_neg = [r for r in neg_evidence if not r.is_expired()] + if fresh_neg: + neg_penalty = min(0.3 * len(fresh_neg), 0.6) + explanation["components"]["negative_evidence"] = { + "count": len(fresh_neg), + "penalty": neg_penalty, + } + + # 4. Corroboration bonus from search index + corroboration_bonus = 0.0 + try: + obj_id = observable.id + hits = self._search_index.search(obj_id, limit=10) + corroboration_bonus = min(len(hits) * 0.05, _MAX_CORROBORATION_BONUS) + except Exception as exc: # noqa: BLE001 + logger.debug("ReasoningEngine: search index unavailable — %s", exc) + explanation["components"]["corroboration"] = { + "hits": int(corroboration_bonus / 0.05) if corroboration_bonus else 0, + "bonus": corroboration_bonus, + } + + # 5. Composite score + raw_score = ( + trust_weight * 0.4 + + age_factor * 0.3 + + corroboration_bonus * 0.3 + - neg_penalty * 0.5 + ) + score = max(0.0, min(1.0, raw_score)) + explanation["score"] = round(score, 4) + explanation["summary"] = ( + f"score={score:.2f} trust={trust} age_factor={age_factor:.2f} " + f"neg_penalty={neg_penalty:.2f} corroboration_bonus={corroboration_bonus:.2f}" + ) + return score, explanation + + @staticmethod + def _age_factor(observable: "STIXBase") -> float: + """Return a 0–1 factor where 1.0 = fresh, decaying with object age.""" + modified_str = getattr(observable, "modified", "") + if not modified_str: + return 0.5 + try: + ts = datetime.fromisoformat(modified_str.rstrip("Z")) + if ts.tzinfo is None: + ts = ts.replace(tzinfo=timezone.utc) + age_days = (datetime.now(timezone.utc) - ts).total_seconds() / 86400.0 + factor = max(0.0, 1.0 - _AGE_DECAY_PER_DAY * age_days) + return round(factor, 4) + except (ValueError, TypeError): + return 0.5 + + # ── Note storage ─────────────────────────────────────────────────────────── + + def _store_note( + self, + observable: "STIXBase", + score: float, + explanation: dict[str, Any], + ) -> None: + """Persist a STIX note object recording the scoring rationale.""" + import json + + from gnat.orm.base import STIXBase as _STIXBase, _utcnow + + note_dict = { + "type": "note", + "id": f"note--{__import__('uuid').uuid4()}", + "spec_version": "2.1", + "created": _utcnow(), + "modified": _utcnow(), + "abstract": f"Reasoning score: {score:.4f}", + "content": json.dumps(explanation, indent=2), + "object_refs": [observable.id], + "x_gnat_reasoning_score": score, + } + try: + ws = self._manager.open(self._workspace_name) + ws._add_object(note_dict, mark_dirty=False) + except Exception as exc: # noqa: BLE001 + logger.debug("ReasoningEngine: could not store note — %s", exc) diff --git a/gnat/reasoning/hypothesis.py b/gnat/reasoning/hypothesis.py new file mode 100644 index 00000000..0026b4c4 --- /dev/null +++ b/gnat/reasoning/hypothesis.py @@ -0,0 +1,256 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.reasoning.hypothesis +========================= + +Hypothesis lifecycle engine. + +:class:`HypothesisEngine` manages the full ``propose → evaluate → close`` +lifecycle for :class:`~gnat.stix.sdos.hypothesis.STIXHypothesis` objects. + +All state is persisted via the existing workspace/store path — hypotheses +are STIX objects and follow the same storage patterns as indicators or +threat actors. + +Usage +----- +:: + + from gnat.reasoning.hypothesis import HypothesisEngine + from gnat.context.workspace import WorkspaceManager + + manager = WorkspaceManager.default() + engine = HypothesisEngine(manager=manager, workspace_name="analysis-ws") + + h = engine.propose( + statement="192.0.2.1 is a Lazarus Group C2.", + initial_evidence=[indicator_stix_id], + ) + h = engine.evaluate(h.id) + print(h.confidence, h.status) + h = engine.close(h.id, verdict="confirmed") +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from gnat.stix.sdos.hypothesis import STIXHypothesis + +if TYPE_CHECKING: + from gnat.context.workspace import WorkspaceManager + +logger = logging.getLogger(__name__) + +# Trust level → evidence weight mapping +_TRUST_WEIGHTS: dict[str, float] = { + "trusted_internal": 0.9, + "semi_trusted": 0.6, + "untrusted_external": 0.3, +} +_DEFAULT_WEIGHT = 0.5 + + +class HypothesisEngine: + """ + Manages the propose → evaluate → close lifecycle for STIX hypotheses. + + Parameters + ---------- + manager : WorkspaceManager + Workspace manager used to open the target workspace. + workspace_name : str + Name of the workspace where hypotheses are stored. + search_index : SearchIndex, optional + Solr (or Null) index used for evidence corroboration queries. + Defaults to ``NullSearchIndex`` if omitted. + + Examples + -------- + :: + + engine = HypothesisEngine(manager=manager, workspace_name="threats") + h = engine.propose("APT29 behind Q1 phishing", ["indicator--abc"]) + h = engine.evaluate(h.id) + h = engine.close(h.id, "confirmed") + """ + + def __init__( + self, + manager: "WorkspaceManager", + workspace_name: str = "analysis", + search_index: Any | None = None, + ) -> None: + """Initialize HypothesisEngine.""" + self._manager = manager + self._workspace_name = workspace_name + if search_index is not None: + self._search_index = search_index + else: + from gnat.search.index import NullSearchIndex + self._search_index = NullSearchIndex() + + # ── Public API ───────────────────────────────────────────────────────────── + + def propose( + self, + statement: str, + initial_evidence: list[str] | None = None, + confidence: float = 0.2, + ) -> STIXHypothesis: + """ + Create a new hypothesis with an initial confidence score. + + Parameters + ---------- + statement : str + The assertion being tested. + initial_evidence : list of str, optional + STIX relationship IDs or object IDs linking initial evidence. + confidence : float + Initial confidence in ``[0.0, 1.0]``. Defaults to 0.2 (low). + + Returns + ------- + STIXHypothesis + The newly created, persisted hypothesis. + """ + h = STIXHypothesis( + statement=statement, + confidence=confidence, + status="pending", + ) + for ev_id in (initial_evidence or []): + h.add_supporting_evidence(ev_id) + + self._persist(h) + logger.info( + "HypothesisEngine: proposed %r (confidence=%.2f, evidence=%d)", + statement[:80], + confidence, + len(initial_evidence or []), + ) + return h + + def evaluate(self, hypothesis_id: str) -> STIXHypothesis: + """ + Re-evaluate a hypothesis using the search index for corroboration. + + Queries the search index with the hypothesis statement. Matching + objects are counted and weighted by their source connector's trust + level (from ``source_platform`` metadata). The confidence score is + updated in-place. + + Parameters + ---------- + hypothesis_id : str + STIX ID of the hypothesis to evaluate. + + Returns + ------- + STIXHypothesis + The updated hypothesis. + + Raises + ------ + KeyError + If no hypothesis with *hypothesis_id* is found. + """ + h = self._load(hypothesis_id) + + # Query search index for corroborating evidence + statement = h._properties.get("statement", "") + corroborating_ids: list[str] = [] + try: + corroborating_ids = self._search_index.search(statement, limit=20) + except Exception as exc: # noqa: BLE001 + logger.debug("HypothesisEngine: search index unavailable — %s", exc) + + # Compute weighted confidence from evidence counts + support_count = len(h._properties.get("supporting_evidence", [])) + refute_count = len(h._properties.get("refuting_evidence", [])) + corroboration_boost = min(len(corroborating_ids) * 0.05, 0.3) + + if support_count + refute_count == 0 and not corroborating_ids: + # No evidence at all — stay at initial confidence + pass + else: + # Weighted ratio: support boosts, refutation reduces + total = support_count + refute_count + 1 # +1 avoids div-by-zero + raw = (support_count / total) + corroboration_boost + confidence = max(0.0, min(1.0, raw)) + h.update_confidence(confidence) + + # Auto-classify status thresholds + conf = h._properties.get("confidence", 0.0) + if conf >= 0.75: + h._properties["status"] = "confirmed" + elif conf <= 0.15 and refute_count > 0: + h._properties["status"] = "refuted" + + self._persist(h) + logger.info( + "HypothesisEngine: evaluated %s — confidence=%.2f status=%r", + hypothesis_id, + conf, + h._properties.get("status"), + ) + return h + + def close(self, hypothesis_id: str, verdict: str) -> STIXHypothesis: + """ + Finalise a hypothesis with an explicit verdict. + + Parameters + ---------- + hypothesis_id : str + STIX ID of the hypothesis. + verdict : str + Final verdict: ``"confirmed"``, ``"refuted"``, or ``"inconclusive"``. + + Returns + ------- + STIXHypothesis + The closed hypothesis. + """ + h = self._load(hypothesis_id) + h.close(verdict) + self._persist(h) + logger.info( + "HypothesisEngine: closed %s with verdict %r", hypothesis_id, verdict + ) + return h + + def get(self, hypothesis_id: str) -> STIXHypothesis | None: + """Return a hypothesis by ID, or ``None`` if not found.""" + try: + return self._load(hypothesis_id) + except KeyError: + return None + + def list_all(self) -> list[STIXHypothesis]: + """Return all hypotheses in the workspace.""" + ws = self._manager.open(self._workspace_name) + result = [] + for obj in ws.objects.values(): + if getattr(obj, "stix_type", "") == STIXHypothesis.stix_type: + result.append(obj) + elif isinstance(obj, dict) and obj.get("type") == STIXHypothesis.stix_type: + result.append(STIXHypothesis.from_dict(obj)) + return result + + # ── Internal helpers ─────────────────────────────────────────────────────── + + def _persist(self, h: STIXHypothesis) -> None: + ws = self._manager.open(self._workspace_name) + ws._add_object(h.to_dict(), mark_dirty=True) + + def _load(self, hypothesis_id: str) -> STIXHypothesis: + ws = self._manager.open(self._workspace_name) + obj = ws.objects.get(hypothesis_id) + if obj is None: + raise KeyError(f"No hypothesis found with id {hypothesis_id!r}") + raw = obj.to_dict() if hasattr(obj, "to_dict") else obj + return STIXHypothesis.from_dict(raw) diff --git a/gnat/stix/sdos/__init__.py b/gnat/stix/sdos/__init__.py new file mode 100644 index 00000000..fe062c0f --- /dev/null +++ b/gnat/stix/sdos/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.stix.sdos +============== + +Custom STIX 2.1 Domain Objects (SDOs) for GNAT Phase 4 reasoning and +evidence tracking. + +Custom SDO types follow the STIX 2.1 ``x--`` naming +convention and are stored via the existing workspace ORM path. +""" + +from gnat.stix.sdos.hypothesis import STIXHypothesis +from gnat.stix.sdos.negative_evidence import NegativeEvidenceRecord + +__all__ = ["STIXHypothesis", "NegativeEvidenceRecord"] diff --git a/gnat/stix/sdos/hypothesis.py b/gnat/stix/sdos/hypothesis.py new file mode 100644 index 00000000..db623eaf --- /dev/null +++ b/gnat/stix/sdos/hypothesis.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.stix.sdos.hypothesis +========================== + +Custom STIX 2.1 SDO representing a reasoning hypothesis. + +A :class:`STIXHypothesis` is an assertion about a threat, actor, campaign, +or relationship that can be confirmed or refuted through evidence. It is +stored via the existing STIX ORM path (``workspace._add_object()``) so no +new storage infrastructure is needed. + +STIX type: ``x-gnat-hypothesis`` + +Usage +----- +:: + + from gnat.stix.sdos.hypothesis import STIXHypothesis + + h = STIXHypothesis( + statement="APT29 is responsible for the Q1 2026 phishing campaign.", + confidence=0.4, + ) + h.add_supporting_evidence("relationship--abc123") + print(h.to_dict()) +""" + +from __future__ import annotations + +from typing import Any, Optional + +from gnat.orm.base import STIXBase, _utcnow + + +class STIXHypothesis(STIXBase): + """ + Custom STIX 2.1 SDO — ``x-gnat-hypothesis``. + + Represents an analyst or engine hypothesis about threat attribution, + campaign linkage, or actor identity. Evidence is tracked through + STIX relationship IDs pointing to supporting or refuting objects. + + Parameters + ---------- + statement : str + Human-readable assertion being tested. + confidence : float + Initial confidence score in the range ``[0.0, 1.0]``. + status : str + Lifecycle status: ``"pending"``, ``"confirmed"``, ``"refuted"``, + or ``"inconclusive"``. + + Examples + -------- + :: + + h = STIXHypothesis( + statement="192.0.2.1 is a C2 server for Lazarus Group.", + confidence=0.3, + ) + h.add_supporting_evidence("relationship--uuid1") + h.add_refuting_evidence("relationship--uuid2") + h.update_confidence(0.7) + h.close("confirmed") + """ + + stix_type = "x-gnat-hypothesis" + schema_version = 1 + + # Valid lifecycle statuses + STATUSES = frozenset({"pending", "confirmed", "refuted", "inconclusive"}) + + def __init__( + self, + statement: str = "", + confidence: float = 0.0, + status: str = "pending", + supporting_evidence: list[str] | None = None, + refuting_evidence: list[str] | None = None, + client: Optional[Any] = None, + **kwargs: Any, + ) -> None: + """Initialize STIXHypothesis.""" + super().__init__(client=client, **kwargs) + if status not in self.STATUSES: + raise ValueError( + f"Invalid hypothesis status {status!r}. " + f"Must be one of: {sorted(self.STATUSES)}" + ) + if not (0.0 <= confidence <= 1.0): + raise ValueError( + f"confidence must be in [0.0, 1.0], got {confidence!r}" + ) + self._properties["statement"] = statement + self._properties["confidence"] = float(confidence) + self._properties["status"] = status + self._properties["supporting_evidence"] = list(supporting_evidence or []) + self._properties["refuting_evidence"] = list(refuting_evidence or []) + + # ── Evidence management ──────────────────────────────────────────────────── + + def add_supporting_evidence(self, relationship_id: str) -> None: + """ + Link a supporting STIX relationship to this hypothesis. + + Parameters + ---------- + relationship_id : str + STIX ID of a ``relationship`` object linking evidence to this hypothesis. + """ + if relationship_id not in self._properties["supporting_evidence"]: + self._properties["supporting_evidence"].append(relationship_id) + self.modified = _utcnow() + + def add_refuting_evidence(self, relationship_id: str) -> None: + """ + Link a refuting STIX relationship to this hypothesis. + + Parameters + ---------- + relationship_id : str + STIX ID of a ``relationship`` object linking contradicting evidence. + """ + if relationship_id not in self._properties["refuting_evidence"]: + self._properties["refuting_evidence"].append(relationship_id) + self.modified = _utcnow() + + # ── Lifecycle ────────────────────────────────────────────────────────────── + + def update_confidence(self, confidence: float) -> None: + """ + Update the confidence score. + + Parameters + ---------- + confidence : float + New score in ``[0.0, 1.0]``. + """ + if not (0.0 <= confidence <= 1.0): + raise ValueError(f"confidence must be in [0.0, 1.0], got {confidence!r}") + self._properties["confidence"] = float(confidence) + self.modified = _utcnow() + + def close(self, verdict: str) -> None: + """ + Finalise this hypothesis with a verdict. + + Parameters + ---------- + verdict : str + One of ``"confirmed"``, ``"refuted"``, or ``"inconclusive"``. + """ + valid = {"confirmed", "refuted", "inconclusive"} + if verdict not in valid: + raise ValueError(f"verdict must be one of {sorted(valid)}, got {verdict!r}") + self._properties["status"] = verdict + self.modified = _utcnow() + + # ── Serialization ────────────────────────────────────────────────────────── + + def to_dict(self) -> dict[str, Any]: + """Serialise to a STIX-compatible dict.""" + return { + "type": self.stix_type, + "id": self.id, + "spec_version": self.spec_version, + "created": self.created, + "modified": self.modified, + "statement": self._properties.get("statement", ""), + "confidence": self._properties.get("confidence", 0.0), + "status": self._properties.get("status", "pending"), + "supporting_evidence": self._properties.get("supporting_evidence", []), + "refuting_evidence": self._properties.get("refuting_evidence", []), + } + + @classmethod + def from_dict(cls, data: dict[str, Any], client: Optional[Any] = None) -> STIXHypothesis: + """Deserialise from a STIX dict.""" + obj = cls( + statement=data.get("statement", ""), + confidence=float(data.get("confidence", 0.0)), + status=data.get("status", "pending"), + supporting_evidence=data.get("supporting_evidence", []), + refuting_evidence=data.get("refuting_evidence", []), + client=client, + id=data.get("id"), + created=data.get("created"), + modified=data.get("modified"), + spec_version=data.get("spec_version", "2.1"), + ) + return obj + + def __repr__(self) -> str: # pragma: no cover + stmt = self._properties.get("statement", "")[:50] + return ( + f"STIXHypothesis(status={self._properties.get('status')!r}, " + f"confidence={self._properties.get('confidence'):.2f}, " + f"statement={stmt!r})" + ) diff --git a/gnat/stix/sdos/negative_evidence.py b/gnat/stix/sdos/negative_evidence.py new file mode 100644 index 00000000..faee16f0 --- /dev/null +++ b/gnat/stix/sdos/negative_evidence.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.stix.sdos.negative_evidence +================================= + +Custom STIX 2.1 SDO representing a negative enrichment result. + +A :class:`NegativeEvidenceRecord` is written when a connector returns no +results for a lookup. It prevents redundant re-queries within a configurable +TTL window and contributes (negatively) to hypothesis confidence scoring. + +STIX type: ``x-gnat-negative-evidence`` + +Usage +----- +:: + + from gnat.stix.sdos.negative_evidence import NegativeEvidenceRecord + + rec = NegativeEvidenceRecord( + target_ref="indicator--abc123", + queried_connector="VirusTotalClient", + ttl_seconds=3600, + ) + print(rec.is_expired()) # False immediately after creation +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any, Optional + +from gnat.orm.base import STIXBase, _utcnow + + +class NegativeEvidenceRecord(STIXBase): + """ + Custom STIX 2.1 SDO — ``x-gnat-negative-evidence``. + + Records that a connector returned no results for a specific lookup. + Callers should check :meth:`is_expired` before re-querying — if the + record is fresh, skip the query entirely. + + Parameters + ---------- + target_ref : str + STIX ID of the object that was queried (e.g. an ``indicator`` id). + queried_connector : str + Class name of the connector that returned no results. + ttl_seconds : int + Seconds after creation before this record expires and a re-query + is permitted. Default ``3600`` (1 hour). + query_timestamp : str, optional + ISO 8601 UTC timestamp of the failed query. Defaults to now. + """ + + stix_type = "x-gnat-negative-evidence" + schema_version = 1 + + def __init__( + self, + target_ref: str = "", + queried_connector: str = "", + ttl_seconds: int = 3600, + query_timestamp: str | None = None, + client: Optional[Any] = None, + **kwargs: Any, + ) -> None: + """Initialize NegativeEvidenceRecord.""" + super().__init__(client=client, **kwargs) + self._properties["target_ref"] = target_ref + self._properties["queried_connector"] = queried_connector + self._properties["ttl_seconds"] = int(ttl_seconds) + self._properties["query_timestamp"] = query_timestamp or _utcnow() + + # ── TTL helpers ──────────────────────────────────────────────────────────── + + def is_expired(self) -> bool: + """ + Return ``True`` if the TTL has elapsed since ``query_timestamp``. + + A re-query is safe only when this returns ``True``. + """ + ts_str = self._properties.get("query_timestamp", "") + if not ts_str: + return True + try: + ts = datetime.fromisoformat(ts_str.rstrip("Z")) + if ts.tzinfo is None: + ts = ts.replace(tzinfo=timezone.utc) + elapsed = (datetime.now(timezone.utc) - ts).total_seconds() + return elapsed >= self._properties.get("ttl_seconds", 3600) + except (ValueError, TypeError): + return True + + def seconds_remaining(self) -> float: + """Return seconds until TTL expiry (0 if already expired).""" + ts_str = self._properties.get("query_timestamp", "") + if not ts_str: + return 0.0 + try: + ts = datetime.fromisoformat(ts_str.rstrip("Z")) + if ts.tzinfo is None: + ts = ts.replace(tzinfo=timezone.utc) + elapsed = (datetime.now(timezone.utc) - ts).total_seconds() + remaining = self._properties.get("ttl_seconds", 3600) - elapsed + return max(0.0, remaining) + except (ValueError, TypeError): + return 0.0 + + # ── Serialization ────────────────────────────────────────────────────────── + + def to_dict(self) -> dict[str, Any]: + """Serialise to a STIX-compatible dict.""" + return { + "type": self.stix_type, + "id": self.id, + "spec_version": self.spec_version, + "created": self.created, + "modified": self.modified, + "target_ref": self._properties.get("target_ref", ""), + "queried_connector": self._properties.get("queried_connector", ""), + "ttl_seconds": self._properties.get("ttl_seconds", 3600), + "query_timestamp": self._properties.get("query_timestamp", ""), + } + + @classmethod + def from_dict(cls, data: dict[str, Any], client: Optional[Any] = None) -> NegativeEvidenceRecord: + """Deserialise from a STIX dict.""" + return cls( + target_ref=data.get("target_ref", ""), + queried_connector=data.get("queried_connector", ""), + ttl_seconds=int(data.get("ttl_seconds", 3600)), + query_timestamp=data.get("query_timestamp"), + client=client, + id=data.get("id"), + created=data.get("created"), + modified=data.get("modified"), + spec_version=data.get("spec_version", "2.1"), + ) + + def __repr__(self) -> str: # pragma: no cover + return ( + f"NegativeEvidenceRecord(" + f"connector={self._properties.get('queried_connector')!r}, " + f"target={self._properties.get('target_ref')!r}, " + f"expired={self.is_expired()})" + ) diff --git a/gnat/testing/__init__.py b/gnat/testing/__init__.py new file mode 100644 index 00000000..87f1b418 --- /dev/null +++ b/gnat/testing/__init__.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.testing +============ + +Phase 4E-3 testing framework for GNAT. + +Provides simulation and replay primitives that allow full pipeline tests +without network access: + +* :class:`SimulationConnector` — a ``BaseClient`` subclass that returns + canned STIX fixtures for any query; no real HTTP calls. +* :class:`ReplayRunner` — replays a recorded ``execution_log`` sequence + through the current pipeline, asserting output matches. +* :class:`AgentTestHarness` — wraps :class:`~gnat.agents.governor.AgentGovernor` + and :class:`~gnat.agents.hitl.HITLGateway` with mock approval responses + for deterministic agent action tests. + +Usage +----- +:: + + from gnat.testing import SimulationConnector, AgentTestHarness + + connector = SimulationConnector(fixtures=[indicator_dict]) + harness = AgentTestHarness() + harness.governor.can_act("agent-1", AgentActionType.ENRICH, "semi_trusted") +""" + +from gnat.testing.simulation import AgentTestHarness, ReplayRunner, SimulationConnector + +__all__ = ["SimulationConnector", "ReplayRunner", "AgentTestHarness"] diff --git a/gnat/testing/simulation.py b/gnat/testing/simulation.py new file mode 100644 index 00000000..afb38cf4 --- /dev/null +++ b/gnat/testing/simulation.py @@ -0,0 +1,401 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Bill Halpin +""" +gnat.testing.simulation +======================== + +Simulation primitives for GNAT pipeline tests. + +:class:`SimulationConnector` + A :class:`~gnat.clients.base.BaseClient` subclass that returns canned + STIX fixtures without making any network calls. Useful for unit and + integration tests that must exercise the full pipeline without live + credentials. + +:class:`ReplayRunner` + Replays a sequence of ``execution_log`` rows through the current + pipeline, asserting that the output matches expected state. + +:class:`AgentTestHarness` + Wraps :class:`~gnat.agents.governor.AgentGovernor` and + :class:`~gnat.agents.hitl.HITLGateway` with mock approval responses + so agent action tests are deterministic. +""" + +from __future__ import annotations + +import logging +from typing import Any, Iterator + +from gnat.clients.base import BaseClient, GNATClientError +from gnat.agents.governor import AgentAction, AgentGovernor, RateLimitExceeded +from gnat.policy.models import AgentActionType + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# SimulationConnector +# --------------------------------------------------------------------------- + +class SimulationConnector(BaseClient): + """ + A no-network :class:`~gnat.clients.base.BaseClient` that serves canned + STIX fixtures for any query. + + All HTTP methods return data from the fixture list rather than making + real HTTP calls. Useful for unit and integration tests. + + Parameters + ---------- + fixtures : list of dict + STIX objects (or any JSON-serialisable dicts) to return from + ``list_objects()`` / ``get()`` calls. + host : str + Nominal host URL (not actually used for connections). + trust_level : str + Trust classification for this simulation connector. + Defaults to ``"semi_trusted"``. + raise_on_request : bool + If ``True``, any call to the underlying ``_request`` helper raises + ``GNATClientError``. Useful for testing error-handling paths. + + Examples + -------- + :: + + sim = SimulationConnector(fixtures=[indicator_dict, malware_dict]) + objects = sim.list_objects() # returns all fixtures + obj = sim.get_object("indicator--abc") # returns matching fixture + """ + + TRUST_LEVEL: str = "semi_trusted" + API_VERSION: str = "sim-v1" + API_PREFIX: str = "/sim" + COST_UNIT: int = 1 + + def __init__( + self, + fixtures: list[dict[str, Any]] | None = None, + host: str = "http://simulation.local", + trust_level: str = "semi_trusted", + raise_on_request: bool = False, + ) -> None: + """Initialize SimulationConnector.""" + # Bypass real connection pool setup with a dummy host + self.host = host.rstrip("/") + self.verify_ssl = False + self.timeout = 30.0 + self.config: dict[str, Any] = {} + self._auth_headers: dict[str, str] = {} + self._authenticated = True # pre-authenticated + self._context: Any = None + + self._fixtures: list[dict[str, Any]] = list(fixtures or []) + self.TRUST_LEVEL = trust_level # type: ignore[assignment] + self._raise_on_request = raise_on_request + + # Build a simple STIX-id index for fast lookup + self._index: dict[str, dict[str, Any]] = { + obj.get("id", ""): obj for obj in self._fixtures if obj.get("id") + } + logger.debug( + "SimulationConnector: loaded %d fixtures", len(self._fixtures) + ) + + # ── ConnectorMixin-compatible interface ──────────────────────────────────── + + def authenticate(self) -> None: + """No-op — simulation connector is always authenticated.""" + + def health_check(self) -> bool: + """Always healthy.""" + return True + + def list_objects( + self, + stix_type: str | None = None, + filters: dict[str, Any] | None = None, + page: int = 1, + page_size: int = 100, + ) -> list[dict[str, Any]]: + """ + Return fixture objects, optionally filtered by ``stix_type``. + + Parameters + ---------- + stix_type : str, optional + If provided, only fixtures with ``type == stix_type`` are returned. + filters : dict, optional + Currently ignored (no filter logic in simulation). + page : int + 1-indexed page number (pagination supported). + page_size : int + Objects per page. + """ + results = self._fixtures + if stix_type: + results = [f for f in results if f.get("type") == stix_type] + # Simple pagination + start = (page - 1) * page_size + return results[start: start + page_size] + + def get_object(self, stix_id: str) -> dict[str, Any] | None: + """Return the fixture with matching STIX id, or ``None``.""" + return self._index.get(stix_id) + + def upsert_object(self, stix_dict: dict[str, Any]) -> dict[str, Any]: + """Add or replace a fixture by STIX id.""" + stix_id = stix_dict.get("id", "") + self._index[stix_id] = stix_dict + existing = next( + (i for i, f in enumerate(self._fixtures) if f.get("id") == stix_id), None + ) + if existing is not None: + self._fixtures[existing] = stix_dict + else: + self._fixtures.append(stix_dict) + return stix_dict + + def delete_object(self, stix_id: str) -> None: + """Remove a fixture by STIX id (no-op if not found).""" + self._fixtures = [f for f in self._fixtures if f.get("id") != stix_id] + self._index.pop(stix_id, None) + + def to_stix(self, obj: Any) -> dict[str, Any]: + """Pass-through — fixtures are already STIX dicts.""" + return obj if isinstance(obj, dict) else {} + + def from_stix(self, stix_obj: dict[str, Any]) -> Any: + """Pass-through — return the dict as-is.""" + return stix_obj + + def add_fixture(self, stix_dict: dict[str, Any]) -> None: + """Dynamically add a fixture at runtime.""" + self.upsert_object(stix_dict) + + def iter_fixtures(self) -> Iterator[dict[str, Any]]: + """Iterate over all fixtures.""" + yield from self._fixtures + + # ── Override _request to avoid real HTTP ────────────────────────────────── + + def _request(self, method: str, path: str, **kwargs: Any) -> Any: # type: ignore[override] + if self._raise_on_request: + raise GNATClientError( + f"SimulationConnector: _request blocked (raise_on_request=True) " + f"— {method} {path}" + ) + # Budget deduction still applies + if self._context is not None: + budget = getattr(self._context, "budget", None) + if budget is not None: + budget.consume(self.COST_UNIT, type(self).__name__) + logger.debug("SimulationConnector: simulated %s %s", method, path) + return {} + + +# --------------------------------------------------------------------------- +# ReplayRunner +# --------------------------------------------------------------------------- + +class ReplayRunner: + """ + Replays a recorded ``execution_log`` sequence through the current pipeline. + + Reads :class:`~gnat.core.context.ExecutionContext` records from a log + (list of dicts) and re-executes them with ``is_replay=True``, asserting + that the output objects match ``expected_stix_ids``. + + Parameters + ---------- + pipeline_fn : callable + Function ``(context) → list[dict]`` representing the pipeline to replay. + Must accept an :class:`~gnat.core.context.ExecutionContext` and return + a list of STIX dicts produced by the run. + """ + + def __init__(self, pipeline_fn: Any) -> None: + """Initialize ReplayRunner.""" + self._pipeline_fn = pipeline_fn + + def replay( + self, + execution_log: list[dict[str, Any]], + expected_stix_ids: list[str] | None = None, + ) -> list[dict[str, Any]]: + """ + Replay log entries through the pipeline. + + Parameters + ---------- + execution_log : list of dict + Rows from ``execution_log`` table (as plain dicts). + expected_stix_ids : list of str, optional + If provided, asserts that every ID appears in the pipeline output. + + Returns + ------- + list of dict + All STIX objects produced across all replayed contexts. + + Raises + ------ + AssertionError + If *expected_stix_ids* are not all present in the output. + """ + from gnat.core.context import ExecutionContext + + all_output: list[dict[str, Any]] = [] + for row in execution_log: + ctx = ExecutionContext.from_dict({**row, "is_replay": True}) + output = self._pipeline_fn(ctx) + all_output.extend(output or []) + + if expected_stix_ids: + produced_ids = {o.get("id") for o in all_output} + missing = [sid for sid in expected_stix_ids if sid not in produced_ids] + assert not missing, ( + f"ReplayRunner: expected STIX IDs not produced: {missing}" + ) + + logger.info( + "ReplayRunner: replayed %d contexts, produced %d objects", + len(execution_log), + len(all_output), + ) + return all_output + + +# --------------------------------------------------------------------------- +# AgentTestHarness +# --------------------------------------------------------------------------- + +class _MockReviewService: + """Minimal ReviewService stub that auto-approves everything.""" + + class _MockItem: + def __init__(self, agent_id: str) -> None: + """Initialize _MockItem.""" + import uuid + from datetime import datetime, timezone + self.id = str(uuid.uuid4()) + self.submitted_by = agent_id + self.submitted_at = datetime.now(timezone.utc) + + class _Status: + value = "approved" + # Use simple string-compatible status + self.status = "approved" + + def submit(self, stix_data: Any, source_workspace: str, submitted_by: str, **_: Any) -> Any: + """Submit.""" + return self._MockItem(submitted_by) + + def approve(self, item_id: str, reviewed_by: str = "auto", **_: Any) -> Any: + """Approve.""" + return self._MockItem("") + + def reject(self, item_id: str, reviewed_by: str = "auto", **_: Any) -> Any: + """Reject.""" + item = self._MockItem("") + item.status = "rejected" + return item + + def get(self, item_id: str) -> Any: + """Get.""" + item = self._MockItem("") + from gnat.review.models import ReviewStatus + item.status = ReviewStatus.APPROVED + return item + + +class AgentTestHarness: + """ + Wraps :class:`~gnat.agents.governor.AgentGovernor` and + :class:`~gnat.agents.hitl.HITLGateway` with mock approval responses. + + All HITL submissions are auto-approved. Recorded actions are accessible + via :attr:`recorded_actions` for post-test assertion. + + Parameters + ---------- + max_calls_per_window : int + Rate limit ceiling. Defaults to ``10_000`` (effectively unlimited + for tests). + policy_overrides : dict, optional + Per-agent permission overrides forwarded to :class:`AgentGovernor`. + + Examples + -------- + :: + + harness = AgentTestHarness() + action = AgentAction( + agent_id="test-agent", + action_type=AgentActionType.ENRICH, + target_ref="indicator--abc", + impact_level="low", + ) + approved, review_item = harness.hitl.evaluate(action) + assert approved is True + harness.governor.record_action(action) + assert len(harness.recorded_actions) == 1 + """ + + def __init__( + self, + max_calls_per_window: int = 10_000, + policy_overrides: dict[str, dict[str, bool]] | None = None, + ) -> None: + """Initialize AgentTestHarness.""" + self.governor = AgentGovernor( + max_calls_per_window=max_calls_per_window, + policy_overrides=policy_overrides, + ) + + from gnat.agents.hitl import HITLGateway + self.hitl = HITLGateway( + review_service=_MockReviewService(), # type: ignore[arg-type] + approval_timeout_seconds=86400, # 24h — tests won't time out + ) + + @property + def recorded_actions(self) -> list[AgentAction]: + """Return all actions recorded by the governor.""" + return self.governor.get_action_log() + + def run_action( + self, + agent_id: str, + action_type: AgentActionType, + target_ref: str = "", + impact_level: str = "low", + trust_level: str = "semi_trusted", + ) -> tuple[bool, AgentAction]: + """ + Convenience method: check permission, rate-limit, evaluate HITL, record. + + Returns ``(approved, action)`` where *approved* reflects the HITL + decision. + + Parameters + ---------- + agent_id : str + action_type : AgentActionType + target_ref : str + impact_level : str + trust_level : str + """ + self.governor.require_can_act(agent_id, action_type, trust_level) + self.governor.rate_limit_check(agent_id) + + action = AgentAction( + agent_id=agent_id, + action_type=action_type, + target_ref=target_ref, + impact_level=impact_level, + ) + approved, _ = self.hitl.evaluate(action) + self.governor.record_action(action) + return approved, action diff --git a/tests/unit/test_phase4_core.py b/tests/unit/test_phase4_core.py new file mode 100644 index 00000000..1896719e --- /dev/null +++ b/tests/unit/test_phase4_core.py @@ -0,0 +1,443 @@ +""" +tests/unit/test_phase4_core.py +================================ +Unit tests for Phase 4A/4B/4E — ExecutionContext, QueryBudget, +Domain boundaries, SimulationConnector, ReplayRunner. +""" + +from __future__ import annotations + +import pytest +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# ExecutionContext tests +# --------------------------------------------------------------------------- + +class TestExecutionContext: + def test_create_defaults(self): + from gnat.core.context import ExecutionContext + + ctx = ExecutionContext.create( + initiated_by="test-connector", + domain="ingestion", + workspace_id="ws-1", + ) + assert ctx.initiated_by == "test-connector" + assert ctx.domain == "ingestion" + assert ctx.trust_level == "semi_trusted" + assert ctx.policy_set == "default" + assert ctx.is_replay is False + assert ctx.budget is None + assert len(ctx.context_id) == 36 # UUID format + + def test_create_with_budget(self): + from gnat.core.context import ExecutionContext, QueryBudget + + ctx = ExecutionContext.create( + initiated_by="test", + domain="analysis", + workspace_id="ws-1", + max_budget_units=500, + ) + assert ctx.budget is not None + assert isinstance(ctx.budget, QueryBudget) + assert ctx.budget.max_units == 500 + assert ctx.budget.remaining == 500 + + def test_from_connector(self): + from gnat.core.context import ExecutionContext + + connector = MagicMock() + type(connector).TRUST_LEVEL = "trusted_internal" + type(connector).__name__ = "SplunkClient" + + ctx = ExecutionContext.from_connector( + connector=connector, + domain="ingestion", + workspace_id="ws-splunk", + ) + assert ctx.trust_level == "trusted_internal" + assert ctx.initiated_by == "SplunkClient" + + def test_child_context(self): + from gnat.core.context import ExecutionContext + + parent = ExecutionContext.create( + initiated_by="pipeline", + domain="ingestion", + workspace_id="ws-1", + ) + child = parent.child("enrichment-agent", domain="analysis") + assert child.parent_context_id == parent.context_id + assert child.workspace_id == parent.workspace_id + assert child.trust_level == parent.trust_level + assert child.domain == "analysis" + + def test_to_dict_from_dict_round_trip(self): + from gnat.core.context import ExecutionContext + + ctx = ExecutionContext.create( + initiated_by="manual", + domain="investigation", + workspace_id="ws-inv", + trust_level="trusted_internal", + policy_set="strict", + is_replay=True, + ) + d = ctx.to_dict() + ctx2 = ExecutionContext.from_dict(d) + + assert ctx2.context_id == ctx.context_id + assert ctx2.domain == "investigation" + assert ctx2.trust_level == "trusted_internal" + assert ctx2.is_replay is True + + +# --------------------------------------------------------------------------- +# QueryBudget tests +# --------------------------------------------------------------------------- + +class TestQueryBudget: + def test_initial_state(self): + from gnat.core.context import QueryBudget + + budget = QueryBudget(max_units=100) + assert budget.remaining == 100 + assert budget.is_exhausted is False + + def test_consume(self): + from gnat.core.context import QueryBudget + + budget = QueryBudget(max_units=100) + budget.consume(10, "TestConnector") + assert budget.remaining == 90 + assert budget.is_exhausted is False + + def test_consume_exact(self): + from gnat.core.context import QueryBudget + + budget = QueryBudget(max_units=10) + budget.consume(10, "TestConnector") + assert budget.remaining == 0 + assert budget.is_exhausted is True + + def test_budget_exceeded(self): + from gnat.core.context import QueryBudget + from gnat.clients.base import BudgetExceeded + + budget = QueryBudget(max_units=5) + budget.consume(3, "Connector1") + with pytest.raises(BudgetExceeded) as exc_info: + budget.consume(3, "Connector1") + assert exc_info.value.connector == "Connector1" + assert exc_info.value.cost == 3 + assert exc_info.value.remaining == 2 + + def test_budget_deducted_by_base_client(self): + from gnat.core.context import ExecutionContext, QueryBudget + from gnat.clients.base import BaseClient, BudgetExceeded + + class MockClient(BaseClient): + COST_UNIT = 5 + + def authenticate(self): + self._authenticated = True + + def _request(self, method, path, **kwargs): + # Call parent budget deduction then return empty + if self._context is not None: + budget = getattr(self._context, "budget", None) + if budget is not None: + budget.consume(self.COST_UNIT, type(self).__name__) + return {} + + ctx = ExecutionContext.create( + initiated_by="test", + domain="ingestion", + workspace_id="ws1", + max_budget_units=10, + ) + client = MockClient(host="http://test.local") + client._context = ctx + client._authenticated = True + + client._request("GET", "/test") # costs 5 + assert ctx.budget.remaining == 5 + + client._request("GET", "/test") # costs another 5 + assert ctx.budget.remaining == 0 + + with pytest.raises(BudgetExceeded): + client._request("GET", "/test") # exceeds budget + + +# --------------------------------------------------------------------------- +# Domain boundary tests +# --------------------------------------------------------------------------- + +class TestDomainBoundary: + def test_domain_enum_values(self): + from gnat.core.domains import Domain + + assert Domain.INGESTION == "ingestion" + assert Domain.ANALYSIS == "analysis" + assert Domain.INVESTIGATION == "investigation" + assert Domain.REPORTING == "reporting" + assert Domain.EXECUTION == "execution" + + def test_domain_boundary_violation_raised(self): + from gnat.core.domains import Domain, DomainBoundaryViolation, domain_boundary + + @domain_boundary(Domain.REPORTING, allowed_callers=[Domain.REPORTING]) + def report_fn(): + return "ok" + + @domain_boundary(Domain.INGESTION) + def ingestion_fn(): + # Calling report_fn from ingestion context should violate the boundary + return report_fn() + + # Calling report (only allowed from reporting) from inside ingestion → violation + with pytest.raises(DomainBoundaryViolation): + ingestion_fn() + + def test_no_violation_within_allowed(self): + from gnat.core.domains import Domain, domain_boundary + + @domain_boundary(Domain.INGESTION, allowed_callers=None) + def ingest_fn(): + return "ingested" + + # Without a domain stack (top-level call), any domain is allowed + result = ingest_fn() + assert result == "ingested" + + def test_no_violation_allowed_caller(self): + from gnat.core.domains import Domain, domain_boundary + + @domain_boundary(Domain.ANALYSIS, allowed_callers=[Domain.INGESTION, Domain.ANALYSIS]) + def analysis_fn(): + return "analyzed" + + @domain_boundary(Domain.INGESTION) + def ingestion_fn(): + return analysis_fn() # ingestion → analysis is allowed + + result = ingestion_fn() + assert result == "analyzed" + + +# --------------------------------------------------------------------------- +# SimulationConnector tests +# --------------------------------------------------------------------------- + +class TestSimulationConnector: + def _make_fixtures(self): + return [ + {"type": "indicator", "id": "indicator--abc", "spec_version": "2.1"}, + {"type": "indicator", "id": "indicator--xyz", "spec_version": "2.1"}, + {"type": "malware", "id": "malware--def", "spec_version": "2.1"}, + ] + + def test_list_all(self): + from gnat.testing import SimulationConnector + + sim = SimulationConnector(fixtures=self._make_fixtures()) + result = sim.list_objects() + assert len(result) == 3 + + def test_list_by_type(self): + from gnat.testing import SimulationConnector + + sim = SimulationConnector(fixtures=self._make_fixtures()) + result = sim.list_objects(stix_type="indicator") + assert len(result) == 2 + assert all(r["type"] == "indicator" for r in result) + + def test_get_object_found(self): + from gnat.testing import SimulationConnector + + sim = SimulationConnector(fixtures=self._make_fixtures()) + obj = sim.get_object("indicator--abc") + assert obj is not None + assert obj["id"] == "indicator--abc" + + def test_get_object_not_found(self): + from gnat.testing import SimulationConnector + + sim = SimulationConnector(fixtures=self._make_fixtures()) + assert sim.get_object("indicator--nonexistent") is None + + def test_upsert_adds_fixture(self): + from gnat.testing import SimulationConnector + + sim = SimulationConnector(fixtures=[]) + new_obj = {"type": "vulnerability", "id": "vuln--new", "spec_version": "2.1"} + sim.upsert_object(new_obj) + assert sim.get_object("vuln--new") is not None + + def test_delete_object(self): + from gnat.testing import SimulationConnector + + sim = SimulationConnector(fixtures=self._make_fixtures()) + sim.delete_object("indicator--abc") + assert sim.get_object("indicator--abc") is None + assert len(sim.list_objects()) == 2 + + def test_health_check_always_true(self): + from gnat.testing import SimulationConnector + + sim = SimulationConnector() + assert sim.health_check() is True + + def test_pagination(self): + from gnat.testing import SimulationConnector + + fixtures = [{"type": "indicator", "id": f"indicator--{i}"} for i in range(10)] + sim = SimulationConnector(fixtures=fixtures) + page1 = sim.list_objects(page=1, page_size=4) + page2 = sim.list_objects(page=2, page_size=4) + page3 = sim.list_objects(page=3, page_size=4) + assert len(page1) == 4 + assert len(page2) == 4 + assert len(page3) == 2 + + def test_trust_level_configurable(self): + from gnat.testing import SimulationConnector + + sim = SimulationConnector(trust_level="trusted_internal") + assert sim.TRUST_LEVEL == "trusted_internal" + + def test_raise_on_request(self): + from gnat.testing import SimulationConnector + from gnat.clients.base import GNATClientError + + sim = SimulationConnector(raise_on_request=True) + with pytest.raises(GNATClientError): + sim._request("GET", "/some/path") + + +# --------------------------------------------------------------------------- +# ReplayRunner tests +# --------------------------------------------------------------------------- + +class TestReplayRunner: + def test_replay_produces_output(self): + from gnat.testing import ReplayRunner + from gnat.core.context import ExecutionContext + + indicator = {"type": "indicator", "id": "indicator--replay-1"} + + def fake_pipeline(ctx): + assert ctx.is_replay is True + return [indicator] + + runner = ReplayRunner(fake_pipeline) + log = [ + ExecutionContext.create( + initiated_by="test", domain="ingestion", workspace_id="ws1" + ).to_dict() + ] + result = runner.replay(log) + assert len(result) == 1 + assert result[0]["id"] == "indicator--replay-1" + + def test_replay_asserts_expected_ids(self): + from gnat.testing import ReplayRunner + from gnat.core.context import ExecutionContext + + def fake_pipeline(ctx): + return [{"id": "indicator--expected"}] + + runner = ReplayRunner(fake_pipeline) + log = [ + ExecutionContext.create( + initiated_by="test", domain="ingestion", workspace_id="ws1" + ).to_dict() + ] + + # Should pass when expected ID is present + runner.replay(log, expected_stix_ids=["indicator--expected"]) + + def test_replay_asserts_fails_on_missing(self): + from gnat.testing import ReplayRunner + from gnat.core.context import ExecutionContext + + def fake_pipeline(ctx): + return [{"id": "indicator--produced"}] + + runner = ReplayRunner(fake_pipeline) + log = [ + ExecutionContext.create( + initiated_by="test", domain="ingestion", workspace_id="ws1" + ).to_dict() + ] + + with pytest.raises(AssertionError, match="indicator--expected"): + runner.replay(log, expected_stix_ids=["indicator--expected"]) + + +# --------------------------------------------------------------------------- +# Workspace trust boundary tests (4E-1) +# --------------------------------------------------------------------------- + +class TestWorkspaceTrustBoundary: + def _make_workspace(self, trust_boundary="semi_trusted", allowed_refs=None): + from gnat.context.workspace import Workspace + + ws = Workspace.__new__(Workspace) + ws.name = "test-ws" + ws.trust_boundary = trust_boundary + ws.allowed_connector_refs = allowed_refs or [] + return ws + + def test_trusted_internal_passes_semi_trusted_boundary(self): + ws = self._make_workspace(trust_boundary="semi_trusted") + + connector = MagicMock() + type(connector).TRUST_LEVEL = "trusted_internal" + type(connector).__name__ = "SplunkClient" + + # Should not raise + ws.check_connector_trust(connector) + + def test_untrusted_external_fails_semi_trusted_boundary(self): + ws = self._make_workspace(trust_boundary="semi_trusted") + + connector = MagicMock() + type(connector).TRUST_LEVEL = "untrusted_external" + type(connector).__name__ = "AlienVaultClient" + + with pytest.raises(PermissionError, match="does not meet workspace"): + ws.check_connector_trust(connector) + + def test_untrusted_external_passes_untrusted_boundary(self): + ws = self._make_workspace(trust_boundary="untrusted_external") + + connector = MagicMock() + type(connector).TRUST_LEVEL = "untrusted_external" + type(connector).__name__ = "AlienVaultClient" + + # Should not raise + ws.check_connector_trust(connector) + + def test_allowlist_enforcement(self): + ws = self._make_workspace( + trust_boundary="semi_trusted", + allowed_refs=["SplunkClient", "SentinelClient"], + ) + + # Allowed connector + allowed = MagicMock() + type(allowed).TRUST_LEVEL = "trusted_internal" + type(allowed).__name__ = "SplunkClient" + ws.check_connector_trust(allowed) # should not raise + + # Disallowed connector + disallowed = MagicMock() + type(disallowed).TRUST_LEVEL = "trusted_internal" + type(disallowed).__name__ = "QRadarClient" + with pytest.raises(PermissionError, match="not in the allowed connector"): + ws.check_connector_trust(disallowed) diff --git a/tests/unit/test_phase4_governor.py b/tests/unit/test_phase4_governor.py new file mode 100644 index 00000000..78f93a46 --- /dev/null +++ b/tests/unit/test_phase4_governor.py @@ -0,0 +1,409 @@ +""" +tests/unit/test_phase4_governor.py +==================================== +Unit tests for Phase 4D — AgentGovernor, HITLGateway, and policy models. +""" + +from __future__ import annotations + +import pytest +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# Policy model tests (AgentActionType) +# --------------------------------------------------------------------------- + +class TestAgentActionType: + def test_all_action_types_defined(self): + from gnat.policy.models import AgentActionType + + expected = { + "read_stix", "write_stix", "delete_stix", "enrich", + "ingest", "export", "trigger_playbook", "manage_workspace", + "escalate", "hypothesize", + } + actual = {a.value for a in AgentActionType} + assert expected == actual + + def test_trusted_internal_has_all_actions(self): + from gnat.policy.models import AgentActionType, agent_can_act + + for action in AgentActionType: + assert agent_can_act("trusted_internal", action) is True + + def test_untrusted_external_limited(self): + from gnat.policy.models import AgentActionType, agent_can_act + + # Untrusted cannot trigger playbooks + assert agent_can_act("untrusted_external", AgentActionType.TRIGGER_PLAYBOOK) is False + assert agent_can_act("untrusted_external", AgentActionType.EXPORT) is False + + # But can read and hypothesize + assert agent_can_act("untrusted_external", AgentActionType.READ_STIX) is True + assert agent_can_act("untrusted_external", AgentActionType.HYPOTHESIZE) is True + + def test_semi_trusted_can_enrich(self): + from gnat.policy.models import AgentActionType, agent_can_act + + assert agent_can_act("semi_trusted", AgentActionType.ENRICH) is True + assert agent_can_act("semi_trusted", AgentActionType.TRIGGER_PLAYBOOK) is False + + def test_unknown_trust_level_denied(self): + from gnat.policy.models import AgentActionType, agent_can_act + + assert agent_can_act("unknown_level", AgentActionType.ENRICH) is False + + +# --------------------------------------------------------------------------- +# AgentGovernor tests +# --------------------------------------------------------------------------- + +class TestAgentGovernor: + def _make_governor(self, **kwargs): + from gnat.agents.governor import AgentGovernor + return AgentGovernor(**kwargs) + + def test_can_act_trusted_internal(self): + from gnat.policy.models import AgentActionType + + gov = self._make_governor() + assert gov.can_act("agent-1", AgentActionType.TRIGGER_PLAYBOOK, "trusted_internal") is True + + def test_can_act_untrusted_denied(self): + from gnat.policy.models import AgentActionType + + gov = self._make_governor() + assert gov.can_act("agent-1", AgentActionType.TRIGGER_PLAYBOOK, "untrusted_external") is False + + def test_require_can_act_raises(self): + from gnat.agents.governor import AgentPermissionDenied + from gnat.policy.models import AgentActionType + + gov = self._make_governor() + with pytest.raises(AgentPermissionDenied): + gov.require_can_act("agent-1", AgentActionType.TRIGGER_PLAYBOOK, "untrusted_external") + + def test_policy_override_allow(self): + from gnat.policy.models import AgentActionType + + gov = self._make_governor( + policy_overrides={"agent-special": {"trigger_playbook": True}} + ) + # Override allows untrusted external to trigger playbooks + assert gov.can_act("agent-special", AgentActionType.TRIGGER_PLAYBOOK, "untrusted_external") is True + + def test_policy_override_deny(self): + from gnat.policy.models import AgentActionType + + gov = self._make_governor( + policy_overrides={"agent-restricted": {"enrich": False}} + ) + # Override denies trusted_internal from enriching + assert gov.can_act("agent-restricted", AgentActionType.ENRICH, "trusted_internal") is False + + def test_set_policy_override_runtime(self): + from gnat.policy.models import AgentActionType + + gov = self._make_governor() + gov.set_policy_override("agent-X", AgentActionType.EXPORT, True) + assert gov.can_act("agent-X", AgentActionType.EXPORT, "untrusted_external") is True + + def test_record_action(self): + from gnat.agents.governor import AgentAction + from gnat.policy.models import AgentActionType + + gov = self._make_governor() + action = AgentAction( + agent_id="agent-1", + action_type=AgentActionType.ENRICH, + target_ref="indicator--abc", + impact_level="low", + ) + gov.record_action(action) + log = gov.get_action_log() + assert len(log) == 1 + assert log[0].agent_id == "agent-1" + + def test_get_action_log_filtered(self): + from gnat.agents.governor import AgentAction + from gnat.policy.models import AgentActionType + + gov = self._make_governor() + a1 = AgentAction(agent_id="agent-A", action_type=AgentActionType.ENRICH) + a2 = AgentAction(agent_id="agent-B", action_type=AgentActionType.READ_STIX) + gov.record_action(a1) + gov.record_action(a2) + + assert len(gov.get_action_log("agent-A")) == 1 + assert len(gov.get_action_log("agent-B")) == 1 + assert len(gov.get_action_log()) == 2 + + def test_rate_limit_check_passes(self): + gov = self._make_governor(max_calls_per_window=5, window_seconds=60) + for _ in range(5): + gov.rate_limit_check("agent-1") + # 5th call should succeed + + def test_rate_limit_check_raises(self): + from gnat.agents.governor import RateLimitExceeded + + gov = self._make_governor(max_calls_per_window=3, window_seconds=60) + gov.rate_limit_check("agent-1") + gov.rate_limit_check("agent-1") + gov.rate_limit_check("agent-1") + with pytest.raises(RateLimitExceeded) as exc_info: + gov.rate_limit_check("agent-1") + assert exc_info.value.agent_id == "agent-1" + assert exc_info.value.window_seconds == 60 + + def test_rate_limit_window_expires(self): + import time + from gnat.agents.governor import AgentGovernor + + gov = AgentGovernor(max_calls_per_window=2, window_seconds=1) + gov.rate_limit_check("agent-1") + gov.rate_limit_check("agent-1") + # Wait for window to expire + time.sleep(1.1) + # Should be allowed again + gov.rate_limit_check("agent-1") + + +# --------------------------------------------------------------------------- +# AgentAction tests +# --------------------------------------------------------------------------- + +class TestAgentAction: + def test_create(self): + from gnat.agents.governor import AgentAction + from gnat.policy.models import AgentActionType + + action = AgentAction( + agent_id="agent-1", + action_type=AgentActionType.ENRICH, + target_ref="indicator--abc", + impact_level="high", + ) + assert action.agent_id == "agent-1" + assert action.impact_level == "high" + assert action.status == "pending" + assert action.action_id # auto-generated + + def test_invalid_impact_level(self): + from gnat.agents.governor import AgentAction + from gnat.policy.models import AgentActionType + + with pytest.raises(ValueError, match="impact_level must be one of"): + AgentAction( + agent_id="agent-1", + action_type=AgentActionType.ENRICH, + impact_level="extreme", + ) + + def test_to_dict(self): + from gnat.agents.governor import AgentAction + from gnat.policy.models import AgentActionType + + action = AgentAction( + agent_id="test", + action_type=AgentActionType.WRITE_STIX, + target_ref="indicator--xyz", + impact_level="medium", + ) + d = action.to_dict() + assert d["agent_id"] == "test" + assert d["action_type"] == "write_stix" + assert d["impact_level"] == "medium" + assert "action_id" in d + assert "submitted_at" in d + + +# --------------------------------------------------------------------------- +# HITLGateway tests +# --------------------------------------------------------------------------- + +class TestHITLGateway: + def _make_gateway(self, auto_approve=True): + from gnat.agents.hitl import HITLGateway + + review_service = MagicMock() + mock_item = MagicMock() + mock_item.id = "review-item-123" + from datetime import datetime, timezone + mock_item.submitted_at = datetime.now(timezone.utc) + + from gnat.review.models import ReviewStatus + mock_item.status = ReviewStatus.PENDING + review_service.submit.return_value = mock_item + review_service.get.return_value = mock_item + review_service.approve.return_value = mock_item + + gateway = HITLGateway(review_service=review_service, approval_timeout_seconds=3600) + return gateway, review_service, mock_item + + def test_low_impact_auto_approved(self): + from gnat.agents.governor import AgentAction + from gnat.policy.models import AgentActionType + + gateway, _, _ = self._make_gateway() + action = AgentAction( + agent_id="agent-1", + action_type=AgentActionType.READ_STIX, + impact_level="low", + ) + approved, review_item = gateway.evaluate(action) + assert approved is True + assert review_item is None + assert action.approved_by == "auto-policy" + assert action.status == "approved" + + def test_medium_impact_auto_approved(self): + from gnat.agents.governor import AgentAction + from gnat.policy.models import AgentActionType + + gateway, _, _ = self._make_gateway() + action = AgentAction( + agent_id="agent-1", + action_type=AgentActionType.ENRICH, + impact_level="medium", + ) + approved, review_item = gateway.evaluate(action) + assert approved is True + + def test_high_impact_creates_review_item(self): + from gnat.agents.governor import AgentAction + from gnat.policy.models import AgentActionType + + gateway, review_service, mock_item = self._make_gateway() + action = AgentAction( + agent_id="agent-1", + action_type=AgentActionType.TRIGGER_PLAYBOOK, + impact_level="high", + ) + approved, review_item = gateway.evaluate(action) + assert approved is False + assert review_item is not None + assert review_item.id == "review-item-123" + review_service.submit.assert_called_once() + assert action.status == "pending" + + def test_critical_impact_notifies_xsoar(self): + from gnat.agents.governor import AgentAction + from gnat.policy.models import AgentActionType + from gnat.agents.hitl import HITLGateway + + review_service = MagicMock() + mock_item = MagicMock() + mock_item.id = "review-crit-999" + from datetime import datetime, timezone + mock_item.submitted_at = datetime.now(timezone.utc) + from gnat.review.models import ReviewStatus + mock_item.status = ReviewStatus.PENDING + review_service.submit.return_value = mock_item + + xsoar = MagicMock() + gateway = HITLGateway( + review_service=review_service, + xsoar_client=xsoar, + ) + + action = AgentAction( + agent_id="agent-1", + action_type=AgentActionType.TRIGGER_PLAYBOOK, + impact_level="critical", + ) + approved, review_item = gateway.evaluate(action) + assert approved is False + xsoar.upsert_object.assert_called_once() + + def test_check_approval_status_timeout(self): + from gnat.agents.hitl import HITLGateway + from gnat.review.models import ReviewStatus + from datetime import datetime, timezone, timedelta + + review_service = MagicMock() + mock_item = MagicMock() + mock_item.id = "review-timeout" + mock_item.submitted_at = datetime.now(timezone.utc) - timedelta(hours=2) + mock_item.status = ReviewStatus.PENDING + review_service.get.return_value = mock_item + + rejected_item = MagicMock() + rejected_item.status = ReviewStatus.REJECTED + + # After reject, get returns the rejected item + call_count = [0] + def get_side_effect(item_id): + call_count[0] += 1 + if call_count[0] == 1: + return mock_item + return rejected_item + review_service.get.side_effect = get_side_effect + + gateway = HITLGateway(review_service=review_service, approval_timeout_seconds=60) + status = gateway.check_approval_status("review-timeout") + review_service.reject.assert_called_once() + assert status == ReviewStatus.REJECTED + + def test_auto_approve_pending(self): + from gnat.agents.hitl import HITLGateway + from gnat.review.models import ReviewStatus + + review_service = MagicMock() + mock_item = MagicMock() + mock_item.status = ReviewStatus.APPROVED + review_service.approve.return_value = mock_item + + gateway = HITLGateway(review_service=review_service) + gateway.auto_approve_pending("review-123", reviewer="system-test") + review_service.approve.assert_called_once_with( + "review-123", reviewed_by="system-test" + ) + + +# --------------------------------------------------------------------------- +# AgentTestHarness tests +# --------------------------------------------------------------------------- + +class TestAgentTestHarness: + def test_run_action_low_impact(self): + from gnat.testing import AgentTestHarness + from gnat.policy.models import AgentActionType + + harness = AgentTestHarness() + approved, action = harness.run_action( + agent_id="test-agent", + action_type=AgentActionType.ENRICH, + impact_level="low", + trust_level="semi_trusted", + ) + assert approved is True + assert action.status == "approved" + assert len(harness.recorded_actions) == 1 + + def test_run_action_denied(self): + from gnat.testing import AgentTestHarness + from gnat.agents.governor import AgentPermissionDenied + from gnat.policy.models import AgentActionType + + harness = AgentTestHarness() + with pytest.raises(AgentPermissionDenied): + harness.run_action( + agent_id="test-agent", + action_type=AgentActionType.TRIGGER_PLAYBOOK, + trust_level="untrusted_external", + ) + + def test_multiple_actions_recorded(self): + from gnat.testing import AgentTestHarness + from gnat.policy.models import AgentActionType + + harness = AgentTestHarness() + for _ in range(5): + harness.run_action( + agent_id="bulk-agent", + action_type=AgentActionType.READ_STIX, + trust_level="semi_trusted", + ) + assert len(harness.recorded_actions) == 5 diff --git a/tests/unit/test_phase4_reasoning.py b/tests/unit/test_phase4_reasoning.py new file mode 100644 index 00000000..744ab98a --- /dev/null +++ b/tests/unit/test_phase4_reasoning.py @@ -0,0 +1,458 @@ +""" +tests/unit/test_phase4_reasoning.py +===================================== +Unit tests for Phase 4C — Hypothesis Engine, Negative Evidence, Reasoning Engine. +""" + +from __future__ import annotations + +import time +from datetime import datetime, timezone, timedelta +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# STIXHypothesis tests +# --------------------------------------------------------------------------- + +class TestSTIXHypothesis: + def test_create_defaults(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + h = STIXHypothesis(statement="APT29 behind Q1 campaign", confidence=0.4) + assert h._properties["statement"] == "APT29 behind Q1 campaign" + assert h._properties["confidence"] == 0.4 + assert h._properties["status"] == "pending" + assert h._properties["supporting_evidence"] == [] + assert h._properties["refuting_evidence"] == [] + assert h.id.startswith("x-gnat-hypothesis--") + + def test_invalid_status_raises(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + with pytest.raises(ValueError, match="Invalid hypothesis status"): + STIXHypothesis(statement="x", confidence=0.5, status="bad-status") + + def test_invalid_confidence_raises(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + with pytest.raises(ValueError, match="confidence must be in"): + STIXHypothesis(statement="x", confidence=1.5) + + def test_add_supporting_evidence(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + h = STIXHypothesis(statement="test", confidence=0.3) + h.add_supporting_evidence("relationship--abc") + assert "relationship--abc" in h._properties["supporting_evidence"] + # Duplicate not added + h.add_supporting_evidence("relationship--abc") + assert h._properties["supporting_evidence"].count("relationship--abc") == 1 + + def test_add_refuting_evidence(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + h = STIXHypothesis(statement="test", confidence=0.3) + h.add_refuting_evidence("relationship--xyz") + assert "relationship--xyz" in h._properties["refuting_evidence"] + + def test_update_confidence(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + h = STIXHypothesis(statement="test", confidence=0.3) + h.update_confidence(0.8) + assert h._properties["confidence"] == 0.8 + + def test_update_confidence_invalid(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + h = STIXHypothesis(statement="test", confidence=0.3) + with pytest.raises(ValueError, match="confidence must be in"): + h.update_confidence(1.1) + + def test_close_confirmed(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + h = STIXHypothesis(statement="test", confidence=0.9) + h.close("confirmed") + assert h._properties["status"] == "confirmed" + + def test_close_invalid_verdict(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + h = STIXHypothesis(statement="test", confidence=0.5) + with pytest.raises(ValueError, match="verdict must be one of"): + h.close("maybe") + + def test_to_dict_round_trip(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + h = STIXHypothesis(statement="round trip test", confidence=0.6, status="pending") + h.add_supporting_evidence("rel--1") + d = h.to_dict() + + assert d["type"] == "x-gnat-hypothesis" + assert d["statement"] == "round trip test" + assert d["confidence"] == 0.6 + assert "rel--1" in d["supporting_evidence"] + + h2 = STIXHypothesis.from_dict(d) + assert h2._properties["statement"] == "round trip test" + assert h2._properties["confidence"] == 0.6 + assert "rel--1" in h2._properties["supporting_evidence"] + + +# --------------------------------------------------------------------------- +# NegativeEvidenceRecord tests +# --------------------------------------------------------------------------- + +class TestNegativeEvidenceRecord: + def test_create(self): + from gnat.stix.sdos.negative_evidence import NegativeEvidenceRecord + + rec = NegativeEvidenceRecord( + target_ref="indicator--abc", + queried_connector="VirusTotalClient", + ttl_seconds=3600, + ) + assert rec._properties["target_ref"] == "indicator--abc" + assert rec._properties["queried_connector"] == "VirusTotalClient" + assert rec._properties["ttl_seconds"] == 3600 + assert rec.id.startswith("x-gnat-negative-evidence--") + + def test_not_expired_immediately(self): + from gnat.stix.sdos.negative_evidence import NegativeEvidenceRecord + + rec = NegativeEvidenceRecord( + target_ref="indicator--abc", + queried_connector="TestClient", + ttl_seconds=3600, + ) + assert rec.is_expired() is False + + def test_expired_with_past_timestamp(self): + from gnat.stix.sdos.negative_evidence import NegativeEvidenceRecord + + past = (datetime.now(timezone.utc) - timedelta(hours=2)).isoformat() + rec = NegativeEvidenceRecord( + target_ref="indicator--abc", + queried_connector="TestClient", + ttl_seconds=3600, + query_timestamp=past, + ) + assert rec.is_expired() is True + + def test_seconds_remaining_fresh(self): + from gnat.stix.sdos.negative_evidence import NegativeEvidenceRecord + + rec = NegativeEvidenceRecord( + target_ref="indicator--abc", + queried_connector="TestClient", + ttl_seconds=3600, + ) + remaining = rec.seconds_remaining() + assert remaining > 3590 # Should be close to 3600 + + def test_seconds_remaining_expired(self): + from gnat.stix.sdos.negative_evidence import NegativeEvidenceRecord + + past = (datetime.now(timezone.utc) - timedelta(hours=2)).isoformat() + rec = NegativeEvidenceRecord( + target_ref="indicator--abc", + queried_connector="TestClient", + ttl_seconds=3600, + query_timestamp=past, + ) + assert rec.seconds_remaining() == 0.0 + + def test_round_trip(self): + from gnat.stix.sdos.negative_evidence import NegativeEvidenceRecord + + rec = NegativeEvidenceRecord( + target_ref="indicator--xyz", + queried_connector="CrowdStrikeClient", + ttl_seconds=7200, + ) + d = rec.to_dict() + rec2 = NegativeEvidenceRecord.from_dict(d) + assert rec2._properties["target_ref"] == "indicator--xyz" + assert rec2._properties["ttl_seconds"] == 7200 + + +# --------------------------------------------------------------------------- +# HypothesisEngine tests +# --------------------------------------------------------------------------- + +class TestHypothesisEngine: + def _make_engine(self, fixtures=None): + from gnat.reasoning.hypothesis import HypothesisEngine + from gnat.search.index import NullSearchIndex + + manager = MagicMock() + ws = MagicMock() + ws.objects = {} + manager.open.return_value = ws + + engine = HypothesisEngine( + manager=manager, + workspace_name="test-ws", + search_index=NullSearchIndex(), + ) + return engine, ws + + def test_propose_creates_hypothesis(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + engine, ws = self._make_engine() + h = engine.propose("APT29 is behind Q1 campaign", confidence=0.3) + assert isinstance(h, STIXHypothesis) + assert h._properties["statement"] == "APT29 is behind Q1 campaign" + assert h._properties["confidence"] == 0.3 + assert h._properties["status"] == "pending" + ws._add_object.assert_called_once() + + def test_propose_with_evidence(self): + engine, ws = self._make_engine() + h = engine.propose( + "Lazarus Group C2", + initial_evidence=["rel--1", "rel--2"], + confidence=0.5, + ) + assert "rel--1" in h._properties["supporting_evidence"] + assert "rel--2" in h._properties["supporting_evidence"] + + def test_evaluate_no_evidence_unchanged(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + engine, ws = self._make_engine() + h_obj = STIXHypothesis(statement="test", confidence=0.2) + ws.objects = {h_obj.id: MagicMock(to_dict=lambda: h_obj.to_dict())} + + h_result = engine.evaluate(h_obj.id) + # No evidence → stays at initial confidence + assert h_result._properties["confidence"] == 0.2 + + def test_evaluate_missing_hypothesis_raises(self): + engine, ws = self._make_engine() + ws.objects = {} + + with pytest.raises(KeyError, match="No hypothesis found"): + engine.evaluate("x-gnat-hypothesis--nonexistent") + + def test_evaluate_high_support_confirms(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + engine, ws = self._make_engine() + h_obj = STIXHypothesis( + statement="confirmed", + confidence=0.5, + supporting_evidence=["r1", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10"], + ) + ws.objects = {h_obj.id: MagicMock(to_dict=lambda: h_obj.to_dict())} + + h_result = engine.evaluate(h_obj.id) + # High support count → status should be confirmed (confidence ≥ 0.75) + assert h_result._properties["confidence"] >= 0.75 + assert h_result._properties["status"] == "confirmed" + + def test_close_sets_verdict(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + engine, ws = self._make_engine() + h_obj = STIXHypothesis(statement="test", confidence=0.4) + ws.objects = {h_obj.id: MagicMock(to_dict=lambda: h_obj.to_dict())} + + h_result = engine.close(h_obj.id, verdict="refuted") + assert h_result._properties["status"] == "refuted" + + def test_get_returns_none_for_missing(self): + engine, ws = self._make_engine() + ws.objects = {} + assert engine.get("missing--id") is None + + def test_list_all(self): + from gnat.stix.sdos.hypothesis import STIXHypothesis + + engine, ws = self._make_engine() + h1 = STIXHypothesis(statement="h1", confidence=0.3) + h2 = STIXHypothesis(statement="h2", confidence=0.5) + + class FakeObj: + stix_type = "x-gnat-hypothesis" + def to_dict(self): return h1.to_dict() + + ws.objects = { + h1.id: FakeObj(), + h2.id: FakeObj(), + } + result = engine.list_all() + assert len(result) == 2 + + +# --------------------------------------------------------------------------- +# ReasoningEngine tests +# --------------------------------------------------------------------------- + +class TestReasoningEngine: + def _make_observable(self, obj_id="indicator--abc", modified_days_ago=1): + obs = MagicMock() + obs.id = obj_id + obs.stix_type = "indicator" + modified = (datetime.now(timezone.utc) - timedelta(days=modified_days_ago)).isoformat() + obs.modified = modified + return obs + + def _make_engine(self, search_hits=0): + from gnat.reasoning.engine import ReasoningEngine + from gnat.search.index import NullSearchIndex + + manager = MagicMock() + ws = MagicMock() + ws.objects = {} + manager.open.return_value = ws + + search_index = MagicMock() + search_index.search.return_value = list(range(search_hits)) + + engine = ReasoningEngine( + manager=manager, + workspace_name="test-ws", + search_index=search_index, + ) + return engine, ws + + def test_prioritize_returns_sorted(self): + engine, ws = self._make_engine(search_hits=3) + + obs1 = self._make_observable("indicator--1", modified_days_ago=1) + obs2 = self._make_observable("indicator--2", modified_days_ago=100) + + from gnat.core.context import ExecutionContext + ctx = ExecutionContext.create( + initiated_by="test", + domain="analysis", + workspace_id="ws1", + trust_level="trusted_internal", + ) + results = engine.prioritize([obs1, obs2], context=ctx, store_notes=False) + + assert len(results) == 2 + # Sorted descending by score + assert results[0][1] >= results[1][1] + + def test_score_in_range(self): + engine, ws = self._make_engine(search_hits=0) + + obs = self._make_observable("indicator--x", modified_days_ago=5) + results = engine.prioritize([obs], store_notes=False) + + score = results[0][1] + assert 0.0 <= score <= 1.0 + + def test_explanation_structure(self): + engine, ws = self._make_engine(search_hits=2) + + obs = self._make_observable("indicator--y", modified_days_ago=2) + results = engine.prioritize([obs], store_notes=False) + + _, score, explanation = results[0] + assert "components" in explanation + assert "trust_weight" in explanation["components"] + assert "age_factor" in explanation["components"] + assert "negative_evidence" in explanation["components"] + assert "corroboration" in explanation["components"] + assert "score" in explanation + assert "summary" in explanation + + def test_negative_evidence_reduces_score(self): + from gnat.reasoning.engine import ReasoningEngine + from gnat.stix.sdos.negative_evidence import NegativeEvidenceRecord + from gnat.search.index import NullSearchIndex + + manager = MagicMock() + ws = MagicMock() + + obs_id = "indicator--neg-test" + + # Build negative evidence record pointing at our observable + neg_rec = NegativeEvidenceRecord( + target_ref=obs_id, + queried_connector="VirusTotal", + ttl_seconds=3600, + ) + neg_dict = neg_rec.to_dict() + + # Workspace returns the negative evidence on iteration + fake_obj = MagicMock() + fake_obj.to_dict.return_value = neg_dict + ws.objects = {neg_rec.id: fake_obj} + manager.open.return_value = ws + + engine = ReasoningEngine( + manager=manager, + workspace_name="test-ws", + search_index=NullSearchIndex(), + ) + + obs = self._make_observable(obs_id, modified_days_ago=1) + results_with_neg = engine.prioritize([obs], store_notes=False) + neg_score = results_with_neg[0][1] + + # Score with neg evidence should be lower than without + ws.objects = {} + results_without = engine.prioritize([obs], store_notes=False) + clean_score = results_without[0][1] + + assert neg_score < clean_score + + def test_age_factor_decay(self): + from gnat.reasoning.engine import ReasoningEngine + + # Fresh object (today) + obs_fresh = MagicMock() + obs_fresh.modified = datetime.now(timezone.utc).isoformat() + fresh_factor = ReasoningEngine._age_factor(obs_fresh) + assert fresh_factor >= 0.95 # barely decayed + + # Old object (20 days ago = 1.0 - 0.05*20 = 0.0) + obs_old = MagicMock() + obs_old.modified = (datetime.now(timezone.utc) - timedelta(days=20)).isoformat() + old_factor = ReasoningEngine._age_factor(obs_old) + assert old_factor == 0.0 + + def test_age_factor_no_modified(self): + from gnat.reasoning.engine import ReasoningEngine + + obs = MagicMock() + obs.modified = "" + assert ReasoningEngine._age_factor(obs) == 0.5 + + def test_trusted_internal_scores_higher(self): + from gnat.reasoning.engine import ReasoningEngine + from gnat.core.context import ExecutionContext + from gnat.search.index import NullSearchIndex + + manager = MagicMock() + ws = MagicMock() + ws.objects = {} + manager.open.return_value = ws + + engine = ReasoningEngine(manager=manager, search_index=NullSearchIndex()) + + obs = self._make_observable("indicator--trust-test", modified_days_ago=1) + + ctx_trusted = ExecutionContext.create( + initiated_by="splunk", domain="analysis", workspace_id="ws1", + trust_level="trusted_internal", + ) + ctx_untrusted = ExecutionContext.create( + initiated_by="otx", domain="analysis", workspace_id="ws1", + trust_level="untrusted_external", + ) + + res_trusted = engine.prioritize([obs], context=ctx_trusted, store_notes=False) + res_untrusted = engine.prioritize([obs], context=ctx_untrusted, store_notes=False) + + assert res_trusted[0][1] > res_untrusted[0][1]