Skip to content

Commit 181bd6f

Browse files
committed
Add GridWorld environment and training API integration
1 parent da52b3d commit 181bd6f

7 files changed

Lines changed: 597 additions & 30 deletions

File tree

api/examples/gridworld.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# ---------------------------------------------------------------------------------
2+
# GridWorld environment and Q-learning trainer -------------------------------------
3+
# ---------------------------------------------------------------------------------
4+
5+
import os
6+
import asyncio
7+
from datetime import datetime
8+
import uuid
9+
from typing import List, Tuple
10+
11+
import numpy as np
12+
import torch
13+
import torch.nn as nn
14+
import torch.optim as optim
15+
from fastapi import WebSocket
16+
17+
# ---------------------------------------------------------------------------------
18+
# Simplified discrete GridWorld ----------------------------------------------------
19+
# ---------------------------------------------------------------------------------
20+
21+
DEFAULT_GRID_SIZE = 5 # N x N grid
22+
MAX_STEPS_PER_EP = 100
23+
24+
# Action mapping
25+
# 0: no-op, 1: up (+z), 2: down (−z), 3: left (−x), 4: right (+x)
26+
ACTION_DELTAS: List[Tuple[int, int]] = [
27+
(0, 0), # stay
28+
(0, 1), # up
29+
(0, -1), # down
30+
(-1, 0), # left
31+
(1, 0), # right
32+
]
33+
NUM_ACTIONS = len(ACTION_DELTAS)
34+
35+
# Observation is 4-dim float vector:
36+
# [dx_to_goal, dy_to_goal, goal_one_hot_0, goal_one_hot_1]
37+
OBS_SIZE = 4
38+
39+
40+
class GridWorldEnv:
41+
"""Minimal multi-goal GridWorld with one agent and two goal types."""
42+
43+
def __init__(self, grid_size: int = DEFAULT_GRID_SIZE):
44+
self.grid_size = grid_size
45+
self.reset()
46+
47+
def reset(self):
48+
# Random positions – ensure they are all unique
49+
all_cells = [(x, y) for x in range(self.grid_size) for y in range(self.grid_size)]
50+
np.random.shuffle(all_cells)
51+
self.agent_pos = all_cells[0]
52+
self.green_goals = [all_cells[1]] # "plus" goals
53+
self.red_goals = [all_cells[2]] # "ex" goals
54+
# Randomly assign current target goal type
55+
self.current_goal_type = np.random.choice([0, 1]) # 0 = green, 1 = red
56+
self.steps = 0
57+
return self._get_obs()
58+
59+
# -------------------------------------------------------------------------
60+
def _get_obs(self):
61+
# Vector from agent to the *nearest* target goal of the required type
62+
if self.current_goal_type == 0:
63+
goal = self.green_goals[0]
64+
else:
65+
goal = self.red_goals[0]
66+
dx = (goal[0] - self.agent_pos[0]) / max(1, self.grid_size - 1)
67+
dy = (goal[1] - self.agent_pos[1]) / max(1, self.grid_size - 1)
68+
one_hot_goal = [1.0, 0.0] if self.current_goal_type == 0 else [0.0, 1.0]
69+
return np.array([dx, dy, *one_hot_goal], dtype=np.float32)
70+
71+
# -------------------------------------------------------------------------
72+
def step(self, action_idx: int):
73+
delta = ACTION_DELTAS[action_idx]
74+
new_x = int(np.clip(self.agent_pos[0] + delta[0], 0, self.grid_size - 1))
75+
new_y = int(np.clip(self.agent_pos[1] + delta[1], 0, self.grid_size - 1))
76+
self.agent_pos = (new_x, new_y)
77+
self.steps += 1
78+
79+
# Base step penalty
80+
reward = -0.01
81+
done = False
82+
83+
# Check for goal collision
84+
if self.agent_pos in self.green_goals:
85+
if self.current_goal_type == 0:
86+
reward = 1.0
87+
else:
88+
reward = -1.0
89+
done = True
90+
elif self.agent_pos in self.red_goals:
91+
if self.current_goal_type == 1:
92+
reward = 1.0
93+
else:
94+
reward = -1.0
95+
done = True
96+
97+
if self.steps >= MAX_STEPS_PER_EP:
98+
done = True
99+
100+
return self._get_obs(), reward, done
101+
102+
# ---------------------------------------------------------------------------------
103+
# Neural network – simple MLP Q-network -------------------------------------------
104+
# ---------------------------------------------------------------------------------
105+
106+
107+
class QNet(nn.Module):
108+
def __init__(self):
109+
super().__init__()
110+
self.fc1 = nn.Linear(OBS_SIZE, 64)
111+
self.fc2 = nn.Linear(64, 64)
112+
self.out = nn.Linear(64, NUM_ACTIONS)
113+
114+
def forward(self, x):
115+
x = torch.tanh(self.fc1(x))
116+
x = torch.tanh(self.fc2(x))
117+
return self.out(x)
118+
119+
# ---------------------------------------------------------------------------------
120+
# Training exposed to FastAPI via websocket ---------------------------------------
121+
# ---------------------------------------------------------------------------------
122+
123+
POLICIES_DIR = "policies"
124+
125+
126+
def _export_model_onnx(model: nn.Module, path: str):
127+
dummy = torch.zeros((1, OBS_SIZE), dtype=torch.float32)
128+
torch.onnx.export(
129+
model,
130+
dummy,
131+
path,
132+
input_names=["input"],
133+
output_names=["output"],
134+
opset_version=17,
135+
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
136+
)
137+
138+
139+
async def train_gridworld(websocket: WebSocket):
140+
os.makedirs(POLICIES_DIR, exist_ok=True)
141+
142+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
143+
session_uuid = str(uuid.uuid4())[:8]
144+
model_filename = f"gridworld_policy_{timestamp}_{session_uuid}.onnx"
145+
model_path = os.path.join(POLICIES_DIR, model_filename)
146+
147+
envs: List[GridWorldEnv] = [GridWorldEnv() for _ in range(12)]
148+
149+
net = QNet()
150+
optimizer = optim.Adam(net.parameters(), lr=3e-4)
151+
gamma = 0.95
152+
epsilon = 1.0
153+
episodes = 300
154+
155+
for ep in range(episodes):
156+
obs_list = [env.reset() for env in envs]
157+
done_flags = [False] * len(envs)
158+
total_reward = 0.0
159+
ep_loss_accum = 0.0
160+
step_counter = 0
161+
162+
while not all(done_flags):
163+
obs_batch = torch.tensor(obs_list, dtype=torch.float32)
164+
if np.random.rand() < epsilon:
165+
actions = np.random.randint(0, NUM_ACTIONS, size=len(envs))
166+
else:
167+
with torch.no_grad():
168+
qvals = net(obs_batch)
169+
actions = torch.argmax(qvals, dim=1).cpu().numpy()
170+
171+
next_obs_list = []
172+
rewards = []
173+
dones = []
174+
for idx, env in enumerate(envs):
175+
if done_flags[idx]:
176+
next_obs_list.append(obs_list[idx])
177+
rewards.append(0.0)
178+
dones.append(True)
179+
continue
180+
nobs, rew, dn = env.step(int(actions[idx]))
181+
next_obs_list.append(nobs)
182+
rewards.append(rew)
183+
dones.append(dn)
184+
185+
obs_tensor = torch.tensor(obs_list, dtype=torch.float32)
186+
next_obs_tensor = torch.tensor(next_obs_list, dtype=torch.float32)
187+
actions_tensor = torch.tensor(actions, dtype=torch.long)
188+
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
189+
dones_tensor = torch.tensor(dones, dtype=torch.float32)
190+
191+
q_pred = net(obs_tensor).gather(1, actions_tensor.view(-1, 1)).squeeze()
192+
with torch.no_grad():
193+
q_next_max = net(next_obs_tensor).max(dim=1).values
194+
q_target = rewards_tensor + gamma * (1.0 - dones_tensor) * q_next_max
195+
loss = ((q_pred - q_target) ** 2).mean()
196+
197+
optimizer.zero_grad()
198+
loss.backward()
199+
optimizer.step()
200+
201+
ep_loss_accum += float(loss.item())
202+
step_counter += 1
203+
total_reward += float(np.sum(rewards))
204+
205+
obs_list = next_obs_list
206+
done_flags = dones
207+
208+
# stream first env state for visualization
209+
env0 = envs[0]
210+
await websocket.send_json({
211+
"type": "train_step",
212+
"state": {
213+
"agentX": int(env0.agent_pos[0]),
214+
"agentY": int(env0.agent_pos[1]),
215+
"gridSize": env0.grid_size,
216+
"greenGoals": env0.green_goals,
217+
"redGoals": env0.red_goals,
218+
"currentGoalType": int(env0.current_goal_type),
219+
},
220+
"episode": ep + 1,
221+
})
222+
223+
await asyncio.sleep(0.01)
224+
225+
if step_counter >= MAX_STEPS_PER_EP:
226+
break
227+
228+
epsilon = max(0.05, epsilon * 0.99)
229+
230+
if (ep + 1) % 10 == 0:
231+
avg_loss = ep_loss_accum / max(1, step_counter)
232+
await websocket.send_json({
233+
"type": "progress",
234+
"episode": ep + 1,
235+
"reward": round(total_reward / len(envs), 3),
236+
"loss": round(avg_loss, 5),
237+
})
238+
239+
_export_model_onnx(net, model_path)
240+
241+
await websocket.send_json({
242+
"type": "trained",
243+
"file_url": f"/policies/{model_filename}",
244+
"model_filename": model_filename,
245+
"timestamp": timestamp,
246+
"session_uuid": session_uuid,
247+
})
248+
249+
# ---------------------------------------------------------------------------------
250+
# Inference -----------------------------------------------------------------------
251+
# ---------------------------------------------------------------------------------
252+
253+
_ort_sessions_grid = {}
254+
255+
256+
def infer_action_gridworld(obs: List[float], model_filename: str | None = None):
257+
import onnxruntime as ort
258+
259+
if model_filename is None:
260+
files = [f for f in os.listdir(POLICIES_DIR) if f.startswith("gridworld_policy_") and f.endswith(".onnx")]
261+
if not files:
262+
raise FileNotFoundError("No trained gridworld policy files available")
263+
files.sort(reverse=True)
264+
model_filename = files[0]
265+
266+
if model_filename not in _ort_sessions_grid:
267+
model_path = os.path.join(POLICIES_DIR, model_filename)
268+
_ort_sessions_grid[model_filename] = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
269+
270+
inp = np.array([obs], dtype=np.float32)
271+
outputs = _ort_sessions_grid[model_filename].run(None, {"input": inp})
272+
return int(np.argmax(outputs[0]))

