Skip to content

Commit 3bd8cbe

Browse files
committed
Rename and refactor pirate ship to kraken game.
1 parent 427c480 commit 3bd8cbe

4 files changed

Lines changed: 111 additions & 20 deletions

File tree

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def reset(self, seed=None, options=None):
4343
self.ship_healths = np.full(NUM_SHIPS, SHIP_HEALTH)
4444
self.kraken_position = np.array([GRID_SIZE / 2, GRID_SIZE / 2])
4545
self.kraken_health = KRAKEN_HEALTH
46-
self.tentacle_positions = np.random.uniform(0, GRID_SIZE, (NUM_TENTACLES, 2))
46+
self.tentacle_offsets = np.random.uniform(-10, 10, (NUM_TENTACLES, 2))
47+
self.tentacle_positions = self.kraken_position[None, :] + self.tentacle_offsets
4748
return self._get_obs(), {}
4849

4950
def _get_obs(self):
@@ -98,6 +99,7 @@ def step(self, actions):
9899
direction /= (np.linalg.norm(direction) + 1e-8)
99100
self.kraken_position += direction * KRAKEN_SPEED
100101
self.kraken_position = np.clip(self.kraken_position, 0, GRID_SIZE)
102+
self.tentacle_positions = self.kraken_position[None, :] + self.tentacle_offsets
101103

102104
# Check done
103105
if self.kraken_health <= 0 or np.all(self.ship_healths <= 0) or self.steps >= MAX_STEPS:
@@ -118,14 +120,36 @@ class WebSocketCallback(BaseCallback):
118120
def __init__(self, websocket, verbose=0):
119121
super().__init__(verbose)
120122
self.websocket = websocket
123+
self.episode_rewards = []
124+
self.episode_lengths = []
121125

122126
def _on_step(self):
123127
# Send state every few steps
124128
if self.n_calls % 10 == 0:
125129
state = self.training_env.get_attr("get_state_for_viz")[0]()
126130
asyncio.run(self.websocket.send_json({"type": "train_step", "state": state}))
131+
132+
# Collect rewards
133+
done = self.locals['dones'][0]
134+
reward = self.locals['rewards'][0]
135+
if done:
136+
self.episode_rewards.append(self.locals['infos'][0].get('episode', {}).get('r', 0))
137+
self.episode_lengths.append(self.locals['infos'][0].get('episode', {}).get('l', 0))
127138
return True
128139

140+
def _on_rollout_end(self):
141+
if self.episode_rewards:
142+
avg_reward = np.mean(self.episode_rewards)
143+
avg_length = np.mean(self.episode_lengths)
144+
asyncio.run(self.websocket.send_json({
145+
"type": "progress",
146+
"episode": self.num_timesteps // 1000, # Approximate episode count
147+
"reward": float(avg_reward),
148+
"loss": 0.0 # Placeholder, as SB3 doesn't directly provide loss here; can be extended if needed
149+
}))
150+
self.episode_rewards = []
151+
self.episode_lengths = []
152+
129153
async def train_pirate_ship(websocket: WebSocket):
130154
env = make_vec_env(PirateShipEnv, n_envs=8)
131155
model = PPO("MlpPolicy", env, verbose=1)

api/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import logging
4141
from examples.simcity import run_simcity, train_simcity, SimCityEnv
4242
from examples.simcity_deckgl import run_simcity as run_simcity_deckgl, train_simcity as train_simcity_deckgl, SimCityEnv as SimCityDeckGLEnv
43-
from examples.pirate_ship import train_pirate_ship, run_pirate_ship, PirateShipEnv
43+
from examples.kraken import train_pirate_ship, run_pirate_ship, PirateShipEnv
4444

4545
logging.basicConfig(level=logging.INFO)
4646
logger = logging.getLogger(__name__)
@@ -676,8 +676,8 @@ async def websocket_endpoint_simcity_deckgl(websocket: WebSocket):
676676
await websocket.send_json({"type": "error", "message": str(e)})
677677

678678

