Splat Trainer — Reconstructing 3D Gaussian Splats from Taos Renders
Run the sample · splat_trainer.ts
Most of the engine takes 3D geometry and produces pixels. This sample runs that arrow backwards: it takes pixels — a handful of rendered views of a scene — and reconstructs the 3D scene as a cloud of 3D Gaussian splats by optimization. It is a from-scratch 3D Gaussian Splatting trainer built inside Taos, the inverse of the existing splat viewer (Gaussian Splatting).
The trick that makes it tractable: because Taos renders the training views itself, the camera poses are known exactly. Real-world 3DGS spends a whole Structure-from-Motion (COLMAP) stage recovering poses from photos; here that stage simply doesn't exist. The dataset is clean by construction — image + the exact view/projection that produced it — so all that's left is the optimization.
record (known poses) optimize (per step) export / view
┌───────────────────┐ ┌───────────────────────────────┐ ┌─────────────┐
│ fly cam → testbed │ │ project → tile-bin → forward │ │ params → │
│ render → image + │──►│ composite → L1+D-SSIM loss │──►│ SplatData → │
│ exact view/proj │ │ → backward scatter → Adam │ │ Gaussian- │
│ → FrameDataset │ │ + densify (prune/grow/resample)│ │ SplatPass │
└───────────────────┘ └───────────────────────────────┘ └─────────────┘
What a splat is, and what we optimize#
Each splat is an anisotropic 3D Gaussian with a handful of trainable parameters,
packed at stride 14 floats
(splat_diff_aniso.ts):
| Param | Floats | Stored as | Activation |
|---|---|---|---|
| Position | 3 | world xyz | — |
| Scale | 3 | log(scale) |
exp |
| Rotation | 4 | raw quaternion | normalized |
| Opacity | 1 | logit | sigmoid |
| Colour (SH-DC) | 3 | DC term | SH_C0·dc + 0.5 |
Everything is stored in its unconstrained form (log-scale, logit-opacity, raw
quaternion) so plain Adam can step it without ever violating a constraint —
scales stay positive, opacity stays in [0,1], the quaternion is normalized at
use. Only the DC band of spherical harmonics is trained: the testbed scenes are
Lambertian, so view-dependent colour buys nothing and the higher SH bands are
deliberately deferred.
The differentiable core, verified on the CPU first#
The heart of the trainer is an analytic backward pass: given the loss gradient on the rendered image, scatter it back through the alpha-composite, the 2D conic, the 2D covariance, the 3D covariance, and finally onto each splat's scale and quaternion. That chain is easy to get subtly wrong, so it was written twice.
The reference is a plain-TypeScript implementation in
splat_diff_aniso.ts: the
forward mirrors the viewer's splat_preprocess.wgsl, and the backward is the
hand-derived analytic gradient. It is checked against finite differences for
every parameter in tests/splats/ — perturb one float,
measure the loss delta, compare to the analytic gradient. The D-SSIM image
gradient (splat_ssim.ts) is
FD-verified the same way.
The GPU trainer
(splat_trainable_aniso.ts
splat_train_aniso_shaders.ts) mirrors that CPU reference exactly. To keep them honest there are two paths by design:
gradCheck()runs the non-tiled forward/backward with a global depth sort and pure L1 loss — this is geometry-identical to the CPUcomputeLossAndGrad, so any divergence is a real bug.step()runs the tiled training path with the full0.8·L1 + 0.2·D-SSIMloss. This is what actually trains.
The split matters: the math is validated on the simple path; the fast path only adds binning, never new gradient math.
Atomic gradient accumulation#
The backward pass has many pixels scattering into the same splat's gradient slot,
so it needs atomic adds — but WebGPU has no native atomic<f32>. Two strategies
are implemented as a selectable shader variant
(grad_atomic.ts): a compare-and-swap
loop on the bit-pattern, or a fixed-point integer accumulator. Both produce
the same gradients; the variant is chosen at construction.
Tile rasterization — the resolution unlock#
Early versions composited every splat against every pixel after a single global depth sort. That works but caps training resolution at ~160×120 before it's too slow, and training resolution is the hard ceiling on detail — splats can only get as fine as the images that supervise them. When the recon was magnified into a ~1500px 3D view, the coarseness showed as fuzz.
The fix is the standard 3DGS tile rasterizer, ported to the trainer
(splat_train_aniso_shaders.ts):
- project every splat → screen mean, 2D conic, depth, bbox.
- tileScatter — each splat writes one
(tileID, depth)pair per 16×16 tile its bbox touches, reserving an exact contiguous block with a singleatomicAddinto a shared pair counter. - sort the pairs by key
(tileID << 20) | (depth bits)with the GPU radix sort, so each tile's splats end up contiguous and depth-ordered. - tileRanges — find each tile's
[start, end)span in the sorted pairs. - tiledForward / tiledBackward — each pixel composites only the ~150 splats in its tile, not the whole cloud.
This made 256×192 training affordable, and the cubes went from fuzzy to
crisp. One subtlety worth recording: the per-splat pair budget is a shared
allocator (PAIR_BUDGET=16/splat), which fixed an earlier "dark tiles at the
bottom" bug where big close-up ground splats were clipped by a fixed per-splat
cap.
The loss#
loss = 0.8 · L1(C, target) + 0.2 · D-SSIM(C, target)
L1 drives raw colour accuracy; D-SSIM drives local structure (it compares
windowed means, variances, and covariance, so it rewards getting edges and
gradients right, not just average colour). The D-SSIM gradient is computed in two
passes — ssimPqs builds the windowed statistics maps, ssimGrad turns them
into a per-pixel image gradient — which is then added to the L1 image gradient
before the backward scatter. Both contributions flow through the same analytic
backward, so the geometry gradients are shared.
Adaptive densification#
A fixed splat cloud can't reconstruct a scene it didn't start with the right
splats for, so the topology adapts. Buffers are fixed-capacity (30 000) with
a mutable active count; the GPU accumulates a per-splat signal each step and
densify() periodically reads it back on the CPU, rewrites the cloud, and
re-uploads. Doing the heuristics in JS keeps them debuggable while the GPU work
stays simple — and densify never touches the gradient path, so the FD checks
stay valid.
Each cycle does, in order:
- Prune splats that are transparent, oversized, or too needle-like (aspect cap).
- Clone / split the top-gradient fraction: a small splat in an
under-reconstructed region is cloned (duplicated in place); a large one is
split into two smaller, offset copies. The signal is the running mean of the
screen-space position gradient (
gradAccum) — a scale-invariant quantile, so there's no magic absolute threshold to tune. - Resample (brush-style) — if pruning dropped the count, refill the freed
budget by duplicating survivors sampled ∝
opacity × visibility, so the recovered budget lands on splats that are opaque and actually on-screen rather than bleeding away. Visibility is accumulated on the GPU asopacity · √det(opacity times projected footprint), and each resampled copy gets a sub-σ positional jitter plus the opacity-preserving split1 − √(1−a)so a duplicated pair doesn't read as twice as dense. - Decay regularization (brush-style) — nudge every kept splat slightly more
transparent and slightly smaller, with strength annealed
1 → 0over training (1 − iter/total). Floaters starve, over-large splats contract; the optimizer regrows whatever genuinely contributes.scaleDecayis kept low (0.01) so splats don't shrink into visible gaps. - Opacity reset — periodically clamp opacity back toward ~0.01, the classic 3DGS escape from bad local minima where the scene hides behind a few opaque blobs.
Learning-rate decay#
The optimizer uses an annealing schedule on the geometry learning rates only.
A public lrScale multiplies the position / log-scale / rotation LRs, driven from
the sample by lrScale = LR_DECAY_FINAL ^ (iter/LR_DECAY_ITERS) (i.e. 1× → 0.02×
exponentially). Colour and opacity LRs stay at full strength so appearance keeps
adapting late. The intuition: early on you want geometry to move freely to find
the rough shape; late, you want it to settle instead of jittering, because
residual geometry jitter on curved surfaces reads as fuzz.
View coverage#
The self-driving ?auto=<sceneId> mode synthesizes the dataset by orbiting the
scene: 5 elevation rings × 16 azimuths + a top-down cap = 81 views, spanning a
near-horizon ring (≈3°) up to a steep ≈75° ring. The angular spread is
deliberate — curved surfaces "fur" from the directions between and beyond the
training cameras, so the near-horizon ring constrains the undersides (where splats
were otherwise spiking down onto the ground) and the steep rings plus the cap pin
the tops.
How to verify#
Because reconstruction quality is a visual judgement, the loop is: unit tests confirm the math, then a human confirms the picture.
- Math —
npx vitest run tests/splats/runs the FD gradient checks for the anisotropic forward/backward, the D-SSIM gradient, and the atomic-add variants. - GPU, headless — drive
splat_trainer.html?auto=primitiveswith the webgpu-inspector MCP tools and read dispatch sizes: the adam dispatch isceil(count/64), which reveals the live splat count; forward/backward/tiled areceil(W/8)×ceil(H/8). Validation-error count should stay zero. - Quality — eyeball the target-vs-recon preview thumbnails and the orbitable
3D view, plus the per-cycle densify log
(
pruned / cloned / split / resampled).
Quality progression so far#
The reconstruction went blurry → recognizable → sharp cubes + furry spheres. The drivers we found:
- View-3D fuzz was dominated by training resolution, which is why tiling (to afford 256×192) was the big unlock — the recon preview was already decent at low res; the coarseness only appeared under magnification.
- D-SSIM helps structure but only at training resolution.
- Denser, wider-elevation views reduce fur on curved surfaces.
- The remaining sphere "fur" is the classic 3DGS artifact: curved surfaces + under-constrained between-view angles + magnification produce radial spikes from elongated, lightly-supervised splats.
What we could try next to improve quality#
Roughly in increasing order of effort. Several are lifted from brush, a cross-platform WebGPU/Burn 3DGS trainer solving the same problem.
Cheap tuning levers (no code, just constants)
- Higher training resolution again (e.g. 320×240) now that tiling makes it affordable — detail ceiling rises directly.
- More splats (raise capacity / densify fraction) for the curved surfaces.
- Tighter aspect cap (Adam
maxAspect3 → 2) to kill the elongated radial "spike" splats that are the fur. - Balance the decay reg:
scaleDecaytoo high opens speckle gaps, too low lets fur survive — it's the main dial for the sphere surfaces.
Optimizer / refinement (moderate)
- True occlusion-aware visibility for resampling: accumulate the exact
T·alphablend weight per splat in the tiled backward (needs a fixed-point atomic, like the gradient path) instead of the current footprint proxy. Matters only once scenes have real occlusion. - Proper 11×11 Gaussian SSIM window. The current D-SSIM uses a box-window approximation; brush uses a true Gaussian 11×11 window with a haloed tile, which sharpens the structural term.
- SH higher bands — only worth it once scenes have view-dependent shading; the testbed is Lambertian, so DC is sufficient today.
Bigger architectural bets
- MCMC refinement (brush 0.3): replace gradient-threshold clone/split with a Markov-Chain-Monte-Carlo relocation scheme that explores the scene (gradient-free moves + noise injection), keeping the count budget fixed while still growing where needed. The current densification's main weakness is that it can only grow where the gradient already points; MCMC can discover under-served regions.
- Anti-aliased / EWA splatting with a per-frame mip-style low-pass, which reduces shimmer and the silhouette fur under magnification.
- LPIPS perceptual loss for a final polish pass — heavy (a VGG forward per step) and overkill for these simple scenes, but the standard last-mile quality lever for photoreal datasets.
Beyond the testbed
- Capture from a real engine scene (e.g.
geo_osm_buildings) instead of the Lambert testbed. This is the large deferred item: it needs the full engine render path, HDR canvas read-back, and streaming / floating-origin handling, so it was deprioritized in favour of getting the anisotropic math and quality right first.