api/main.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from examples.basic import train_basic, infer_action
1010
from examples.ball3d import train_ball3d, infer_action_ball3d
11+
from examples.gridworld import train_gridworld, infer_action_gridworld
1112

1213
app = FastAPI(title="ML-Agents API")
1314

@@ -96,4 +97,21 @@ async def websocket_ball3d(ws: WebSocket):
9697
elif cmd == "inference":
9798
obs = data.get("obs", []) # expect list [rotX, rotZ, ballX, ballZ]
9899
act_idx = infer_action_ball3d(obs)
100+
await ws.send_json({"type": "action", "action": int(act_idx)})
101+
102+
103+
# WebSocket endpoint for GridWorld
104+
105+
106+
@app.websocket("/ws/gridworld")
107+
async def websocket_gridworld(ws: WebSocket):
108+
await ws.accept()
109+
async for message in ws.iter_text():
110+
data = json.loads(message)
111+
cmd = data.get("cmd")
112+
if cmd == "train":
113+
await train_gridworld(ws)
114+
elif cmd == "inference":
115+
obs = data.get("obs", []) # expect [dx, dy, g0, g1] (or any agreed)
116+
act_idx = infer_action_gridworld(obs)
99117
await ws.send_json({"type": "action", "action": int(act_idx)})