679-
@app.websocket("/ws/pirate-ship")
680-
async def websocket_pirate_ship(websocket: WebSocket):
679+
@app.websocket("/ws/kraken")
680+
async def websocket_kraken(websocket: WebSocket):
681681
await websocket.accept()
682682
preview_env = PirateShipEnv()
683683
preview_state = preview_env.get_state_for_viz()
@@ -690,7 +690,7 @@ async def websocket_pirate_ship(websocket: WebSocket):
690690
elif data['cmd'] == 'run':
691691
await run_pirate_ship(websocket)
692692
except Exception as e:
693-
print(f"Pirate Ship websocket disconnected: {e}")
693+
print(f"Kraken websocket disconnected: {e}")
694694

695695

696696
if __name__ == "__main__":

client/src/App.jsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import SimCityDeckGLExample from './examples/SimCityDeckGL.jsx';
2020
import FishExample from './examples/Fish.jsx';
2121
import IntersectionExample from './examples/Intersection.jsx';
2222
import SelfDrivingCarExample from './examples/SelfDrivingCar.jsx';
23-
import PirateShip from './examples/PirateShip.jsx';
23+
import KrakenGame from './examples/KrakenGame.jsx';
2424

2525
export default function App() {
2626
return (
@@ -68,7 +68,7 @@ export default function App() {
6868
<Route path="/fish" element={<FishExample />} />
6969
<Route path="/intersection" element={<IntersectionExample />} />
7070
<Route path="/self-driving-car" element={<SelfDrivingCarExample />} />
71-
<Route path="/pirate-ship" element={<PirateShip />} />
71+
<Route path="/kraken" element={<KrakenGame />} />
7272
<Route path="*" element={<Navigate to="/" replace />} />
7373
</Routes>
7474
</BrowserRouter>
Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ import { Canvas, useFrame } from '@react-three/fiber';
33
import { OrbitControls, Text as DreiText, Stars, Line } from '@react-three/drei';
44
import { Button, Text, Card } from '@geist-ui/core';
55
import { Link } from 'react-router-dom';
6+
import InfoPanel from '../components/InfoPanel.jsx';
67
import * as THREE from 'three';
78
import config from '../config.js';
89
import { useResponsive } from '../hooks/useResponsive.js';
910
import { EffectComposer, Bloom } from '@react-three/postprocessing';
1011

11-
const WS_URL = `${config.WS_BASE_URL}/ws/pirate-ship`;
12+
const WS_URL = `${config.WS_BASE_URL}/ws/kraken`;
13+
const GRID_SIZE = 200;
1214

1315
// Water surface component (adapted from example)
1416
const WaterSurface = ({ waterSize }) => {
@@ -71,14 +73,37 @@ const Kraken = ({ position, health }) => {
7173
);
7274
};
7375

74-
// Tentacle component
75-
const Tentacle = ({ position }) => {
76-
return (
77-
<mesh position={[position[0] - GRID_SIZE/2, 1, position[1] - GRID_SIZE/2]}>
78-
<cylinderGeometry args={[0.5, 0.5, 3]} />
79-
<meshStandardMaterial color="darkgreen" />
80-
</mesh>
81-
);
76+
// Tentacle component - PURPLE tentacles emerging from water and reaching UP!
77+
const Tentacle = ({ start, end }) => {
78+
// Safety check for undefined end position
79+
if (!end || !Array.isArray(end) || end.length < 2) {
80+
return null;
81+
}
82+
83+
// Start at water surface (y = 0)
84+
const startAdj = [end[0] - GRID_SIZE/2, 0, end[1] - GRID_SIZE/2];
85+
// End point high in the air (y = 15)
86+
const endAdj = [end[0] - GRID_SIZE/2, 15, end[1] - GRID_SIZE/2];
87+
// Mid-point for nice curve
88+
const midRef = useRef([startAdj[0], 8, startAdj[2]]);
89+
const lineRef = useRef();
90+
91+
useFrame(({ clock }) => {
92+
const time = clock.getElapsedTime();
93+
const wave = Math.sin(time * 2 + (start[0] + start[1])) * 2;
94+
const sway = Math.cos(time * 1.5 + (start[0] + start[1])) * 1.5;
95+
// Animate mid-point for writhing effect
96+
midRef.current[1] = 8 + wave;
97+
midRef.current[0] = startAdj[0] + sway;
98+
midRef.current[2] = startAdj[2] + sway * 0.7;
99+
if (lineRef.current) {
100+
lineRef.current.points = [startAdj, midRef.current, endAdj];
101+
lineRef.current.geometry.setPositions(lineRef.current.points.flat());
102+
}
103+
});
104+
105+
const points = [startAdj, midRef.current, endAdj];
106+
return <Line ref={lineRef} points={points} color="purple" lineWidth={5} />;
82107
};
83108

84109
const StatusPanel = ({ ships, kraken }) => {
@@ -91,26 +116,63 @@ const StatusPanel = ({ ships, kraken }) => {
91116
);
92117
};
93118

94-
export default function PirateShip() {
95-
const [state, setState] = useState(null);
119+
export default function KrakenGame() {
120+
const initialState = {
121+
ships: Array(4).fill().map(() => ({ pos: [Math.random() * 200, Math.random() * 200], health: 100 })),
122+
kraken: { pos: [100, 100], health: 500 },
123+
tentacles: Array(6).fill().map(() => [Math.random() * 200, Math.random() * 200]),
124+
grid_size: 200
125+
};
126+
const [state, setState] = useState(initialState);
96127
const [running, setRunning] = useState(false);
97128
const [training, setTraining] = useState(false);
98129
const [trained, setTrained] = useState(false);
130+
const [logs, setLogs] = useState([]);
131+
const [chartState, setChartState] = useState({ labels: [], rewards: [], losses: [] });
99132
const wsRef = useRef(null);
100133
const { isMobile } = useResponsive();
101134

135+
const addLog = (txt) => {
136+
setLogs((l) => {
137+
const upd = [...l, txt];
138+
return upd.length > 200 ? upd.slice(upd.length - 200) : upd;
139+
});
140+
};
141+
142+
const resetTraining = () => {
143+
setTraining(false);
144+
setTrained(false);
145+
setChartState({ labels: [], rewards: [], losses: [] });
146+
setState(initialState);
147+
addLog('Training has been reset.');
148+
};
149+
102150
useEffect(() => {
103151
const ws = new WebSocket(WS_URL);
104152
wsRef.current = ws;
153+
ws.onopen = () => addLog('WS opened');
105154
ws.onmessage = (ev) => {
106-
const parsed = JSON.parse(ev.data);
155+
addLog(ev.data);
156+
let parsed;
157+
try {
158+
parsed = JSON.parse(ev.data);
159+
} catch {
160+
return;
161+
}
107162
if (parsed.type === 'train_step' || parsed.type === 'run_step') {
108163
setState(parsed.state);
164+
} else if (parsed.type === 'progress') {
165+
setChartState((prev) => ({
166+
labels: [...prev.labels, parsed.episode],
167+
rewards: [...prev.rewards, parsed.reward],
168+
losses: [...prev.losses, parsed.loss ?? null],
169+
}));
109170
} else if (parsed.type === 'trained') {
110171
setTraining(false);
111172
setTrained(true);
112173
}
113174
};
175+
ws.onclose = () => addLog('WS closed');
114176
return () => ws.close();
115177
}, []);
116178

@@ -121,11 +183,14 @@ export default function PirateShip() {
121183
};
122184

123185
const startTraining = () => {
186+
if (training || trained) return;
124187
setTraining(true);
188+
addLog('Starting training...');
125189
send({ cmd: 'train' });
126190
};
127191

128192
const startRun = () => {
193+
if (!trained) return;
129194
setRunning(true);
130195
send({ cmd: 'run' });
131196
};
@@ -146,7 +211,7 @@ export default function PirateShip() {
146211
{state && <Kraken position={state.kraken.pos} health={state.kraken.health} />}
147212

148213
{state && state.tentacles.map((tentacle, i) => (
149-
<Tentacle key={i} position={tentacle} />
214+
<Tentacle key={i} start={state.kraken.pos} end={tentacle} />
150215
))}
151216

152217
<EffectComposer>
@@ -160,7 +225,9 @@ export default function PirateShip() {
160225
<Text h1>Pirate Ship vs Kraken</Text>
161226
<Button type="secondary" disabled={training || trained} onClick={startTraining}>Train</Button>
162227
<Button type="success" disabled={!trained || running} onClick={startRun}>Run</Button>
228+
{trained && <Button type="error" onClick={resetTraining}>Reset</Button>}
163229
</div>
230+
<InfoPanel logs={logs} chartState={chartState} />
164231

165232
{state && (
166233
<div style={{ position: 'absolute', top: 10, right: 10, zIndex: 1 }}>

0 commit comments

Comments
 (0)