Blog

Back to blog
Building a Real-Time Neural Network Visualizer with React Three Fiber
Three.jsWebGLMachine LearningReact

Building a Real-Time Neural Network Visualizer with React Three Fiber

|8 min read

How I built an interactive 3D visualization of neural network inference using WebGL, ONNX Runtime, and React Three Fiber.


Building a Real-Time Neural Network Visualizer

This has been one of my favorite projects to work on. Neural networks always get called "black boxes" and I wanted to actually see what happens when one processes an image. Not some simplified diagram, but real data from a real model running in the browser.

What It Does

Upload an image and watch MobileNetV2 process it in real-time. You'll see:

  • A glowing signal packet traveling through the network layers
  • Each layer lighting up based on how strongly it activates
  • The top predictions with confidence scores
  • Feature maps showing what patterns each layer detected
  • An attention heatmap showing where the network focused

The whole thing runs client-side with no backend. Your images never leave your browser.

How It Actually Works

Here's the real flow when you upload an image:

  1. Preprocess - Resize to 224x224, normalize with ImageNet stats
  2. Run Inference - ONNX Runtime executes MobileNetV2 via WebAssembly
  3. Extract Activations - Pull intermediate outputs from each layer (not just the final prediction)
  4. Compute Metrics - Calculate energy and sparsity for each layer from the raw activation data
  5. Animate - Play the forward pass animation with the signal packet weaving through layers

The key insight is that modern ONNX models can expose intermediate layer outputs, not just the final classification. I modified MobileNetV2 to output activations from every major block.

The Tech Stack

ONNX Runtime Web handles the actual neural network inference. It runs the model via WebAssembly which is surprisingly fast for browser-based ML. The model loads once on page load, then inference takes around 100-200ms per image depending on your device.

React Three Fiber renders the 3D visualization. Each network layer is a grid of instanced sphere meshes. Instancing is important here because we're rendering hundreds of nodes and instanced meshes share geometry, cutting down on draw calls.

Custom Animation System controls the signal packet that travels through the network. It uses requestAnimationFrame for smooth 60fps animation. The packet's path through each layer is seeded by the inference result, so different images produce different paths.

Extracting Data from the Model

Image Preprocessing

Before feeding an image to MobileNetV2, I need to transform it into the format the model expects. This means resizing to 224x224, normalizing with ImageNet statistics using z-score normalization, and converting to NCHW tensor format:

utils/imagePreprocess.ts
const INPUT_SIZE = 224;
const IMAGENET_MEAN = [0.485, 0.456, 0.406];
const IMAGENET_STD = [0.229, 0.224, 0.225];
 
// Get pixel data from canvas
const imageData = ctx.getImageData(0, 0, INPUT_SIZE, INPUT_SIZE);
const pixels = imageData.data;
 
// Convert to normalized NCHW Float32Array
const float32Data = new Float32Array(1 * 3 * INPUT_SIZE * INPUT_SIZE);
 
for (let c = 0; c < 3; c++) {
  const mean = IMAGENET_MEAN[c];
  const std = IMAGENET_STD[c];
 
  for (let h = 0; h < INPUT_SIZE; h++) {
    for (let w = 0; w < INPUT_SIZE; w++) {
      // Source: RGBA interleaved
      const srcIdx = (h * INPUT_SIZE + w) * 4 + c;
      // Destination: NCHW format (channel-first)
      const dstIdx = c * INPUT_SIZE * INPUT_SIZE + h * INPUT_SIZE + w;
 
      // Normalize: (pixel/255 - mean) / std
      float32Data[dstIdx] = (pixels[srcIdx] / 255 - mean) / std;
    }
  }
}
 
return new ort.Tensor('float32', float32Data, [1, 3, INPUT_SIZE, INPUT_SIZE]);

Running Inference and Extracting Layer Outputs

The key to this whole visualization is that I modified MobileNetV2 to output intermediate activations, not just the final classification. When inference runs, I get back tensors from every major layer:

hooks/useOnnxInference.ts
// Create inference session
const session = await ort.InferenceSession.create(MODEL_PATH, {
  executionProviders: ['wasm'],
  graphOptimizationLevel: 'all',
});
 
// Run inference
const outputs = await session.run({ input: inputTensor });
 
// Extract activations from each layer output
const activations: LayerActivation[] = [];
const featureMaps: FeatureMap[] = [];
 
// Layers to extract raw feature maps from
const featureMapLayers = ['layer_conv0', 'layer_block2', 'layer_block3', 'layer_conv_final'];
 
