|
2 | 2 | import asyncio |
3 | 3 | import uuid |
4 | 4 | import traceback |
5 | | -from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query |
| 5 | +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query,Body |
6 | 6 | from fastapi.middleware.cors import CORSMiddleware |
7 | 7 | import redis.asyncio as redis |
8 | 8 | from dotenv import load_dotenv |
@@ -36,6 +36,22 @@ class CreateAgentPayload(BaseModel): |
36 | 36 | agent_prompts: Optional[Dict[str, Dict[str, str]]] |
37 | 37 |
|
38 | 38 |
|
| 39 | +@app.get("/agent/{agent_id}") |
| 40 | +async def get_agent(agent_id: str): |
| 41 | + """Fetches an agent's information by ID.""" |
| 42 | + try: |
| 43 | + agent_data = await redis_client.get(agent_id) |
| 44 | + if not agent_data: |
| 45 | + raise HTTPException(status_code=404, detail="Agent not found") |
| 46 | + |
| 47 | + return json.loads(agent_data) |
| 48 | + |
| 49 | + except Exception as e: |
| 50 | + logger.error(f"Error fetching agent {agent_id}: {e}", exc_info=True) |
| 51 | + raise HTTPException(status_code=500, detail="Internal server error") |
| 52 | + |
| 53 | + |
| 54 | + |
39 | 55 | @app.post("/agent") |
40 | 56 | async def create_agent(agent_data: CreateAgentPayload): |
41 | 57 | agent_uuid = str(uuid.uuid4()) |
@@ -66,6 +82,99 @@ async def create_agent(agent_data: CreateAgentPayload): |
66 | 82 | return {"agent_id": agent_uuid, "state": "created"} |
67 | 83 |
|
68 | 84 |
|
| 85 | +@app.put("/agent/{agent_id}") |
| 86 | +async def edit_agent(agent_id: str, agent_data: CreateAgentPayload = Body(...)): |
| 87 | + """Edits an existing agent based on the provided agent_id.""" |
| 88 | + try: |
| 89 | + |
| 90 | + existing_data = await redis_client.get(agent_id) |
| 91 | + if not existing_data: |
| 92 | + raise HTTPException(status_code=404, detail="Agent not found") |
| 93 | + |
| 94 | + existing_data = json.loads(existing_data) |
| 95 | + |
| 96 | + |
| 97 | + new_data = agent_data.agent_config.model_dump() |
| 98 | + new_data["assistant_status"] = "updated" |
| 99 | + agent_prompts = agent_data.agent_prompts |
| 100 | + |
| 101 | + logger.info(f"Updating Agent {agent_id}: {new_data}") |
| 102 | + |
| 103 | + |
| 104 | + for index, task in enumerate(new_data.get("tasks", [])): |
| 105 | + if task.get("task_type") == "extraction": |
| 106 | + extraction_prompt_llm = os.getenv("EXTRACTION_PROMPT_GENERATION_MODEL") |
| 107 | + if not extraction_prompt_llm: |
| 108 | + raise HTTPException(status_code=500, detail="Extraction model not configured") |
| 109 | + |
| 110 | + extraction_prompt_generation_llm = LiteLLM(model=extraction_prompt_llm, max_tokens=2000) |
| 111 | + extraction_details = task["tools_config"]["llm_agent"].get("extraction_details", "") |
| 112 | + |
| 113 | + extraction_prompt = await extraction_prompt_generation_llm.generate( |
| 114 | + messages=[ |
| 115 | + {"role": "system", "content": EXTRACTION_PROMPT_GENERATION_PROMPT}, |
| 116 | + {"role": "user", "content": extraction_details} |
| 117 | + ] |
| 118 | + ) |
| 119 | + |
| 120 | + new_data["tasks"][index]["tools_config"]["llm_agent"]["extraction_json"] = extraction_prompt |
| 121 | + |
| 122 | + |
| 123 | + stored_prompt_file_path = f"{agent_id}/conversation_details.json" |
| 124 | + await asyncio.gather( |
| 125 | + redis_client.set(agent_id, json.dumps(new_data)), |
| 126 | + store_file(file_key=stored_prompt_file_path, file_data=agent_prompts, local=True) |
| 127 | + ) |
| 128 | + |
| 129 | + return {"agent_id": agent_id, "state": "updated"} |
| 130 | + |
| 131 | + except Exception as e: |
| 132 | + logger.error(f"Error updating agent {agent_id}: {e}", exc_info=True) |
| 133 | + raise HTTPException(status_code=500, detail="Internal server error") |
| 134 | + |
| 135 | +@app.delete("/agent/{agent_id}") |
| 136 | +async def delete_agent(agent_id: str): |
| 137 | + """Deletes an agent by ID.""" |
| 138 | + try: |
| 139 | + agent_exists = await redis_client.exists(agent_id) |
| 140 | + if not agent_exists: |
| 141 | + raise HTTPException(status_code=404, detail="Agent not found") |
| 142 | + |
| 143 | + await redis_client.delete(agent_id) |
| 144 | + return {"agent_id": agent_id, "state": "deleted"} |
| 145 | + |
| 146 | + except Exception as e: |
| 147 | + logger.error(f"Error deleting agent {agent_id}: {e}", exc_info=True) |
| 148 | + raise HTTPException(status_code=500, detail="Internal server error") |
| 149 | + |
| 150 | + |
| 151 | +@app.get("/all") |
| 152 | +async def get_all_agents(): |
| 153 | + """Fetches all agents stored in Redis.""" |
| 154 | + try: |
| 155 | + |
| 156 | + agent_keys = await redis_client.keys("*") |
| 157 | + |
| 158 | + if not agent_keys: |
| 159 | + return {"agents": []} |
| 160 | + agents_data = [] |
| 161 | + for key in agent_keys: |
| 162 | + try: |
| 163 | + data = await redis_client.get(key) |
| 164 | + agents_data.append(data) |
| 165 | + except Exception as e: |
| 166 | + logger.error(f"An error occurred with key {key}: {e}") |
| 167 | + |
| 168 | + |
| 169 | + agents = [{ "agent_id": key, "data": json.loads(data) } for key, data in zip(agent_keys, agents_data) if data] |
| 170 | + |
| 171 | + return {"agents": agents} |
| 172 | + |
| 173 | + except Exception as e: |
| 174 | + logger.error(f"Error fetching all agents: {e}", exc_info=True) |
| 175 | + raise HTTPException(status_code=500, detail="Internal server error") |
| 176 | + |
| 177 | + |
69 | 178 | ############################################################################################# |
70 | 179 | # Websocket |
71 | 180 | ############################################################################################# |
|
0 commit comments