Skip to content

Commit 64ba0dc

Browse files
committed
Implement 3DBall environment with training and inference.
1 parent 6df4042 commit 64ba0dc

7 files changed

Lines changed: 633 additions & 0 deletions

File tree

api/examples/ball3d.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
import os
2+
import asyncio
3+
from datetime import datetime
4+
import uuid
5+
from typing import List
6+
7+
import numpy as np
8+
import torch
9+
import torch.nn as nn
10+
import torch.optim as optim
11+
from fastapi import WebSocket
12+
13+
# ---------------------------------------------------------------------------------
14+
# Simplified 3DBall environment ----------------------------------------------------
15+
# ---------------------------------------------------------------------------------
16+
17+
G = 9.81 # gravitational constant (m/s^2) – only used for rough accel scaling
18+
DT = 0.02 # physics time-step (seconds)
19+
MAX_STEPS_PER_EP = 200 # terminate episode after this many physics steps
20+
21+
# Bounds for the square platform on which the ball must stay
22+
PLATFORM_HALF_SIZE = 3.0 # ball falls when |x| or |z| > 3
23+
24+
# Bounds for platform rotation (approximately ±25° in radians)
25+
MAX_TILT = np.deg2rad(25.0)
26+
TILT_DELTA = np.deg2rad(3.0) # amount each discrete action tilts the platform
27+
28+
# Observation indices for convenience
29+
# 0: rotX, 1: rotZ, 2: ballPosX, 3: ballPosZ
30+
OBS_SIZE = 4
31+
32+
# Discrete action mapping (5 actions)
33+
# 0: tilt +x (rotate platform around Z-axis positive)
34+
# 1: tilt −x (rotate platform around Z-axis negative)
35+
# 2: tilt +z (rotate platform around X-axis positive)
36+
# 3: tilt −z (rotate platform around X-axis negative)
37+
# 4: no-op
38+
ACTION_DELTAS = [
39+
np.array([ TILT_DELTA, 0.0]), # +x
40+
np.array([-TILT_DELTA, 0.0]), # −x
41+
np.array([ 0.0, TILT_DELTA]), # +z
42+
np.array([ 0.0, -TILT_DELTA]), # −z
43+
np.array([ 0.0, 0.0]), # no-op
44+
]
45+
NUM_ACTIONS = len(ACTION_DELTAS)
46+
47+
48+
class Ball3DEnv:
49+
"""A lightweight physics approximation of the ML-Agents 3DBall task."""
50+
51+
def __init__(self):
52+
self.reset()
53+
54+
def reset(self):
55+
# Platform rotation (x, z) in radians
56+
self.rot = np.zeros(2, dtype=np.float32)
57+
# Ball position relative to platform centre
58+
self.pos = np.random.uniform(-0.5, 0.5, size=2).astype(np.float32)
59+
# Ball velocity (x, z)
60+
self.vel = np.zeros(2, dtype=np.float32)
61+
self.steps = 0
62+
return self._get_obs()
63+
64+
def _get_obs(self):
65+
return np.array([self.rot[0], self.rot[1], self.pos[0], self.pos[1]], dtype=np.float32)
66+
67+
def step(self, action_idx: int):
68+
# Apply platform tilt change, clip to limits
69+
delta = ACTION_DELTAS[action_idx]
70+
self.rot += delta
71+
self.rot = np.clip(self.rot, -MAX_TILT, MAX_TILT)
72+
73+
# Compute acceleration of ball due to gravity projected onto tilted plane
74+
acc_x = G * np.sin(self.rot[0])
75+
acc_z = G * np.sin(self.rot[1])
76+
self.vel[0] += acc_x * DT
77+
self.vel[1] += acc_z * DT
78+
79+
# Dampen velocity slightly (friction / rolling resistance)
80+
self.vel *= 0.98
81+
82+
# Integrate position
83+
self.pos += self.vel * DT
84+
85+
# Increment step counter
86+
self.steps += 1
87+
88+
# Check termination conditions – ball fell off or time limit reached
89+
off_platform = (abs(self.pos[0]) > PLATFORM_HALF_SIZE) or (abs(self.pos[1]) > PLATFORM_HALF_SIZE)
90+
timeout = self.steps >= MAX_STEPS_PER_EP
91+
done = off_platform or timeout
92+
93+
# Reward scheme: +0.1 per step alive, −1 when failure, +1 bonus if survived full episode
94+
reward = 0.1
95+
if done:
96+
reward = -1.0
97+
if not off_platform and timeout:
98+
reward = 1.0
99+
100+
dist_penalty = -0.02 * np.linalg.norm(self.pos)
101+
reward += dist_penalty
102+
103+
return self._get_obs(), reward, done
104+
105+
106+
# ---------------------------------------------------------------------------------
107+
# Neural network & RL algorithm (simple Q-learning with discretised actions) -------
108+
# ---------------------------------------------------------------------------------
109+
110+
class QNet(nn.Module):
111+
def __init__(self):
112+
super().__init__()
113+
self.fc1 = nn.Linear(OBS_SIZE, 64)
114+
self.fc2 = nn.Linear(64, 64)
115+
self.out = nn.Linear(64, NUM_ACTIONS)
116+
117+
def forward(self, x):
118+
x = torch.tanh(self.fc1(x))
119+
x = torch.tanh(self.fc2(x))
120+
return self.out(x)
121+
122+
123+
# ---------------------------------------------------------------------------------
124+
# Training entry point exposed to FastAPI websocket --------------------------------
125+
# ---------------------------------------------------------------------------------
126+
127+
POLICIES_DIR = "policies"
128+
129+
def _export_model_onnx(model: nn.Module, path: str):
130+
dummy_input = torch.zeros((1, OBS_SIZE), dtype=torch.float32)
131+
torch.onnx.export(
132+
model,
133+
dummy_input,
134+
path,
135+
input_names=["input"],
136+
output_names=["output"],
137+
opset_version=17,
138+
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
139+
)
140+
141+
142+
async def train_ball3d(websocket: WebSocket):
143+
# Ensure directory for saved policies exists
144+
os.makedirs(POLICIES_DIR, exist_ok=True)
145+
146+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
147+
session_uuid = str(uuid.uuid4())[:8]
148+
model_filename = f"ball3d_policy_{timestamp}_{session_uuid}.onnx"
149+
model_path = os.path.join(POLICIES_DIR, model_filename)
150+
151+
# Create vectorised environments (12 instances to match UI layout)
152+
envs: List[Ball3DEnv] = [Ball3DEnv() for _ in range(12)]
153+
154+
net = QNet()
155+
optimizer = optim.Adam(net.parameters(), lr=3e-4)
156+
gamma = 0.99
157+
epsilon = 1.0
158+
episodes = 100
159+
160+
for ep in range(episodes):
161+
# Reset all envs
162+
obs_list = [env.reset() for env in envs]
163+
done_flags = [False] * len(envs)
164+
total_reward = 0.0
165+
ep_loss_accum = 0.0
166+
step_counter = 0
167+
168+
while not all(done_flags):
169+
obs_batch = torch.tensor(obs_list, dtype=torch.float32)
170+
# Epsilon-greedy action selection
171+
if np.random.rand() < epsilon:
172+
actions = np.random.randint(0, NUM_ACTIONS, size=len(envs))
173+
else:
174+
with torch.no_grad():
175+
qvals = net(obs_batch)
176+
actions = torch.argmax(qvals, dim=1).cpu().numpy()
177+
178+
# Step each env
179+
next_obs_list = []
180+
rewards = []
181+
dones = []
182+
for idx, env in enumerate(envs):
183+
if done_flags[idx]:
184+
# Already finished – keep state
185+
next_obs_list.append(obs_list[idx])
186+
rewards.append(0.0)
187+
dones.append(True)
188+
continue
189+
nobs, rew, dn = env.step(int(actions[idx]))
190+
next_obs_list.append(nobs)
191+
rewards.append(rew)
192+
dones.append(dn)
193+
194+
# Q-learning update per env (independent)
195+
obs_tensor = torch.tensor(obs_list, dtype=torch.float32)
196+
next_obs_tensor = torch.tensor(next_obs_list, dtype=torch.float32)
197+
actions_tensor = torch.tensor(actions, dtype=torch.long)
198+
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
199+
dones_tensor = torch.tensor(dones, dtype=torch.float32)
200+
201+
q_pred = net(obs_tensor).gather(1, actions_tensor.view(-1, 1)).squeeze()
202+
with torch.no_grad():
203+
q_next_max = net(next_obs_tensor).max(dim=1).values
204+
q_target = rewards_tensor + gamma * (1.0 - dones_tensor) * q_next_max
205+
loss = ((q_pred - q_target) ** 2).mean()
206+
207+
optimizer.zero_grad()
208+
loss.backward()
209+
optimizer.step()
210+
211+
ep_loss_accum += float(loss.item())
212+
step_counter += 1
213+
total_reward += float(np.sum(rewards))
214+
215+
obs_list = next_obs_list
216+
done_flags = dones
217+
218+
# Stream current state to frontend for one of the envs (choose env 0)
219+
env0 = envs[0]
220+
await websocket.send_json({
221+
"type": "train_step",
222+
"state": {
223+
"rotX": float(env0.rot[0]),
224+
"rotZ": float(env0.rot[1]),
225+
"ballX": float(env0.pos[0]),
226+
"ballZ": float(env0.pos[1]),
227+
},
228+
"episode": ep + 1,
229+
})
230+
231+
# small sleep so UI can update smoothly
232+
await asyncio.sleep(0.01)
233+
234+
# Break loop if steps too many (safety)
235+
if step_counter >= MAX_STEPS_PER_EP:
236+
break
237+
238+
# Epsilon annealing
239+
epsilon = max(0.05, epsilon * 0.995)
240+
241+
# Send progress summary every 10 episodes
242+
if (ep + 1) % 10 == 0:
243+
avg_loss = ep_loss_accum / max(1, step_counter)
244+
await websocket.send_json({
245+
"type": "progress",
246+
"episode": ep + 1,
247+
"reward": round(total_reward / len(envs), 3),
248+
"loss": round(avg_loss, 5),
249+
})
250+
251+
# Export trained model
252+
_export_model_onnx(net, model_path)
253+
254+
await websocket.send_json({
255+
"type": "trained",
256+
"file_url": f"/policies/{model_filename}",
257+
"model_filename": model_filename,
258+
"timestamp": timestamp,
259+
"session_uuid": session_uuid,
260+
})
261+
262+
263+
# ---------------------------------------------------------------------------------
264+
# Inference helper ----------------------------------------------------------------
265+
# ---------------------------------------------------------------------------------
266+
267+
_ort_sessions_ball3d = {}
268+
269+
270+
def infer_action_ball3d(obs: List[float], model_filename: str = None):
271+
"""Run the ONNX policy to obtain an action index (0-4)."""
272+
# Lazy import to avoid dependency if unused
273+
import onnxruntime as ort
274+
275+
# Resolve model filename (most recent if None)
276+
if model_filename is None:
277+
files = [f for f in os.listdir(POLICIES_DIR) if f.startswith("ball3d_policy_") and f.endswith(".onnx")]
278+
if not files:
279+
raise FileNotFoundError("No trained ball3d policy files available")
280+
files.sort(reverse=True)
281+
model_filename = files[0]
282+
283+
if model_filename not in _ort_sessions_ball3d:
284+
model_path = os.path.join(POLICIES_DIR, model_filename)
285+
_ort_sessions_ball3d[model_filename] = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
286+
287+
inp = np.array([obs], dtype=np.float32)
288+
outputs = _ort_sessions_ball3d[model_filename].run(None, {"input": inp})
289+
return int(np.argmax(outputs[0]))