for (const config of LAYER_CONFIGS) {
  const outputTensor = outputs[config.onnxOutputName];
  if (outputTensor) {
    const data = outputTensor.data as Float32Array;
    const metrics = calculateMetrics(data, config.id);
    activations.push(metrics);
 
    // Store raw feature map data for visualization layers
    if (featureMapLayers.includes(config.onnxOutputName)) {
      const [height, width, channels] = config.outputShape;
      featureMaps.push({
        layerId: config.id,
        data: new Float32Array(data),
        width,
        height,
        channels,
      });
    }
  }
}

Getting Predictions with Softmax

The final layer outputs raw logits (unnormalized scores). I apply the softmax function to convert these to probabilities and grab the top 5 predictions. The trick of subtracting the max value before computing exponentials prevents numerical overflow, this is why languauges like C are superior. It would have been easier to debug.

utils/tensorOps.ts
function softmaxTopK(logits: Float32Array, k: number) {
  // Find max for numerical stability
  let maxLogit = -Infinity;
  for (let i = 0; i < logits.length; i++) {
    if (logits[i] > maxLogit) maxLogit = logits[i];
  }
 
  // Compute exp and sum
  const expValues = new Float32Array(logits.length);
  let expSum = 0;
  for (let i = 0; i < logits.length; i++) {
    const expVal = Math.exp(logits[i] - maxLogit);
    expValues[i] = expVal;
    expSum += expVal;
  }
 
  // Create {index, probability} array and sort
  const indexed = [];
  for (let i = 0; i < expValues.length; i++) {
    indexed.push({
      index: i,
      probability: expValues[i] / expSum,
    });
  }
 
  indexed.sort((a, b) => b.probability - a.probability);
  return indexed.slice(0, k);
}

Computing Real Metrics

The energy and sparsity values you see come straight from the activation tensors. Here's the actual calculation:

utils/tensorOps.ts
function calculateMetrics(data: Float32Array, layerId: string, sparsityThreshold = 0.01) {
  const length = data.length;
 
  let sumSquares = 0;
  let sum = 0;
  let nearZeroCount = 0;
  let maxVal = 0;
 
  // Single pass through data for efficiency
  for (let i = 0; i < length; i++) {
    const val = data[i];
    const absVal = Math.abs(val);
 
    sumSquares += val * val;
    sum += val;
 
    if (absVal < sparsityThreshold) {
      nearZeroCount++;
    }
    if (absVal > maxVal) {
      maxVal = absVal;
    }
  }
 
  // RMS normalized to roughly 0-1 range
  const rms = Math.sqrt(sumSquares / length);
  const normalizedEnergy = Math.min(1, rms / 2);
 
  return {
    layerId,
    energy: normalizedEnergy,
    sparsity: nearZeroCount / length,
    maxActivation: maxVal,
    meanActivation: sum / length,
  };
}

Energy is the normalized RMS (root mean square) of all activation values. High energy means lots of neurons firing strongly.

Sparsity measures the fraction of near-zero activations. High sparsity means most neurons aren't firing, which is typical for later layers that filter down to specific features.

The Feature Maps

When you look at the Feature Maps panel, you're seeing actual 2D slices of the activation tensors from early convolutional layers. I find the most activated channels and render them as heatmaps with a viridis colormap:

utils/featureMapRenderer.ts
// Find top N channels by mean activation
function findTopChannels(featureMap: FeatureMap, topN: number = 6): number[] {
  const { width, height, channels, data } = featureMap;
  const channelSize = width * height;
 
  const channelMeans = [];
  for (let c = 0; c < channels; c++) {
    let sum = 0;
    const offset = c * channelSize;
    for (let i = 0; i < channelSize; i++) {
      sum += Math.abs(data[offset + i]);
    }
    channelMeans.push({ index: c, mean: sum / channelSize });
  }
 
  channelMeans.sort((a, b) => b.mean - a.mean);
  return channelMeans.slice(0, topN).map((c) => c.index);
}
 
// Render a channel to canvas with viridis colormap
function renderChannelToImageData(channelData: Float32Array, width: number, height: number) {
  // Find min/max for normalization
  let min = Infinity, max = -Infinity;
  for (let i = 0; i < channelData.length; i++) {
    if (channelData[i] < min) min = channelData[i];
    if (channelData[i] > max) max = channelData[i];
  }
  const range = max - min || 1;
 
  const imageData = new ImageData(width, height);
  for (let i = 0; i < channelData.length; i++) {
    const normalized = (channelData[i] - min) / range;
    const [r, g, b] = getViridisColor(normalized);
 
    imageData.data[i * 4] = r;
    imageData.data[i * 4 + 1] = g;
    imageData.data[i * 4 + 2] = b;
    imageData.data[i * 4 + 3] = 255;
  }
  return imageData;
}

