Skip to content

Commit bd028ba

Browse files
authored
Improve ability seeding (#1591)
1 parent 26ee8a5 commit bd028ba

3 files changed

Lines changed: 167 additions & 62 deletions

File tree

agixt/Agent.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import binascii
4141
from WebhookManager import WebhookEventEmitter
4242
from sqlalchemy.exc import IntegrityError
43+
from sqlalchemy.orm import joinedload
4344

4445
# Initialize webhook event emitter
4546
webhook_emitter = WebhookEventEmitter()
@@ -50,6 +51,61 @@
5051
)
5152

5253

54+
_command_owner_cache = None
55+
56+
57+
def _get_command_owner_cache():
58+
global _command_owner_cache
59+
if _command_owner_cache is not None:
60+
return _command_owner_cache
61+
62+
cache = {}
63+
try:
64+
extensions = Extensions().get_extensions()
65+
for extension_data in extensions:
66+
extension_name = extension_data.get("extension_name")
67+
for command_data in extension_data.get("commands", []):
68+
friendly_name = command_data.get("friendly_name")
69+
if not friendly_name:
70+
continue
71+
cache.setdefault(friendly_name.lower(), set()).add(extension_name)
72+
except Exception as e:
73+
logging.debug(f"Unable to build command owner cache: {e}")
74+
75+
_command_owner_cache = cache
76+
return _command_owner_cache
77+
78+
79+
def _resolve_command_by_name(session, command_name):
80+
if not command_name:
81+
return None
82+
83+
commands = (
84+
session.query(Command)
85+
.options(joinedload(Command.extension))
86+
.filter(Command.name == command_name)
87+
.all()
88+
)
89+
90+
if not commands:
91+
return None
92+
if len(commands) == 1:
93+
return commands[0]
94+
95+
owners = _get_command_owner_cache().get(command_name.lower(), set())
96+
if owners:
97+
for command in commands:
98+
extension_name = command.extension.name if command.extension else None
99+
if extension_name in owners:
100+
return command
101+
102+
logging.warning(
103+
"Multiple database entries found for command '%s'. Defaulting to first match.",
104+
command_name,
105+
)
106+
return commands[0]
107+
108+
53109
# Define the standalone wallet creation function
54110
def create_solana_wallet() -> Tuple[str, str, str]:
55111
"""
@@ -219,7 +275,7 @@ def add_agent(agent_name, provider_settings=None, commands=None, user=DEFAULT_US
219275
# Handle any additional commands passed in the commands parameter
220276
if commands:
221277
for command_name, enabled in commands.items():
222-
command = session.query(Command).filter_by(name=command_name).first()
278+
command = _resolve_command_by_name(session, command_name)
223279
if command:
224280
# Check if agent command already exists (from auto-enabled extensions)
225281
existing_agent_command = (
@@ -1415,7 +1471,7 @@ def update_agent_config(self, new_config, config_key):
14151471
continue
14161472

14171473
# First try to find an existing command
1418-
command = session.query(Command).filter_by(name=command_name).first()
1474+
command = _resolve_command_by_name(session, command_name)
14191475

14201476
if not command:
14211477
# Check if this is a chain command

agixt/Interactions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ async def format_prompt(
327327
context.append(
328328
f"## Persona\n**The assistant follows a persona and uses the following guidelines and information to remain in character.**\n{persona}\nThe assistant is {self.agent_name} and is an AGiXT agent created by DevXT, empowered with AGiXT abilities."
329329
)
330+
APP_URI = getenv("APP_URI")
331+
if "localhost:" not in APP_URI:
332+
context.append(
333+
f"The assistant is an AGiXT agent named `{self.agent_name}` running on {APP_URI}. The assistant can access the documentation about the website at {AGIXT_URI}/docs as well as information about the open source AGiXT back end repository at https://github.com/Josh-XT/AGiXT if necessary."
334+
)
330335
if "72" in kwargs and "42" in kwargs:
331336
if kwargs["72"] == True and kwargs["42"] == True:
332337
kwargs["fp"] = context

agixt/SeedImports.py

Lines changed: 104 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import json
33
import logging
4+
from collections import defaultdict
45
from DB import (
56
get_session,
67
Provider,
@@ -16,6 +17,7 @@
1617
Agent,
1718
AgentCommand,
1819
Chain,
20+
ChainStep,
1921
)
2022
from Providers import get_providers, get_provider_options
2123
from Agent import add_agent
@@ -27,6 +29,65 @@
2729
)
2830

2931

32+
def _merge_command_duplicates(session, target_command, duplicates):
33+
"""Merge duplicate command rows into a target command instance."""
34+
35+
if not duplicates:
36+
return
37+
38+
for duplicate in duplicates:
39+
if duplicate.id == target_command.id:
40+
continue
41+
42+
# Move agent command associations
43+
agent_commands = (
44+
session.query(AgentCommand)
45+
.filter(AgentCommand.command_id == duplicate.id)
46+
.all()
47+
)
48+
for agent_command in agent_commands:
49+
existing_ref = (
50+
session.query(AgentCommand)
51+
.filter(
52+
AgentCommand.agent_id == agent_command.agent_id,
53+
AgentCommand.command_id == target_command.id,
54+
)
55+
.first()
56+
)
57+
58+
if existing_ref:
59+
if agent_command.state:
60+
existing_ref.state = True
61+
session.delete(agent_command)
62+
else:
63+
agent_command.command_id = target_command.id
64+
65+
# Move arguments linked to the duplicate command
66+
arguments = (
67+
session.query(Argument).filter(Argument.command_id == duplicate.id).all()
68+
)
69+
for argument in arguments:
70+
existing_arg = (
71+
session.query(Argument)
72+
.filter_by(command_id=target_command.id, name=argument.name)
73+
.first()
74+
)
75+
if existing_arg:
76+
session.delete(argument)
77+
else:
78+
argument.command_id = target_command.id
79+
80+
# Update chain steps that target the duplicate command
81+
session.query(ChainStep).filter(
82+
ChainStep.target_command_id == duplicate.id
83+
).update(
84+
{ChainStep.target_command_id: target_command.id},
85+
synchronize_session=False,
86+
)
87+
88+
session.delete(duplicate)
89+
90+
3091
def get_extension_category(session, extension_name):
3192
"""Get or create extension category based on extension class CATEGORY attribute"""
3293
from ExtensionsHub import (
@@ -134,6 +195,7 @@ def import_extensions():
134195

135196
ext = Extensions()
136197
extensions_data = ext.get_extensions()
198+
command_owner_map = defaultdict(set)
137199
# Delete "AGiXT Chains"
138200
if "AGiXT Chains" in extensions_data:
139201
del extensions_data["AGiXT Chains"]
@@ -142,6 +204,14 @@ def import_extensions():
142204
# del extensions_data["Custom Automation"]
143205
extension_settings_data = Extensions().get_extension_settings()
144206

207+
for extension_data in extensions_data:
208+
extension_name = extension_data["extension_name"]
209+
for command in extension_data.get("commands", []):
210+
friendly_name = command.get("friendly_name", "").strip()
211+
if not friendly_name:
212+
continue
213+
command_owner_map[friendly_name.lower()].add(extension_name)
214+
145215
# Create extension database tables during seed import
146216
create_extension_tables()
147217

@@ -306,86 +376,60 @@ def import_extensions():
306376
# Process commands for this extension
307377
if "commands" in extension_data:
308378
for command_data in extension_data["commands"]:
309-
if "friendly_name" not in command_data:
379+
command_name = command_data.get("friendly_name")
380+
if not command_name:
310381
continue
311382

312-
command_name = command_data["friendly_name"]
313-
command_description = command_data.get("description", "")
383+
command_key = command_name.strip().lower()
384+
shared_command = len(command_owner_map.get(command_key, set())) > 1
314385

315-
# Check if this command exists in a different extension (moved command)
316-
existing_in_other_extension = (
386+
commands_with_name = (
317387
session.query(Command)
318388
.join(Extension)
319-
.filter(Command.name == command_name, Extension.id != extension.id)
320-
.first()
389+
.filter(Command.name == command_name)
390+
.all()
321391
)
322392

323-
if existing_in_other_extension:
393+
command = None
394+
duplicates_to_merge = []
324395

325-
# Find or create command in current extension
326-
command = (
327-
session.query(Command)
328-
.filter_by(extension_id=extension.id, name=command_name)
329-
.first()
330-
)
331-
332-
if not command:
333-
command = Command(
334-
extension_id=extension.id,
335-
name=command_name,
336-
)
337-
session.add(command)
338-
session.flush()
339-
340-
# Update all agent command references from old to new
341-
agent_commands = (
342-
session.query(AgentCommand)
343-
.filter(
344-
AgentCommand.command_id == existing_in_other_extension.id
396+
for db_command in commands_with_name:
397+
if db_command.extension_id == extension.id:
398+
if command is None:
399+
command = db_command
400+
else:
401+
duplicates_to_merge.append(db_command)
402+
elif not shared_command:
403+
duplicates_to_merge.append(db_command)
404+
405+
if command is None:
406+
if duplicates_to_merge:
407+
command = duplicates_to_merge.pop(0)
408+
old_extension_name = (
409+
command.extension.name if command.extension else "Unknown"
345410
)
346-
.all()
347-
)
348-
349-
for agent_command in agent_commands:
350-
# Check if agent already has a reference to the new command
351-
existing_ref = (
352-
session.query(AgentCommand)
353-
.filter(
354-
AgentCommand.agent_id == agent_command.agent_id,
355-
AgentCommand.command_id == command.id,
411+
if old_extension_name.lower() != extension_name.lower():
412+
logging.info(
413+
"Moving command '%s' from extension '%s' to '%s'",
414+
command_name,
415+
old_extension_name,
416+
extension_name,
356417
)
357-
.first()
358-
)
359-
360-
if existing_ref:
361-
# Merge - keep enabled if either was enabled
362-
if agent_command.state:
363-
existing_ref.state = True
364-
session.delete(agent_command)
365-
else:
366-
# Update to point to new command
367-
agent_command.command_id = command.id
368-
369-
else:
370-
# Normal case - find or create command in this extension
371-
command = (
372-
session.query(Command)
373-
.filter_by(extension_id=extension.id, name=command_name)
374-
.first()
375-
)
376-
377-
if not command:
418+
command.extension_id = extension.id
419+
command.extension = extension
420+
else:
378421
command = Command(
379422
extension_id=extension.id,
380423
name=command_name,
381424
)
382425
session.add(command)
383426
session.flush()
384427

428+
_merge_command_duplicates(session, command, duplicates_to_merge)
429+
385430
# Process command arguments if they exist
386431
if "command_args" in command_data:
387432
for arg_name, arg_type in command_data["command_args"].items():
388-
# Check if argument already exists
389433
existing_arg = (
390434
session.query(Argument)
391435
.filter_by(command_id=command.id, name=arg_name)

0 commit comments

Comments
 (0)