api/main.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fastapi.staticfiles import StaticFiles
88

99
from examples.basic import train_basic, infer_action
10+
from examples.ball3d import train_ball3d, infer_action_ball3d
1011

1112
app = FastAPI(title="ML-Agents API")
1213

@@ -78,4 +79,21 @@ async def websocket_basic(ws: WebSocket):
7879
elif cmd == "inference":
7980
position = int(data.get("obs", 0))
8081
act_idx = await infer_action(position)
82+
await ws.send_json({"type": "action", "action": int(act_idx)})
83+
84+
85+
# WebSocket endpoint for 3DBall
86+
87+
88+
@app.websocket("/ws/ball3d")
89+
async def websocket_ball3d(ws: WebSocket):
90+
await ws.accept()
91+
async for message in ws.iter_text():
92+
data = json.loads(message)
93+
cmd = data.get("cmd")
94+
if cmd == "train":
95+
await train_ball3d(ws)
96+
elif cmd == "inference":
97+
obs = data.get("obs", []) # expect list [rotX, rotZ, ballX, ballZ]
98+
act_idx = infer_action_ball3d(obs)
8199
await ws.send_json({"type": "action", "action": int(act_idx)})

client/public/3d_ball_example.jpg

420 KB
Loading

client/src/App.jsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ import React from 'react';
22
import { BrowserRouter, Routes, Route, Navigate } from 'react-router-dom';
33
import ExamplesIndex from './examples/Index.jsx';
44
import BasicExample from './examples/Basic.jsx';
5+
import Ball3DExample from './examples/Ball3D.jsx';
56

67
export default function App() {
78
return (
89
<BrowserRouter basename="/three-mlagents">
910
<Routes>
1011
<Route path="/" element={<ExamplesIndex />} />
1112
<Route path="/basic" element={<BasicExample />} />
13+
<Route path="/ball3d" element={<Ball3DExample />} />
1214
<Route path="*" element={<Navigate to="/" replace />} />
1315
</Routes>
1416
</BrowserRouter>

0 commit comments

Comments
 (0)