1+ import os
2+ import asyncio
3+ from datetime import datetime
4+ import uuid
5+ from typing import List
6+
7+ import numpy as np
8+ import torch
9+ import torch .nn as nn
10+ import torch .optim as optim
11+ from fastapi import WebSocket
12+
13+ # ---------------------------------------------------------------------------------
14+ # Simplified 3DBall environment ----------------------------------------------------
15+ # ---------------------------------------------------------------------------------
16+
17+ G = 9.81 # gravitational constant (m/s^2) – only used for rough accel scaling
18+ DT = 0.02 # physics time-step (seconds)
19+ MAX_STEPS_PER_EP = 200 # terminate episode after this many physics steps
20+
21+ # Bounds for the square platform on which the ball must stay
22+ PLATFORM_HALF_SIZE = 3.0 # ball falls when |x| or |z| > 3
23+
24+ # Bounds for platform rotation (approximately ±25° in radians)
25+ MAX_TILT = np .deg2rad (25.0 )
26+ TILT_DELTA = np .deg2rad (3.0 ) # amount each discrete action tilts the platform
27+
28+ # Observation indices for convenience
29+ # 0: rotX, 1: rotZ, 2: ballPosX, 3: ballPosZ
30+ OBS_SIZE = 4
31+
32+ # Discrete action mapping (5 actions)
33+ # 0: tilt +x (rotate platform around Z-axis positive)
34+ # 1: tilt −x (rotate platform around Z-axis negative)
35+ # 2: tilt +z (rotate platform around X-axis positive)
36+ # 3: tilt −z (rotate platform around X-axis negative)
37+ # 4: no-op
38+ ACTION_DELTAS = [
39+ np .array ([ TILT_DELTA , 0.0 ]), # +x
40+ np .array ([- TILT_DELTA , 0.0 ]), # −x
41+ np .array ([ 0.0 , TILT_DELTA ]), # +z
42+ np .array ([ 0.0 , - TILT_DELTA ]), # −z
43+ np .array ([ 0.0 , 0.0 ]), # no-op
44+ ]
45+ NUM_ACTIONS = len (ACTION_DELTAS )
46+
47+
48+ class Ball3DEnv :
49+ """A lightweight physics approximation of the ML-Agents 3DBall task."""
50+
51+ def __init__ (self ):
52+ self .reset ()
53+
54+ def reset (self ):
55+ # Platform rotation (x, z) in radians
56+ self .rot = np .zeros (2 , dtype = np .float32 )
57+ # Ball position relative to platform centre
58+ self .pos = np .random .uniform (- 0.5 , 0.5 , size = 2 ).astype (np .float32 )
59+ # Ball velocity (x, z)
60+ self .vel = np .zeros (2 , dtype = np .float32 )
61+ self .steps = 0
62+ return self ._get_obs ()
63+
64+ def _get_obs (self ):
65+ return np .array ([self .rot [0 ], self .rot [1 ], self .pos [0 ], self .pos [1 ]], dtype = np .float32 )
66+
67+ def step (self , action_idx : int ):
68+ # Apply platform tilt change, clip to limits
69+ delta = ACTION_DELTAS [action_idx ]
70+ self .rot += delta
71+ self .rot = np .clip (self .rot , - MAX_TILT , MAX_TILT )
72+
73+ # Compute acceleration of ball due to gravity projected onto tilted plane
74+ acc_x = G * np .sin (self .rot [0 ])
75+ acc_z = G * np .sin (self .rot [1 ])
76+ self .vel [0 ] += acc_x * DT
77+ self .vel [1 ] += acc_z * DT
78+
79+ # Dampen velocity slightly (friction / rolling resistance)
80+ self .vel *= 0.98
81+
82+ # Integrate position
83+ self .pos += self .vel * DT
84+
85+ # Increment step counter
86+ self .steps += 1
87+
88+ # Check termination conditions – ball fell off or time limit reached
89+ off_platform = (abs (self .pos [0 ]) > PLATFORM_HALF_SIZE ) or (abs (self .pos [1 ]) > PLATFORM_HALF_SIZE )
90+ timeout = self .steps >= MAX_STEPS_PER_EP
91+ done = off_platform or timeout
92+
93+ # Reward scheme: +0.1 per step alive, −1 when failure, +1 bonus if survived full episode
94+ reward = 0.1
95+ if done :
96+ reward = - 1.0
97+ if not off_platform and timeout :
98+ reward = 1.0
99+
100+ dist_penalty = - 0.02 * np .linalg .norm (self .pos )
101+ reward += dist_penalty
102+
103+ return self ._get_obs (), reward , done
104+
105+
106+ # ---------------------------------------------------------------------------------
107+ # Neural network & RL algorithm (simple Q-learning with discretised actions) -------
108+ # ---------------------------------------------------------------------------------
109+
110+ class QNet (nn .Module ):
111+ def __init__ (self ):
112+ super ().__init__ ()
113+ self .fc1 = nn .Linear (OBS_SIZE , 64 )
114+ self .fc2 = nn .Linear (64 , 64 )
115+ self .out = nn .Linear (64 , NUM_ACTIONS )
116+
117+ def forward (self , x ):
118+ x = torch .tanh (self .fc1 (x ))
119+ x = torch .tanh (self .fc2 (x ))
120+ return self .out (x )
121+
122+
123+ # ---------------------------------------------------------------------------------
124+ # Training entry point exposed to FastAPI websocket --------------------------------
125+ # ---------------------------------------------------------------------------------
126+
127+ POLICIES_DIR = "policies"
128+
129+ def _export_model_onnx (model : nn .Module , path : str ):
130+ dummy_input = torch .zeros ((1 , OBS_SIZE ), dtype = torch .float32 )
131+ torch .onnx .export (
132+ model ,
133+ dummy_input ,
134+ path ,
135+ input_names = ["input" ],
136+ output_names = ["output" ],
137+ opset_version = 17 ,
138+ dynamic_axes = {"input" : {0 : "batch" }, "output" : {0 : "batch" }},
139+ )
140+
141+
142+ async def train_ball3d (websocket : WebSocket ):
143+ # Ensure directory for saved policies exists
144+ os .makedirs (POLICIES_DIR , exist_ok = True )
145+
146+ timestamp = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
147+ session_uuid = str (uuid .uuid4 ())[:8 ]
148+ model_filename = f"ball3d_policy_{ timestamp } _{ session_uuid } .onnx"
149+ model_path = os .path .join (POLICIES_DIR , model_filename )
150+
151+ # Create vectorised environments (12 instances to match UI layout)
152+ envs : List [Ball3DEnv ] = [Ball3DEnv () for _ in range (12 )]
153+
154+ net = QNet ()
155+ optimizer = optim .Adam (net .parameters (), lr = 3e-4 )
156+ gamma = 0.99
157+ epsilon = 1.0
158+ episodes = 100
159+
160+ for ep in range (episodes ):
161+ # Reset all envs
162+ obs_list = [env .reset () for env in envs ]
163+ done_flags = [False ] * len (envs )
164+ total_reward = 0.0
165+ ep_loss_accum = 0.0
166+ step_counter = 0
167+
168+ while not all (done_flags ):
169+ obs_batch = torch .tensor (obs_list , dtype = torch .float32 )
170+ # Epsilon-greedy action selection
171+ if np .random .rand () < epsilon :
172+ actions = np .random .randint (0 , NUM_ACTIONS , size = len (envs ))
173+ else :
174+ with torch .no_grad ():
175+ qvals = net (obs_batch )
176+ actions = torch .argmax (qvals , dim = 1 ).cpu ().numpy ()
177+
178+ # Step each env
179+ next_obs_list = []
180+ rewards = []
181+ dones = []
182+ for idx , env in enumerate (envs ):
183+ if done_flags [idx ]:
184+ # Already finished – keep state
185+ next_obs_list .append (obs_list [idx ])
186+ rewards .append (0.0 )
187+ dones .append (True )
188+ continue
189+ nobs , rew , dn = env .step (int (actions [idx ]))
190+ next_obs_list .append (nobs )
191+ rewards .append (rew )
192+ dones .append (dn )
193+
194+ # Q-learning update per env (independent)
195+ obs_tensor = torch .tensor (obs_list , dtype = torch .float32 )
196+ next_obs_tensor = torch .tensor (next_obs_list , dtype = torch .float32 )
197+ actions_tensor = torch .tensor (actions , dtype = torch .long )
198+ rewards_tensor = torch .tensor (rewards , dtype = torch .float32 )
199+ dones_tensor = torch .tensor (dones , dtype = torch .float32 )
200+
201+ q_pred = net (obs_tensor ).gather (1 , actions_tensor .view (- 1 , 1 )).squeeze ()
202+ with torch .no_grad ():
203+ q_next_max = net (next_obs_tensor ).max (dim = 1 ).values
204+ q_target = rewards_tensor + gamma * (1.0 - dones_tensor ) * q_next_max
205+ loss = ((q_pred - q_target ) ** 2 ).mean ()
206+
207+ optimizer .zero_grad ()
208+ loss .backward ()
209+ optimizer .step ()
210+
211+ ep_loss_accum += float (loss .item ())
212+ step_counter += 1
213+ total_reward += float (np .sum (rewards ))
214+
215+ obs_list = next_obs_list
216+ done_flags = dones
217+
218+ # Stream current state to frontend for one of the envs (choose env 0)
219+ env0 = envs [0 ]
220+ await websocket .send_json ({
221+ "type" : "train_step" ,
222+ "state" : {
223+ "rotX" : float (env0 .rot [0 ]),
224+ "rotZ" : float (env0 .rot [1 ]),
225+ "ballX" : float (env0 .pos [0 ]),
226+ "ballZ" : float (env0 .pos [1 ]),
227+ },
228+ "episode" : ep + 1 ,
229+ })
230+
231+ # small sleep so UI can update smoothly
232+ await asyncio .sleep (0.01 )
233+
234+ # Break loop if steps too many (safety)
235+ if step_counter >= MAX_STEPS_PER_EP :
236+ break
237+
238+ # Epsilon annealing
239+ epsilon = max (0.05 , epsilon * 0.995 )
240+
241+ # Send progress summary every 10 episodes
242+ if (ep + 1 ) % 10 == 0 :
243+ avg_loss = ep_loss_accum / max (1 , step_counter )
244+ await websocket .send_json ({
245+ "type" : "progress" ,
246+ "episode" : ep + 1 ,
247+ "reward" : round (total_reward / len (envs ), 3 ),
248+ "loss" : round (avg_loss , 5 ),
249+ })
250+
251+ # Export trained model
252+ _export_model_onnx (net , model_path )
253+
254+ await websocket .send_json ({
255+ "type" : "trained" ,
256+ "file_url" : f"/policies/{ model_filename } " ,
257+ "model_filename" : model_filename ,
258+ "timestamp" : timestamp ,
259+ "session_uuid" : session_uuid ,
260+ })
261+
262+
263+ # ---------------------------------------------------------------------------------
264+ # Inference helper ----------------------------------------------------------------
265+ # ---------------------------------------------------------------------------------
266+
267+ _ort_sessions_ball3d = {}
268+
269+
270+ def infer_action_ball3d (obs : List [float ], model_filename : str = None ):
271+ """Run the ONNX policy to obtain an action index (0-4)."""
272+ # Lazy import to avoid dependency if unused
273+ import onnxruntime as ort
274+
275+ # Resolve model filename (most recent if None)
276+ if model_filename is None :
277+ files = [f for f in os .listdir (POLICIES_DIR ) if f .startswith ("ball3d_policy_" ) and f .endswith (".onnx" )]
278+ if not files :
279+ raise FileNotFoundError ("No trained ball3d policy files available" )
280+ files .sort (reverse = True )
281+ model_filename = files [0 ]
282+
283+ if model_filename not in _ort_sessions_ball3d :
284+ model_path = os .path .join (POLICIES_DIR , model_filename )
285+ _ort_sessions_ball3d [model_filename ] = ort .InferenceSession (model_path , providers = ["CPUExecutionProvider" ])
286+
287+ inp = np .array ([obs ], dtype = np .float32 )
288+ outputs = _ort_sessions_ball3d [model_filename ].run (None , {"input" : inp })
289+ return int (np .argmax (outputs [0 ]))
0 commit comments