Early layers detect simple stuff like edges and gradients. Deeper layers pick up more complex patterns.

The Attention Map

The attention heatmap uses a simplified Class Activation Mapping (CAM) approach. It takes the final convolutional layer's activations and creates a weighted sum based on channel importance:

utils/camCompute.ts
function computeSimplifiedCAM(finalConvFeatures: FeatureMap): Float32Array {
  const { width, height, channels, data } = finalConvFeatures;
  const spatialSize = width * height;
 
  const heatmap = new Float32Array(spatialSize);
  const channelWeights = new Float32Array(channels);
 
  // Weight channels by mean activation
  for (let c = 0; c < channels; c++) {
    let sum = 0;
    const offset = c * spatialSize;
    for (let i = 0; i < spatialSize; i++) {
      sum += Math.max(0, data[offset + i]); // ReLU output
    }
    channelWeights[c] = sum / spatialSize;
  }
 
  // Normalize weights and compute weighted spatial sum
  // ... (normalization code)
 
  for (let c = 0; c < channels; c++) {
    const weight = channelWeights[c];
    if (weight < 0.1) continue; // Skip low-weight channels
 
    const offset = c * spatialSize;
    for (let i = 0; i < spatialSize; i++) {
      heatmap[i] += Math.max(0, data[offset + i]) * weight;
    }
  }
 
  // Normalize heatmap to 0-1 and return
  return heatmap;
}

The heatmap is then upsampled using bilinear interpolation to match the input image size, and rendered with a jet colormap where red areas are where the network focused and blue areas were mostly ignored.

This is genuinely useful for understanding predictions. If the network classifies something wrong, the attention map often shows it was looking at the wrong part of the image.

Understanding the UI

Predictions Panel - Top 5 ImageNet classes with confidence percentages. The sparkline graph shows activation energy across all layers.

Activation Log - Shows energy and sparsity bars for each layer as the signal passes through. You can see how information transforms at each stage.

Feature Maps - 2D heatmaps from early convolutional layers. These are the patterns the network detected in your image.

Attention Map - Shows your original image alongside the CAM heatmap so you can see where the network looked.

3D Visualization - Each layer is assigned a distinct color (see table below). Node brightness indicates activation energy. The cyan signal packet shows the forward pass in action.

Limitations and Scale

Here's the thing nobody tells you about neural network visualizations: you literally cannot show the real thing. MobileNetV2 has millions of parameters and the actual neuron counts are insane:

LayerReal NeuronsDisplayed NodesScale
Conv0401,408 (112×112×32)1200.03%
Block 712,544 (14×14×64)1691.3%
Final Conv62,720 (7×7×1280)3240.5%
Classifier1,00027027%

If I rendered every neuron in the first convolutional layer alone, you'd have 400,000+ spheres. Your GPU would catch fire and the visualization would be an unreadable pile of shit. Even with instanced meshes, there's a practical limit.

So what you're actually seeing is a symbolic representation. Each glowing node represents thousands of real neurons. The activation energy and colors are real though. They come from aggregating the actual tensor data. When a layer lights up bright, it means the underlying neurons (all 12,000+ of them in some cases) are firing strongly on average.

This is the tradeoff with any neural network visualization. You can either show accurate scale (and see nothing useful) or show a digestible representation that preserves the important patterns. I went with the second option.

What I Learned

  1. ONNX Runtime is production-ready. The WebAssembly backend runs MobileNetV2 at interactive speeds in the browser.

  2. Instanced meshes matter. Without instancing, rendering hundreds of nodes would tank performance.

  3. Extracting intermediate activations is powerful. You can learn so much more about a model when you can see inside it, not just its final output.

  4. Responsive 3D is hard. Making this work on mobile meant completely rethinking the layout, not just scaling things down.

The Math Behind It

If you want to dive deeper into the techniques used:

Concepts:

Original Papers:

Try It Yourself

Upload any image and watch the network think. Some things to try:

  • Compare attention maps between correct and incorrect predictions
  • Watch how early vs late layers activate differently
  • Try images with multiple objects and see where it focuses

Check out the live demo and let me know what you think!


Got questions about how I built this? Hit me up on GitHub or LinkedIn.