Skip to content

Commit a8f4972

Browse files
committed
Refactor episode management and simplify LLM integration.
1 parent 759f070 commit a8f4972

1 file changed

Lines changed: 15 additions & 40 deletions

File tree

api/examples/self_driving_car.py

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def log_to_frontend(message: str):
4848
MAX_LLM_LOGS = 30
4949
LLM_CALL_FREQUENCY = 10
5050
USE_LOCAL_OLLAMA = True
51+
MAX_EPISODE_STEPS = 1000
5152

5253
DISCRETE_ACTIONS = [
5354
"accelerate", "decelerate", "maintain_speed", "slight_left", "slight_right"
@@ -250,6 +251,7 @@ def __init__(self, agent_id: int, start_node: int, goal_node: int, path: list, g
250251
self.speed = 0.0
251252
self.color = random.choice(RETRO_SCIFI_COLORS)
252253
self.memory_stream = []
254+
self.episode_steps = 0
253255

254256
def _set_new_path(self, start_node: int, goal_node: int, path: list):
255257
self.start_node = start_node
@@ -272,6 +274,7 @@ def reset(self, start_node: int, goal_node: int, path: list):
272274
self.angular_velocity = 0.0
273275
self.speed = 0.0
274276
self.memory_stream = []
277+
self.episode_steps = 0
275278

276279
def _calculate_remaining_len(self):
277280
"""Calculates the total remaining distance along the agent's path."""
@@ -478,6 +481,7 @@ def _execute_actions(self, agent_actions: List[Tuple[str, Any]]):
478481
dones = []
479482

480483
for agent, (action, data) in zip(self.agents, agent_actions):
484+
agent.episode_steps += 1
481485
# Store state before action
482486
last_speed = agent.speed
483487
last_heading = agent.heading
@@ -503,6 +507,12 @@ def _execute_actions(self, agent_actions: List[Tuple[str, Any]]):
503507
dones.append(True)
504508
continue
505509

510+
if agent.episode_steps > MAX_EPISODE_STEPS:
511+
agent.add_to_memory_stream(f"Episode timed out after {MAX_EPISODE_STEPS} steps.", self.step_count)
512+
rewards.append(-20.0) # Timeout penalty
513+
dones.append(True)
514+
continue
515+
506516
dist_to_move = agent.speed
507517
agent.distance_on_segment += dist_to_move
508518

@@ -736,7 +746,7 @@ def get_valid_actions_mask(agent: Agent, env: "SelfDrivingCarEnv") -> np.ndarray
736746
LR = 3e-4
737747

738748
# making this super low
739-
EPISODES = 256
749+
EPISODES = 64
740750

741751
async def train_self_driving_car(websocket: WebSocket, env: SelfDrivingCarEnv):
742752
global _current_websocket
@@ -934,45 +944,10 @@ async def receive_commands():
934944
action_name = DISCRETE_ACTIONS[actions_np[i]]
935945
agent_actions_for_env.append((action_name, None))
936946

937-
if env.step_count % LLM_CALL_FREQUENCY == 0:
938-
try:
939-
top_features = env.trained_policy.get_local_feature_importance(obs_t[[0]], actions_t[[0]])
940-
941-
prompt = (
942-
f"The self-driving car is at step {env.step_count}. "
943-
f"It's currently moving at {env.agents[0].speed:.1f} m/s with a heading of {env.agents[0].heading:.1f} degrees. "
944-
f"The chosen action is to '{DISCRETE_ACTIONS[actions_np[0]]}'.\n\n"
945-
"The policy model's decision was influenced by these top features, with their contribution to the decision shown as a percentage:\n"
946-
)
947-
for f in top_features:
948-
prompt += f"- {f['feature']} ({f['percentage']:.0f}%): Current Value = {f['value']:.2f}\n"
949-
950-
prompt += "\nBased on this context, provide a concise, one-sentence explanation for why the car chose this action. For example: 'The car is accelerating because it's on a straight path with no immediate obstacles.' or 'The car is turning left to correct its heading towards the next waypoint.'"
951-
952-
explanation_json = get_json(
953-
prompt,
954-
name="format_explanation",
955-
description="Formats the explanation into a structured JSON object.",
956-
properties={"explanation": {"type": "string"}},
957-
use_local=USE_LOCAL_OLLAMA,
958-
)
959-
explanation = explanation_json.get("explanation", "Could not generate explanation.")
960-
env.add_message(agent_id=env.agents[0].id, message=explanation)
961-
except Exception as e:
962-
logger.warning(f"LLM explanation failed: {e}")
963-
top_features_list = env.trained_policy.get_local_feature_importance(obs_t[[0]], actions_t[[0]])
964-
causes_str = ', '.join([f"{f['feature']} ({f['percentage']:.0f}%)" for f in top_features_list])
965-
explanation = f"Action: {DISCRETE_ACTIONS[actions_np[0]]}, Causes: {causes_str}"
966-
env.add_message(agent_id=env.agents[0].id, message=explanation)
967-
elif env.step_count % LLM_CALL_FREQUENCY != 0 and len(env.messages) > 0 and "Action:" in env.messages[-1].get("message", ""):
968-
# If it's not an LLM call step, do nothing to avoid replacing the detailed message with a simple one
969-
pass
970-
else:
971-
# Fallback for the very first steps or if there are no messages
972-
top_features_list = env.trained_policy.get_local_feature_importance(obs_t[[0]], actions_t[[0]])
973-
causes_str = ', '.join([f"{f['feature']} ({f['percentage']:.0f}%)" for f in top_features_list])
974-
explanation = f"Action: {DISCRETE_ACTIONS[actions_np[0]]}, Causes: {causes_str}"
975-
env.add_message(agent_id=env.agents[0].id, message=explanation)
947+
top_features_list = env.trained_policy.get_local_feature_importance(obs_t[[0]], actions_t[[0]])
948+
causes_str = ', '.join([f"{f['feature']} ({f['percentage']:.0f}%)" for f in top_features_list])
949+
explanation = f"Action: {DISCRETE_ACTIONS[actions_np[0]]}, Causes: {causes_str}"
950+
env.add_message(agent_id=env.agents[0].id, message=explanation)
976951

977952
else:
978953
for agent in env.agents:

0 commit comments

Comments
 (0)