Skip to content

Commit 4f58d66

Browse files
authored
Merge branch 'master' into feat/ability-yaml-upload
2 parents 0825099 + 17615ac commit 4f58d66

13 files changed

Lines changed: 182 additions & 23 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ conf/*.yml
2020
!conf/default.yml
2121
data/object_store
2222
data/fact_store
23+
data/cookie_storage
2324
data/results/*
2425
!data/results/.gitkeep
2526
data/payloads/*

app/api/v2/handlers/payload_api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pathlib
55
import re
66
from io import IOBase
7+
from typing import Optional
78

89
import aiohttp_apispec
910
from aiohttp import web
@@ -27,14 +28,16 @@ def add_routes(self, app: web.Application):
2728

2829
@aiohttp_apispec.docs(tags=['payloads'],
2930
summary='Retrieve payloads',
30-
description='Retrieves all stored payloads.')
31+
description='Retrieves all stored payloads. Supports optional filtering by name '
32+
'(case-insensitive substring match via the `name` query parameter).')
3133
@aiohttp_apispec.querystring_schema(PayloadQuerySchema)
3234
@aiohttp_apispec.response_schema(PayloadSchema(),
3335
description='Returns a list of all payloads in PayloadSchema format.')
3436
async def get_payloads(self, request: web.Request):
3537
sort: bool = request['querystring'].get('sort')
3638
exclude_plugins: bool = request['querystring'].get('exclude_plugins')
3739
add_path: bool = request['querystring'].get('add_path')
40+
name_filter: Optional[str] = request['querystring'].get('name')
3841

3942
cwd = pathlib.Path.cwd()
4043
payload_dirs = [cwd / 'data' / 'payloads']
@@ -52,6 +55,11 @@ async def get_payloads(self, request: web.Request):
5255
}
5356

5457
payloads = list(payloads)
58+
59+
if name_filter:
60+
name_filter_lower = name_filter.lower()
61+
payloads = [p for p in payloads if name_filter_lower in pathlib.PurePath(p).name.lower()]
62+
5563
if sort:
5664
payloads.sort()
5765

app/api/v2/managers/base_api_manager.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import logging
22
import os
3+
import re
34
import uuid
45
import yaml
56

67
from marshmallow.schema import SchemaMeta
78
from typing import Any, List
89
from base64 import b64encode, b64decode
910

11+
from app.api.v2.errors import DataValidationError
1012
from app.utility.base_world import BaseWorld
1113

1214

@@ -64,6 +66,7 @@ def create_object_from_schema(self, schema: SchemaMeta, data: dict, access: Base
6466

6567
async def create_on_disk_object(self, data: dict, access: dict, ram_key: str, id_property: str, obj_class: type):
6668
obj_id = data.get(id_property) or str(uuid.uuid4())
69+
obj_id = self._sanitize_id(obj_id)
6770
data[id_property] = obj_id
6871

6972
file_path = await self._get_new_object_file_path(data[id_property], ram_key)
@@ -121,18 +124,34 @@ async def remove_object_from_memory_by_id(self, identifier: str, ram_key: str, i
121124
await self._data_svc.remove(ram_key, {id_property: identifier})
122125

123126
async def remove_object_from_disk_by_id(self, identifier: str, ram_key: str):
127+
identifier = self._sanitize_id(identifier)
124128
file_path = await self._get_existing_object_file_path(identifier, ram_key)
125129

126130
if os.path.exists(file_path):
127131
os.remove(file_path)
128132

133+
@staticmethod
134+
def _sanitize_id(obj_id) -> str:
135+
'''Removes any non-alphanumeric characters and non-hyphen/underscore.'''
136+
if not isinstance(obj_id, str):
137+
raise DataValidationError(message=f'Invalid id type: expected str, got {type(obj_id).__name__}', name='id', value=obj_id)
138+
original_id = obj_id
139+
obj_id = re.sub(r'[^a-zA-Z0-9_-]', '', obj_id)
140+
if not obj_id:
141+
raise DataValidationError(message=f"Invalid id: {obj_id!r}", name='id', value=obj_id)
142+
if original_id != obj_id:
143+
logging.getLogger(DEFAULT_LOGGER_NAME).warning(f"Sanitized ID: {obj_id}")
144+
return obj_id
145+
129146
@staticmethod
130147
async def _get_new_object_file_path(identifier: str, ram_key: str) -> str:
131148
"""Create file path for new object"""
149+
identifier = BaseApiManager._sanitize_id(identifier)
132150
return os.path.join('data', ram_key, f'{identifier}.yml')
133151

134152
async def _get_existing_object_file_path(self, identifier: str, ram_key: str) -> str:
135153
"""Find file path for existing object (by id)"""
154+
identifier = self._sanitize_id(identifier)
136155
_, file_path = await self._file_svc.find_file_path(f'{identifier}.yml', location=ram_key)
137156
if not file_path:
138157
file_path = await self._get_new_object_file_path(identifier, ram_key)

app/api/v2/schemas/payload_schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ class PayloadQuerySchema(schema.Schema):
55
sort = fields.Boolean(required=False, load_default=False)
66
exclude_plugins = fields.Boolean(required=False, load_default=False)
77
add_path = fields.Boolean(required=False, load_default=False)
8+
name = fields.String(required=False, load_default=None, allow_none=True)
89

910

1011
class PayloadSchema(schema.Schema):

app/service/auth_svc.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import base64
21
from collections import namedtuple
32
from importlib import import_module
3+
import os
44

55
from aiohttp import web, web_request
66
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPForbidden
@@ -10,7 +10,6 @@
1010
from aiohttp_security.abc import AbstractAuthorizationPolicy
1111
from aiohttp_session import setup as setup_session
1212
from aiohttp_session.cookie_storage import EncryptedCookieStorage
13-
from cryptography import fernet
1413

1514
from app.service.interfaces.i_auth_svc import AuthServiceInterface
1615
from app.service.interfaces.i_login_handler import LoginHandlerInterface
@@ -73,9 +72,35 @@ async def apply(self, app, users):
7372
for username, password in user.items():
7473
await self.create_user(username, password, group)
7574
app.user_map = self.user_map
76-
fernet_key = fernet.Fernet.generate_key()
77-
secret_key = base64.urlsafe_b64decode(fernet_key)
78-
storage = EncryptedCookieStorage(secret_key, cookie_name=COOKIE_SESSION)
75+
cookie_file = 'cookie_storage'
76+
expiration_days = self.get_config('session_expiration_days')
77+
file_svc = self.get_service('file_svc')
78+
cookie_path = os.path.join('data', cookie_file)
79+
80+
# Safely calculate max_age in seconds, allowing for fractional days
81+
try:
82+
max_age = int(float(expiration_days) * 86400) if expiration_days else None
83+
except (ValueError, TypeError):
84+
max_age = None
85+
try:
86+
if os.path.exists(cookie_path):
87+
secret_key = file_svc._read(cookie_path)
88+
self.log.debug('Loaded persistent session key from data/cookie_storage')
89+
else:
90+
# Generate a new random 32-byte key for AES encryption if no valid key is found in the config or data folder
91+
secret_key = os.urandom(32)
92+
file_svc._save(cookie_path, secret_key, encrypt=True)
93+
self.log.debug('Generated and saved new persistent session key.')
94+
except Exception as e:
95+
# Fallback if file operations fail
96+
self.log.warning('Could not manage persistent key file, falling back to ephemeral: %s', e)
97+
secret_key = os.urandom(32)
98+
if len(secret_key) != 32:
99+
secret_key = os.urandom(32)
100+
self.log.warning('Loaded session key is not 32 bytes long. Generating new key.')
101+
102+
# Pass max_age to the storage initializer
103+
storage = EncryptedCookieStorage(secret_key, cookie_name=COOKIE_SESSION, max_age=max_age)
79104
setup_session(app, storage)
80105
policy = SessionIdentityPolicy()
81106
setup_security(app, policy, DictionaryAuthorizationPolicy(self.user_map))

app/service/data_svc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
'data/results/*',
3636
'data/sources/*',
3737
'data/object_store',
38+
'data/cookie_storage',
3839
)
3940

4041
PAYLOADS_CONFIG_STANDARD_KEY = 'standard_payloads'

conf/default.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ app.contact.websocket: 0.0.0.0:7012
2727
auth.login.handler.module: default
2828
crypt_salt: REPLACE_WITH_RANDOM_VALUE
2929
encryption_key: ADMIN123
30+
session_expiration_days: 7
3031
exfil_dir: /tmp/caldera
3132
host: 0.0.0.0
3233
objects.planners.default: atomic

package-lock.json

Lines changed: 12 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
aiohttp-jinja2==1.5.1
2-
aiohttp==3.13.3
2+
aiohttp==3.13.4
33
aiohttp_session==2.12.0
44
aiohttp-security==0.4.0
55
aiohttp-apispec==3.0.0b2
66
argon2-cffi==25.1.0
77
jinja2==3.1.6
88
pyyaml==6.0.1
9-
cryptography==46.0.5
9+
cryptography==46.0.7
1010
websockets==15.0
1111
Sphinx==7.1.2
1212
sphinx_rtd_theme==1.3.0
@@ -15,15 +15,15 @@ marshmallow==3.26.2
1515
dirhash==0.2.1
1616
marshmallow-enum==1.5.1
1717
ldap3==2.9.1
18-
pyasn1~=0.5.1
18+
pyasn1==0.6.3
1919
reportlab==4.0.4 # debrief
2020
rich==13.7.0
2121
lxml==6.0.2 # debrief
2222
svglib==1.5.1 # debrief
2323
Markdown==3.8.1 # training
2424
dnspython==2.6.1
2525
asyncssh==2.20.0
26-
aioftp~=0.20.0
26+
aioftp==0.27.2
2727
packaging==23.2
2828
croniter~=3.0.3
2929
setuptools==78.1.1

tests/api/v2/handlers/test_payloads_api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import pathlib
23
import tempfile
34
from http import HTTPStatus
45

@@ -49,6 +50,35 @@ async def test_get_payloads(self, api_v2_client, api_cookies, expected_payload_f
4950

5051
assert filtered_payload_file_names == expected_payload_file_names
5152

53+
@pytest.mark.parametrize('query_name', ['payload_', 'PAYLOAD_'])
54+
async def test_get_payloads_name_filter(self, api_v2_client, api_cookies, expected_payload_file_names, query_name):
55+
resp = await api_v2_client.get(f'/api/v2/payloads?name={query_name}', cookies=api_cookies)
56+
assert resp.status == HTTPStatus.OK
57+
payload_file_names = await resp.json()
58+
59+
# All expected payloads should be present
60+
assert expected_payload_file_names <= set(payload_file_names)
61+
# Every returned payload must match the filter (no false positives)
62+
assert all('payload_' in pathlib.PurePath(p).name.lower() for p in payload_file_names)
63+
64+
async def test_get_payloads_name_filter_no_match(self, api_v2_client, api_cookies):
65+
resp = await api_v2_client.get('/api/v2/payloads?name=__no_match_xyzzy__', cookies=api_cookies)
66+
assert resp.status == HTTPStatus.OK
67+
assert await resp.json() == []
68+
69+
async def test_get_payloads_name_filter_with_sort_and_add_path(
70+
self, api_v2_client, api_cookies, expected_payload_file_names):
71+
resp = await api_v2_client.get('/api/v2/payloads?name=payload_&sort=true&add_path=true', cookies=api_cookies)
72+
assert resp.status == HTTPStatus.OK
73+
payload_paths = await resp.json()
74+
75+
# Results should be sorted
76+
assert payload_paths == sorted(payload_paths)
77+
# Every returned path's filename must match the filter
78+
assert all('payload_' in pathlib.PurePath(p).name.lower() for p in payload_paths)
79+
# Results should contain paths (not bare filenames)
80+
assert all(os.sep in p or '/' in p for p in payload_paths)
81+
5282
async def test_unauthorized_get_payloads(self, api_v2_client):
5383
resp = await api_v2_client.get('/api/v2/payloads')
5484
assert resp.status == HTTPStatus.UNAUTHORIZED

0 commit comments

Comments
 (0)