client/src/App.jsx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@ import { BrowserRouter, Routes, Route, Navigate } from 'react-router-dom';
33
import ExamplesIndex from './examples/Index.jsx';
44
import BasicExample from './examples/Basic.jsx';
55
import Ball3DExample from './examples/Ball3D.jsx';
6+
import GridWorldExample from './examples/GridWorld.jsx';
67

78
export default function App() {
89
return (
910
<BrowserRouter basename="/three-mlagents">
11+
{/* Global style override for KaTeX display alignment */}
12+
<style>{`
13+
.katex-display{ text-align:left !important; }
14+
.katex-display > .katex{ text-align:left !important; }
15+
`}</style>
1016
<Routes>
1117
<Route path="/" element={<ExamplesIndex />} />
1218
<Route path="/basic" element={<BasicExample />} />
1319
<Route path="/ball3d" element={<Ball3DExample />} />
20+
<Route path="/gridworld" element={<GridWorldExample />} />
1421
<Route path="*" element={<Navigate to="/" replace />} />
1522
</Routes>
1623
</BrowserRouter>
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import React, { memo } from 'react';
2+
import { Line } from 'react-chartjs-2';
3+
import {
4+
Chart as ChartJS,
5+
LineElement,
6+
PointElement,
7+
LinearScale,
8+
CategoryScale,
9+
} from 'chart.js';
10+
11+
// Ensure required elements are registered once
12+
if (!ChartJS.registry.getElement('line')) {
13+
ChartJS.register(LineElement, PointElement, LinearScale, CategoryScale);
14+
}
15+
16+
function ChartPanel({ labels, rewards, losses, style = {} }) {
17+
return (
18+
<div style={{ width: '100%', height: '100%', ...style }}>
19+
<Line
20+
data={{
21+
labels,
22+
datasets: [
23+
{ label: 'Reward', data: rewards, borderColor: '#0f0', backgroundColor: 'transparent', borderWidth: 1, pointRadius: 0, yAxisID: 'y' },
24+
{ label: 'Loss', data: losses, borderColor: 'orange', backgroundColor: 'transparent', borderWidth: 1, pointRadius: 0, yAxisID: 'y1' },
25+
],
26+
}}
27+
options={{
28+
responsive: true,
29+
maintainAspectRatio: false,
30+
scales: {
31+
x: { ticks: { color: '#aaa' }, grid: { color: 'rgba(255,255,255,0.1)' }, title: { display: true, text: 'Episode', color: '#aaa' } },
32+
y: { ticks: { color: '#aaa' }, grid: { color: 'rgba(255,255,255,0.1)' }, title: { display: true, text: 'Reward', color: '#aaa' } },
33+
y1: { position: 'right', ticks: { color: 'orange' }, grid: { drawOnChartArea: false }, title: { display: true, text: 'Loss', color: 'orange' } },
34+
},
35+
plugins: { legend: { labels: { color: '#ddd' } } },
36+
}}
37+
/>
38+
</div>
39+
);
40+
}
41+
42+
export default memo(ChartPanel);

0 commit comments

Comments
 (0)