Skip to content

Commit 0b2364f

Browse files
committed
Implement ML-Agents WebSocket bridge for web environments.
1 parent e604178 commit 0b2364f

25 files changed

Lines changed: 3402 additions & 13 deletions

api/.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.11

api/README.md

Whitespace-only changes.

api/mlagents_bridge/__init__.py

Whitespace-only changes.

api/mlagents_bridge/env.py

Lines changed: 368 additions & 0 deletions
Large diffs are not rendered by default.

api/mlagents_bridge/server.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
2+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
3+
import asyncio
4+
import json
5+
import logging
6+
import queue
7+
from typing import Dict
8+
9+
logger = logging.getLogger("ThreeJSServer")
10+
11+
def create_server(env_map: Dict[str, "ThreeJSEnv"]):
12+
"""
13+
Creates a FastAPI app with a websocket endpoint that connects to ThreeJSEnvs.
14+
env_map: dict of channel_id -> ThreeJSEnv
15+
"""
16+
app = FastAPI()
17+
18+
@app.websocket("/ws/mlagents")
19+
async def websocket_endpoint(ws: WebSocket):
20+
await ws.accept()
21+
logger.info("Client connected to /ws/mlagents")
22+
23+
# We assume single channel/single environment for now, or use query param?
24+
# Let's assume default channel
25+
env = env_map.get("default")
26+
27+
if not env:
28+
await ws.close(code=1000, reason="No environment found")
29+
return
30+
31+
# Connect queues
32+
# We need async wrappers or run in thread
33+
34+
try:
35+
while True:
36+
# 1. Check for outbound messages from Python (Poll non-blocking?)
37+
# Since we are in async loop, we can't easily wait on a blocking Queue.
38+
# But we can use run_in_executor or just poll with small sleep.
39+
40+
# Check outbound (Python -> Browser)
41+
try:
42+
while True:
43+
msg = env.outbound_queue.get_nowait()
44+
await ws.send_json(msg)
45+
except queue.Empty:
46+
pass
47+
48+
# 2. Check for inbound messages from Browser
49+
# receive_json is awaitable
50+
try:
51+
# We use wait_for to allow polling both directions
52+
data = await asyncio.wait_for(ws.receive_json(), timeout=0.01)
53+
# Put into inbound queue for Python to read
54+
env.inbound_queue.put(data)
55+
except asyncio.TimeoutError:
56+
pass # continue loop
57+
58+
await asyncio.sleep(0.001)
59+
60+
except WebSocketDisconnect:
61+
logger.info("Client disconnected")
62+
env.close()
63+
64+
return app

api/pyproject.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
[project]
2+
name = "api"
3+
version = "0.1.0"
4+
description = "Add your description here"
5+
readme = "README.md"
6+
requires-python = ">=3.11"
7+
dependencies = [
8+
"fastapi>=0.128.0",
9+
"numpy>=2.4.0",
10+
"onnx>=1.20.0",
11+
"onnxruntime>=1.23.2",
12+
"python-multipart>=0.0.21",
13+
"shimmy[gym-v21]>=2.0.0",
14+
"stable-baselines3[extra]>=2.7.1",
15+
"tensorboard>=2.20.0",
16+
"torch>=2.9.1",
17+
"uvicorn>=0.40.0",
18+
"websockets>=15.0.1",
19+
]

api/requirements.txt

Lines changed: 0 additions & 11 deletions
This file was deleted.

api/train_migration.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
2+
import threading
3+
import uvicorn
4+
import time
5+
import queue
6+
import logging
7+
import numpy as np
8+
import uuid
9+
10+
# Configure logging
11+
logging.basicConfig(level=logging.INFO)
12+
logger = logging.getLogger("MigrationDemo")
13+
14+
# Updated Imports
15+
from mlagents_bridge.env import ThreeJSEnv
16+
from mlagents_bridge.server import create_server
17+
from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel
18+
19+
def run_server(env):
20+
"""Starts the uvicorn server in a separate thread."""
21+
env_map = {"default": env}
22+
app = create_server(env_map)
23+
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="warning")
24+
25+
def main():
26+
logger.info("Starting Migration Demo...")
27+
28+
# 1. Create Side Channels
29+
env_params = EnvironmentParametersChannel()
30+
31+
# 2. Create Environment
32+
# We pass the side channel list to the env
33+
env = ThreeJSEnv(channel_id="default", side_channels=[env_params])
34+
35+
# 3. Start Server in Thread
36+
server_thread = threading.Thread(target=run_server, args=(env,), daemon=True)
37+
server_thread.start()
38+
39+
logger.info("Server started on port 8000. Waiting for client to connect...")
40+
41+
# 4. Wait for handshake and Reset
42+
try:
43+
env.reset()
44+
logger.info("Environment reset complete. Handshake successful.")
45+
except Exception as e:
46+
logger.error(f"Failed to reset: {e}")
47+
return
48+
49+
# Set Initial Gravity
50+
env_params.set_float_parameter("gravity", 9.81)
51+
52+
# 5. Run Loop
53+
behavior_name = "Ball3D" # Must match JS
54+
55+
try:
56+
for i in range(10000): # Run longer
57+
# Get decision steps
58+
decision_steps, terminal_steps = env.get_steps(behavior_name)
59+
60+
# Log Obs Shapes occasionally
61+
if i % 100 == 0:
62+
n_agents = len(decision_steps)
63+
logger.info(f"Step {i}: Agents={n_agents}")
64+
for idx, obs in enumerate(decision_steps.obs):
65+
logger.info(f" Obs[{idx}] Shape: {obs.shape}")
66+
67+
# Create random actions
68+
if behavior_name in env.behavior_specs:
69+
spec = env.behavior_specs[behavior_name]
70+
n_agents = len(decision_steps)
71+
72+
if n_agents > 0:
73+
action = spec.action_spec.random_action(n_agents)
74+
env.set_actions(behavior_name, action)
75+
76+
# Step env
77+
env.step()
78+
79+
# Test Side Channel
80+
if i % 100 == 0:
81+
new_g = 5.0 + np.random.rand() * 10.0
82+
logger.info(f"Step {i}: Setting Gravity to {new_g:.2f}")
83+
env_params.set_float_parameter("gravity", float(new_g))
84+
85+
except KeyboardInterrupt:
86+
logger.info("Stopping...")
87+
finally:
88+
env.close()
89+
90+
if __name__ == "__main__":
91+
main()

0 commit comments

Comments
 (0)