1+ # ---------------------------------------------------------------------------------
2+ # GridWorld environment and Q-learning trainer -------------------------------------
3+ # ---------------------------------------------------------------------------------
4+
5+ import os
6+ import asyncio
7+ from datetime import datetime
8+ import uuid
9+ from typing import List , Tuple
10+
11+ import numpy as np
12+ import torch
13+ import torch .nn as nn
14+ import torch .optim as optim
15+ from fastapi import WebSocket
16+
17+ # ---------------------------------------------------------------------------------
18+ # Simplified discrete GridWorld ----------------------------------------------------
19+ # ---------------------------------------------------------------------------------
20+
21+ DEFAULT_GRID_SIZE = 5 # N x N grid
22+ MAX_STEPS_PER_EP = 100
23+
24+ # Action mapping
25+ # 0: no-op, 1: up (+z), 2: down (−z), 3: left (−x), 4: right (+x)
26+ ACTION_DELTAS : List [Tuple [int , int ]] = [
27+ (0 , 0 ), # stay
28+ (0 , 1 ), # up
29+ (0 , - 1 ), # down
30+ (- 1 , 0 ), # left
31+ (1 , 0 ), # right
32+ ]
33+ NUM_ACTIONS = len (ACTION_DELTAS )
34+
35+ # Observation is 4-dim float vector:
36+ # [dx_to_goal, dy_to_goal, goal_one_hot_0, goal_one_hot_1]
37+ OBS_SIZE = 4
38+
39+
40+ class GridWorldEnv :
41+ """Minimal multi-goal GridWorld with one agent and two goal types."""
42+
43+ def __init__ (self , grid_size : int = DEFAULT_GRID_SIZE ):
44+ self .grid_size = grid_size
45+ self .reset ()
46+
47+ def reset (self ):
48+ # Random positions – ensure they are all unique
49+ all_cells = [(x , y ) for x in range (self .grid_size ) for y in range (self .grid_size )]
50+ np .random .shuffle (all_cells )
51+ self .agent_pos = all_cells [0 ]
52+ self .green_goals = [all_cells [1 ]] # "plus" goals
53+ self .red_goals = [all_cells [2 ]] # "ex" goals
54+ # Randomly assign current target goal type
55+ self .current_goal_type = np .random .choice ([0 , 1 ]) # 0 = green, 1 = red
56+ self .steps = 0
57+ return self ._get_obs ()
58+
59+ # -------------------------------------------------------------------------
60+ def _get_obs (self ):
61+ # Vector from agent to the *nearest* target goal of the required type
62+ if self .current_goal_type == 0 :
63+ goal = self .green_goals [0 ]
64+ else :
65+ goal = self .red_goals [0 ]
66+ dx = (goal [0 ] - self .agent_pos [0 ]) / max (1 , self .grid_size - 1 )
67+ dy = (goal [1 ] - self .agent_pos [1 ]) / max (1 , self .grid_size - 1 )
68+ one_hot_goal = [1.0 , 0.0 ] if self .current_goal_type == 0 else [0.0 , 1.0 ]
69+ return np .array ([dx , dy , * one_hot_goal ], dtype = np .float32 )
70+
71+ # -------------------------------------------------------------------------
72+ def step (self , action_idx : int ):
73+ delta = ACTION_DELTAS [action_idx ]
74+ new_x = int (np .clip (self .agent_pos [0 ] + delta [0 ], 0 , self .grid_size - 1 ))
75+ new_y = int (np .clip (self .agent_pos [1 ] + delta [1 ], 0 , self .grid_size - 1 ))
76+ self .agent_pos = (new_x , new_y )
77+ self .steps += 1
78+
79+ # Base step penalty
80+ reward = - 0.01
81+ done = False
82+
83+ # Check for goal collision
84+ if self .agent_pos in self .green_goals :
85+ if self .current_goal_type == 0 :
86+ reward = 1.0
87+ else :
88+ reward = - 1.0
89+ done = True
90+ elif self .agent_pos in self .red_goals :
91+ if self .current_goal_type == 1 :
92+ reward = 1.0
93+ else :
94+ reward = - 1.0
95+ done = True
96+
97+ if self .steps >= MAX_STEPS_PER_EP :
98+ done = True
99+
100+ return self ._get_obs (), reward , done
101+
102+ # ---------------------------------------------------------------------------------
103+ # Neural network – simple MLP Q-network -------------------------------------------
104+ # ---------------------------------------------------------------------------------
105+
106+
107+ class QNet (nn .Module ):
108+ def __init__ (self ):
109+ super ().__init__ ()
110+ self .fc1 = nn .Linear (OBS_SIZE , 64 )
111+ self .fc2 = nn .Linear (64 , 64 )
112+ self .out = nn .Linear (64 , NUM_ACTIONS )
113+
114+ def forward (self , x ):
115+ x = torch .tanh (self .fc1 (x ))
116+ x = torch .tanh (self .fc2 (x ))
117+ return self .out (x )
118+
119+ # ---------------------------------------------------------------------------------
120+ # Training exposed to FastAPI via websocket ---------------------------------------
121+ # ---------------------------------------------------------------------------------
122+
123+ POLICIES_DIR = "policies"
124+
125+
126+ def _export_model_onnx (model : nn .Module , path : str ):
127+ dummy = torch .zeros ((1 , OBS_SIZE ), dtype = torch .float32 )
128+ torch .onnx .export (
129+ model ,
130+ dummy ,
131+ path ,
132+ input_names = ["input" ],
133+ output_names = ["output" ],
134+ opset_version = 17 ,
135+ dynamic_axes = {"input" : {0 : "batch" }, "output" : {0 : "batch" }},
136+ )
137+
138+
139+ async def train_gridworld (websocket : WebSocket ):
140+ os .makedirs (POLICIES_DIR , exist_ok = True )
141+
142+ timestamp = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
143+ session_uuid = str (uuid .uuid4 ())[:8 ]
144+ model_filename = f"gridworld_policy_{ timestamp } _{ session_uuid } .onnx"
145+ model_path = os .path .join (POLICIES_DIR , model_filename )
146+
147+ envs : List [GridWorldEnv ] = [GridWorldEnv () for _ in range (12 )]
148+
149+ net = QNet ()
150+ optimizer = optim .Adam (net .parameters (), lr = 3e-4 )
151+ gamma = 0.95
152+ epsilon = 1.0
153+ episodes = 300
154+
155+ for ep in range (episodes ):
156+ obs_list = [env .reset () for env in envs ]
157+ done_flags = [False ] * len (envs )
158+ total_reward = 0.0
159+ ep_loss_accum = 0.0
160+ step_counter = 0
161+
162+ while not all (done_flags ):
163+ obs_batch = torch .tensor (obs_list , dtype = torch .float32 )
164+ if np .random .rand () < epsilon :
165+ actions = np .random .randint (0 , NUM_ACTIONS , size = len (envs ))
166+ else :
167+ with torch .no_grad ():
168+ qvals = net (obs_batch )
169+ actions = torch .argmax (qvals , dim = 1 ).cpu ().numpy ()
170+
171+ next_obs_list = []
172+ rewards = []
173+ dones = []
174+ for idx , env in enumerate (envs ):
175+ if done_flags [idx ]:
176+ next_obs_list .append (obs_list [idx ])
177+ rewards .append (0.0 )
178+ dones .append (True )
179+ continue
180+ nobs , rew , dn = env .step (int (actions [idx ]))
181+ next_obs_list .append (nobs )
182+ rewards .append (rew )
183+ dones .append (dn )
184+
185+ obs_tensor = torch .tensor (obs_list , dtype = torch .float32 )
186+ next_obs_tensor = torch .tensor (next_obs_list , dtype = torch .float32 )
187+ actions_tensor = torch .tensor (actions , dtype = torch .long )
188+ rewards_tensor = torch .tensor (rewards , dtype = torch .float32 )
189+ dones_tensor = torch .tensor (dones , dtype = torch .float32 )
190+
191+ q_pred = net (obs_tensor ).gather (1 , actions_tensor .view (- 1 , 1 )).squeeze ()
192+ with torch .no_grad ():
193+ q_next_max = net (next_obs_tensor ).max (dim = 1 ).values
194+ q_target = rewards_tensor + gamma * (1.0 - dones_tensor ) * q_next_max
195+ loss = ((q_pred - q_target ) ** 2 ).mean ()
196+
197+ optimizer .zero_grad ()
198+ loss .backward ()
199+ optimizer .step ()
200+
201+ ep_loss_accum += float (loss .item ())
202+ step_counter += 1
203+ total_reward += float (np .sum (rewards ))
204+
205+ obs_list = next_obs_list
206+ done_flags = dones
207+
208+ # stream first env state for visualization
209+ env0 = envs [0 ]
210+ await websocket .send_json ({
211+ "type" : "train_step" ,
212+ "state" : {
213+ "agentX" : int (env0 .agent_pos [0 ]),
214+ "agentY" : int (env0 .agent_pos [1 ]),
215+ "gridSize" : env0 .grid_size ,
216+ "greenGoals" : env0 .green_goals ,
217+ "redGoals" : env0 .red_goals ,
218+ "currentGoalType" : int (env0 .current_goal_type ),
219+ },
220+ "episode" : ep + 1 ,
221+ })
222+
223+ await asyncio .sleep (0.01 )
224+
225+ if step_counter >= MAX_STEPS_PER_EP :
226+ break
227+
228+ epsilon = max (0.05 , epsilon * 0.99 )
229+
230+ if (ep + 1 ) % 10 == 0 :
231+ avg_loss = ep_loss_accum / max (1 , step_counter )
232+ await websocket .send_json ({
233+ "type" : "progress" ,
234+ "episode" : ep + 1 ,
235+ "reward" : round (total_reward / len (envs ), 3 ),
236+ "loss" : round (avg_loss , 5 ),
237+ })
238+
239+ _export_model_onnx (net , model_path )
240+
241+ await websocket .send_json ({
242+ "type" : "trained" ,
243+ "file_url" : f"/policies/{ model_filename } " ,
244+ "model_filename" : model_filename ,
245+ "timestamp" : timestamp ,
246+ "session_uuid" : session_uuid ,
247+ })
248+
249+ # ---------------------------------------------------------------------------------
250+ # Inference -----------------------------------------------------------------------
251+ # ---------------------------------------------------------------------------------
252+
253+ _ort_sessions_grid = {}
254+
255+
256+ def infer_action_gridworld (obs : List [float ], model_filename : str | None = None ):
257+ import onnxruntime as ort
258+
259+ if model_filename is None :
260+ files = [f for f in os .listdir (POLICIES_DIR ) if f .startswith ("gridworld_policy_" ) and f .endswith (".onnx" )]
261+ if not files :
262+ raise FileNotFoundError ("No trained gridworld policy files available" )
263+ files .sort (reverse = True )
264+ model_filename = files [0 ]
265+
266+ if model_filename not in _ort_sessions_grid :
267+ model_path = os .path .join (POLICIES_DIR , model_filename )
268+ _ort_sessions_grid [model_filename ] = ort .InferenceSession (model_path , providers = ["CPUExecutionProvider" ])
269+
270+ inp = np .array ([obs ], dtype = np .float32 )
271+ outputs = _ort_sessions_grid [model_filename ].run (None , {"input" : inp })
272+ return int (np .argmax (outputs [0 ]))
0 commit comments