New metrics for v5 and planning v9
Passed on MuJoCo Warp (too much JAX rewrite for the time left), added action and tracking-lag metrics, and used them to characterise v5 as "noisy but on-time" versus classical's "clean but late". Set the stage for v9 with an action delta penalty.
- research
What I’m trying to achieve
Clean up the code-base, move training to GPU (if I can), see if I can quantify jerkiness to improve upon it for the RL approach.
Context
Yesterday, I spent a long time training new models. I believe I spent 7 hours total training 7 different models. It was also my intro to tuning RL hyperparameters, which was very mentally taxing yet rewarding. After some rest and taking some time to digest what I learned, I feel much more confident in what I’m doing.
Today, I was walking around campus and met my teacher, Miti, in the Engineering Physics project lab. I spoke with him about this project and he highly encouraged me to look into a recent update to MuJoCo which allows you to host physics on your GPU. This could potentially allow me to train tens of thousands of times faster so it’s worth looking into.
Research
MuJoCo Warp (MJWarp) seems to be what Miti was talking about. He mentioned the motivation was Google DeepMind needed it, so they built it. What exactly is it?
(Doing some research…)
I see what’s happening. I can’t use this. MuJoCo Warp is coupled with JAX. This means my current PyTorch setup won’t work. I’d also have to rewrite:
env.pyto remove NumPy and use JAX functionstrajectory.pyto precompute paths instead of generate per-resetik.pyto map over some linear algebra to JAX
Although it would save me tens of hours in compute time when I explore more complicated RL options, right now it’s simply not smart to port over.
Instead, I want to benchmark what I have, and see what happens when I push the limits. (Maybe sneak one or two more new models too.)
Improving jitter
So thankfully I had the foresight to push my changes to GitHub before I started messing around with v6, v7, and v8 termination models. (More like terminator models.)
git add best_model_v6.zip best_model_v7.zip best_model_v8.zip
git commit -m "archiving v6-v8 collision-avoidance experiments (failed)"
git restore ../../env.py ../../train.py
And we’re back to our best policy.
Re-measuring performance
I want to improve what we’re doing to track the performance of the v5 model.
Currently I’m logging jerk using Pascal’s triangle for finite differences:
(p[i+3] - 3*p[i+2] + 3*p[i+1] - p[i]) / dt³
## in the code base this is:
j = (ee[3:] - 3 * ee[2:-1] + 3 * ee[1:-2] - ee[:-3]) / dt**3
I want to also log:
- The per-step lead vector coming from RL
- The average magnitude of the lead per step (will be zero for ikpd)
- The average change of the lead (hopefully will give us insight into the jitter)
- The 95th percentile of the action delta (I want to see if the mean is hiding outliers)
- The tracking lag (time shift between target and EE)
For 5, I’ll explain a little more. We need to scan around values of tau which minimize:
||target(t) - ee(t+τ)|| mean over the whole run.
Then we will log:
- tracking_lag_ms (the τ itself)
- min_error_at_lag_cm (the mean error we found after applying the time shift)
This allows us to learn about the shape of tracking and the timing of the lead. This table may be helpful to see why:
| τ | min_error | Meaning |
|---|---|---|
| low | low | Great. Right path, right time. |
| low | high | Wrong shape. Time alignment didn’t save it. |
| high | low | Right path, just late. Pure lag problem. |
| high | high | Bad on both axes. |
uv run python collect.py --seeds 100-299 --model archive/old_models/best_model_v5.zip
Here are the results:
=== ALL (n=200) ===
────────────────────────────────────────────────────────────────────────────
Metric Classical (μ±σ) RL v5 (μ±σ) Δμ
────────────────────────────────────────────────────────────────────────────
Mean error (cm) 12.26 ± 6.66 8.28 ± 4.50 -3.97
RMSE error (cm) 17.27 ± 8.40 14.39 ± 6.94 -2.88
Max error (cm) 76.66 ± 27.25 76.62 ± 27.42 -0.04
Time in 5cm band (%) 24.06 ± 30.22 54.17 ± 26.26 +30.11
Err in-reach (cm) 12.27 ± 6.66 8.26 ± 4.47 -4.01
Err out-of-reach (cm) (22/200) 12.44 ± 9.58 10.47 ± 8.00 -1.97
Jerk RMS (m/s³) 520.68 ± 317.94 773.06 ± 388.80 +252.38
Lead norm (mean) 0.00 ± 0.00 0.65 ± 0.25 +0.65
Lead Δ norm (mean) 0.00 ± 0.00 0.04 ± 0.02 +0.04
Lead Δ norm (p95) 0.00 ± 0.00 0.13 ± 0.08 +0.13
Tracking lag (ms) 113.45 ± 11.90 17.65 ± 33.79 -95.80
Err at best lag (cm) 4.67 ± 4.80 7.86 ± 4.20 +3.18
=== CIRCLE (n=64) ===
────────────────────────────────────────────────────────────────────────────
Metric Classical (μ±σ) RL v5 (μ±σ) Δμ
────────────────────────────────────────────────────────────────────────────
Mean error (cm) 15.03 ± 9.55 9.79 ± 6.53 -5.24
RMSE error (cm) 19.61 ± 11.41 15.99 ± 9.00 -3.61
Max error (cm) 75.15 ± 25.62 75.25 ± 25.44 +0.11
Time in 5cm band (%) 16.10 ± 31.94 50.68 ± 30.54 +34.58
Err in-reach (cm) 15.00 ± 9.53 9.73 ± 6.49 -5.27
Err out-of-reach (cm) (4/64) 24.75 ± 6.14 18.77 ± 8.07 -5.98
Jerk RMS (m/s³) 630.96 ± 400.27 844.86 ± 458.67 +213.90
Lead norm (mean) 0.00 ± 0.00 0.72 ± 0.31 +0.72
Lead Δ norm (mean) 0.00 ± 0.00 0.04 ± 0.02 +0.04
Lead Δ norm (p95) 0.00 ± 0.00 0.14 ± 0.09 +0.14
Tracking lag (ms) 114.38 ± 11.84 14.22 ± 36.18 -100.16
Err at best lag (cm) 6.49 ± 7.19 9.23 ± 6.04 +2.74
=== FIG8 (n=70) ===
────────────────────────────────────────────────────────────────────────────
Metric Classical (μ±σ) RL v5 (μ±σ) Δμ
────────────────────────────────────────────────────────────────────────────
Mean error (cm) 8.71 ± 3.73 6.36 ± 2.70 -2.35
RMSE error (cm) 13.44 ± 5.78 12.07 ± 5.74 -1.37
Max error (cm) 68.93 ± 25.85 69.15 ± 26.34 +0.21
Time in 5cm band (%) 39.69 ± 34.74 68.17 ± 23.22 +28.48
Err in-reach (cm) 8.65 ± 3.63 6.28 ± 2.56 -2.37
Err out-of-reach (cm) (7/70) 14.89 ± 8.91 12.37 ± 8.56 -2.51
Jerk RMS (m/s³) 424.87 ± 206.90 669.21 ± 341.14 +244.34
Lead norm (mean) 0.00 ± 0.00 0.51 ± 0.21 +0.51
Lead Δ norm (mean) 0.00 ± 0.00 0.03 ± 0.02 +0.03
Lead Δ norm (p95) 0.00 ± 0.00 0.11 ± 0.08 +0.11
Tracking lag (ms) 114.14 ± 13.89 21.43 ± 38.92 -92.71
Err at best lag (cm) 2.98 ± 2.43 6.06 ± 2.68 +3.08
=== FLY (n=66) ===
────────────────────────────────────────────────────────────────────────────
Metric Classical (μ±σ) RL v5 (μ±σ) Δμ
────────────────────────────────────────────────────────────────────────────
Mean error (cm) 13.33 ± 3.04 8.86 ± 2.40 -4.47
RMSE error (cm) 19.06 ± 5.30 15.29 ± 4.83 -3.77
Max error (cm) 86.33 ± 27.31 85.87 ± 27.73 -0.46
Time in 5cm band (%) 15.19 ± 9.97 42.71 ± 16.19 +27.52
Err in-reach (cm) 13.46 ± 3.10 8.93 ± 2.44 -4.53
Err out-of-reach (cm) (11/66) 6.41 ± 4.90 6.24 ± 3.55 -0.17
Jerk RMS (m/s³) 515.35 ± 289.82 813.56 ± 335.53 +298.22
Lead norm (mean) 0.00 ± 0.00 0.74 ± 0.12 +0.74
Lead Δ norm (mean) 0.00 ± 0.00 0.04 ± 0.01 +0.04
Lead Δ norm (p95) 0.00 ± 0.00 0.15 ± 0.07 +0.15
Tracking lag (ms) 111.82 ± 9.20 16.97 ± 23.61 -94.85
Err at best lag (cm) 4.72 ± 2.70 8.44 ± 2.23 +3.72
What does this mean?
Here’s the most important stuff:
| Metric | Classical | RL v5 | Δ |
|---|---|---|---|
| Mean error | 12.26 | 8.28 | -3.97 cm |
| In-band (%) | 24% | 54% | +30 pp |
| Tracking lag | 113 ms | 18 ms | -96 ms |
| Err at best lag | 4.67 | 7.86 | +3.18 cm |
| Jerk RMS | 521 | 773 | +252 |
- Classical: 113 ms of trailing, but if you time-align it, residual error drops to 4.67 cm. The shape is excellent, it’s just slow.
- RL v5: only 18 ms of lag (it actually learned to lead, as designed). But at best lag, error is 7.86 cm. So RL’s shape is noisier.
In other words, the classical approach traces a beautiful but late path. RL traces a wobbly but on-time path. If I can reduce wobble, I can blow classical out of the water. I’m going to add that action delta penalty I’ve been talking about.
Driving questions
- How big should the penalty be for large action deltas?
- Will the action delta penalty reduce jitter without trading off tracking accuracy?
Next
- Train our final class of RL: the no-jitter models beginning with
v9