Skip to content

Merck's Blog

Locating Reasoning Errors with Active Sampling

Chris Merck (chrismerck@gmail.com) -- 2nd North East Mechanistic Interpretability Workshop, August 22, 2025, Northeastern University, Boston


Abstract

Recent advances in chain-of-thought reasoning allow large language models (LLMs) to solve mathematical problems that are beyond the computational power of a single forward pass. This suggests the existence of learned mechanisms operating at the level of reasoning steps in the chain of thought (CoT), yet few techniques exist for identifying these mechanisms. Furthermore, chains of thought can be deceptively interpretable, resembling human internal monologues closely enough that we risk being misled about their causal structure, heightening the need for rigorous interpretability methods. In this work in progress, we develop an algorithm for locating reasoning errors in incorrect base solutions via targeted resampling. This exploration should improve our understanding of chain-of-thought reasoning, particularly how it goes wrong, allowing us to more safely and efficiently operate reasoning models.

Introduction

One frame through which to interpret reasoning traces is to treat them as the work of a math student, with the researcher as tutor. Correct answers to unguessable problems may be rewarded without inspecting the work, while the incorrect solutions are more interesting: the tutor scans for the first mistake and circles in red pen, leaving a helpful comment. We aim ultimately to automate this process, with this poster presenting some of the elements of work in progress. We see this work as part of a broad effort to take advantage of the present opportunity for CoT monitoring and interpretability (Korbak et al., 2025) and an application of counterfactual techniques from physical systems (Merck & Kleinberg, 2016).

Foundational Work. In their recent groundbreaking preprint, Bogdan et al. (2025) employ several techniques for finding causal structure among the sentences of a CoT reasoning trace, cross-validating counterfactual sampling against attention tracing and masking. They develop the math-rollouts dataset, containing 10 incorrect base solutions to problems from the MATH dataset as generated by DeepSeek R1-Distill Qwen-14B (DeepSeek, 2025), selected for intermediate difficulty (having between a 25% and 75% probability of being solved correctly). The dataset also contains 100 rollouts for each sentence in each reasoning trace, allowing us to explore the probability trajectory and counterfactual scenarios.

Active Bayesian Changepoint Detection

Although math-rollouts provides a useful starting point, we would like to eventually scale up to using state-of-the-art reasoning models where exhaustively generating many rollouts for each sentence would be prohibitively expensive. So we apply an active sampling algorithm to efficiently find the sentence containing the most prominent error, termed a changepoint after Adams & MacKay (2007), resulting in a ~100X reduction in the number of rollouts required, at least when the trace contains a clear error.

We propose a Bernoulli process model with a single changepoint τ at which the probability of a correct solution drops from p₁ to p₂:

\[ p(\text{correct}_t) = \begin{cases} p_1 & \text{if } t < \tau \\ p_2 & \text{if } t \geq \tau \end{cases} \]

Bayesian Inference. We maintain a posterior distribution over changepoint locations τ ∈ {1, …, T} using a uniform prior. For each hypothesis τ, we model the probabilities with Beta priors: p₁ ~ Beta(2, 2), reflecting our belief that the initial probability lies between approximately 25% and 75%, and p₂ ~ Beta(1, 19), a strong prior on a low chance of recovery.

Active Sampling. We select the next sample location with replacement to maximize expected information gain about the changepoint location:

\[ t^* = \arg\max_{t} \mathbb{E}_{y_t}[H(\tau) - H(\tau | y_t)] \]

where H(τ) is the entropy of the current posterior over changepoint locations and y_t ∈ {0, 1} is the correctness of the hypothetical resampled rollout. This strategy efficiently focuses sampling around the uncertain changepoint region.

Probability Trajectories. With the choice of a strong prior on p₂, we find that the algorithm tends to find reasonable changepoints with just 100 rollouts.

Probability trajectories for all 10 incorrect base solutions. Each subplot shows how chances of success evolve along the CoT. Red vertical lines indicate detected changepoints (80% CI).

Failure Analysis

We now examine and discuss selected failures based on the change in answer distribution at the changepoint and a reading of the trace.

Nested Multiplication (#330) — wrong base case in recurrence relation

Wait, wait, no. Let me think. Wait, in the original expression, each time it's 3(1 + ...). So, actually, each level is 3(1 + previous level). So, maybe I should approach it as a recursive computation.

The trace was on the right track—only two computations away from the correct answer. But when the model used the ellipsis (...), many rollouts failed to converge to a boxed answer, recursing until the token limit. The base solution does converge, and fails due to an error in the base case of 4 instead of 0 or 12. Although visually the probability trajectory does seem to contain a clear changepoint, the identified sentence is in fact by inspection the first introduction of this mistake that persists into the incorrect solution.

Now we examine the starkest failure in the Digit Counting problem:

Digit Counting (#2236) — non sequitur

The ones digits here are 1–8, so 4 doesn't appear in the units place in this partial set.

Here the model makes a statement that is obviously false to us, but is generated with approximately a 25% probability at this step. To better understand the structure of this failure, we collect all 100 rollouts resampled at this changepoint sentence and visualize them in a dendrogram. Following Bogdan et al. (2025), we join rollouts if the sentence embeddings have cosine similarity greater than the median similarity between all sentence pairs (0.8).

A visualization of the 100 resamplings at sentence 24. Color indicates the probability of a correct solution (blue = 1.0, red = 0.0). We observe that the resampled rollouts fall into three clusters, the bottom of which is the path taken in the incorrect base solution. The clear separation of the incorrect cluster demonstrates that the similarity metric is a good definition for the notion of "sameness" of a sentence.

Finally, we examine a case where there is no clear changepoint. Inspecting the the model's solution summary, we find that the model was quite knowingly vacillating between two possible interpretations of the problem:

Hex to Binary (#4682) — trick question

While converting each hexadecimal digit to 4 bits results in a 20-bit binary number, the minimal number of bits required to represent the number 419,430₁₀ without leading zeros is 19.

However, since the problem asks for the number of bits when the hexadecimal number is written in binary, and considering that each hexadecimal digit is typically converted to 4 bits, the correct answer is 20 bits.

Limitations and Future Directions

  • Although we prefer to investigate reasoning errors in a counterfactual context, there is a body of existing work that studies errors in chain-of-thought prompting within and beyond mathematical reasoning (Lightman et al., 2023; Tyen et al., 2024), which also examines reasoning structure beyond the influence on accuracy (Xia et al., 2025).

  • As we saw with the trick question, the failure of these models to deliver mundane utility often stems from issues interpreting the user's prompt rather than logical errors within the chain of thought. The relationship of user intent and model interpretation can be fully formalized (Wu et al., 2022) for mathematical problems. The broader problem of prompt interpretation is of high importance to alignment, especially as models become capable of achieving more work with a single operator instruction.

  • The rollouts in math-rollouts are drawn from a relatively small 14B model distilled from a larger one (DeepSeek R1 671B). It is possible that reasoning failures in stronger models are qualitatively different, or that the structure of the reasoning behavior differs after distillation. The changepoint detector introduced here is intended to facilitate testing on larger models.

  • The changepoint model assumes a single error. However, realistic chains of thought may contain multiple errors and backtracking leading to a more complex probability trajectory. It is possible to develop a changepoint detector for these more complex structures, but validation requires much more than the 10 examples available in the math-rollouts dataset.

  • In particular, the case of no clear changepoint is important to avoid spurious error detections. The algorithm could be updated to track confidence in a null hypothesis.

  • State-of-the-art reasoning models and agents use parallel execution and subagents (subroutines with chains of thought that are not fully attended to by the top-level agent such as in Anthropic, 2025). Investigating failure modes of these new architechtures is a challenge due to their complexity, however it would appear that the Bogdan et al. (2025) methods could be adapted. Active sampling may be of even greater importance as the compute per rollout increases.

Acknowledgements

I am grateful for discussions of this work with Anuj Apte. This work is generously supported by a grant from the Cosmos Institute (https://cosmos-institute.org/) and the Foundation for Individual Rights and Expression (FIRE) (https://www.thefire.org/).

References

  • Adams, R. P., & MacKay, D. J. (2007). Bayesian online changepoint detection. arXiv preprint arXiv:0710.3742.
  • Anthropic (2025). Claude 3.5 Sonnet with Computer Use. Available at: https://www.anthropic.com/
  • Bogdan, P., et al. (2025). Thought Experiments: Counterfactual Reasoning in Chain-of-Thought. arXiv preprint.
  • DeepSeek (2025). DeepSeek R1: Reasoning at the Edge of AGI. arXiv preprint.
  • Korbak, T., et al. (2025). Chain-of-thought reasoning as a monitoring opportunity. arXiv preprint.
  • Lightman, H., et al. (2023). Let's verify step by step. arXiv preprint.
  • Merck, C., & Kleinberg, S. (2016). Causal explanation under indeterminacy: A sampling approach. In Proceedings of the AAAI Conference on Artificial Intelligence.
  • Tyen, G., et al. (2024). LLMs cannot correct reasoning errors yet. arXiv preprint.
  • Wu, Y., et al. (2022). Autoformalization with large language models. arXiv preprint.
  • Xia, M., Li, X., Liu, F., Wu, B., & Liu, P. (2025). Reasoning structure in chain-of-thought beyond accuracy. arXiv preprint.

Reasoning Failures in DeepSeek V3 Base

Wherein we observe arithmatic errors, infinite loops, handwaving, and wishful thinking. -- Coursework for Frontier Language Models (Summer 2025, Fractal U).

Intro

In this post, I manually analyze specific reasoning failures of DeepSeek V3 Base (1). Why the base model? Because if we are to understand how reasoning emerges during RL, we should start by studying the nacent reasoning capabilities of the base models that are the starting point for the likes of R1 (2) and presumably o3, and how failures are avoided. The next post will explore reasoning failures in the R1 model.

We will investigate a simple verifiable math problem for which the base model solves ~75% of the time, and which R1 always solves. This way we can see some of the failures which GRPO RL is able to optimize away.

  1. DeepSeek V3 is a 671B parameter Mixture-of-Experts model released in December 2024.
  2. Just a month after announcing the base model, DeepSeek published a reasoning model DeepSeek R1 competative with then state-of-the-art OpenAI o1. R1 was trained by mostly RL against a rule-based reward function of math problems with known solutions and coding challenges that can be checked by compilation and unit tests.

Method

For the purposes of this post I use a math problem of a kind likely used (1) to train R1, a 3x3 linear system of equations:

  1. We do not know exactly which math problems with deterministic results were used by the DeepSeek team.
\[ \left\{ \begin{aligned} 4x - 3y + 6z &= 6 \\ -3x + 4y + 5z &= 5 \\ 3x + 9y - 7z &= -7 \end{aligned} \right. \]

This problem is hard enough that we observe frequent and varied failures in V3 Base, but easy enough that perfect performance is achieved by R1.


I use the exact prompt template given in the DeepSeek R1 paper:

A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: {problem}. Assistant:

And then in this post I'm using just this problem:

Solve the system of equations: \(4x -3y +6z = 6\), \(-3x +4y +5z = 5\), \(3x +9y -7z = -7\). Give your solution as a JSON array of integers.

Failures

I ran the above prompt a number of times and sampled several of the failure cases below with commentary.

Handwaving

Sometimes V3 Base will describe how to solve the problem rather than actually solving it. I hypothesize that in the handwaving example below, the failure is mediated by the use of the auxiliary verbs can and will.

To solve the system of equations, we can use the method of substitution or elimination. Here, we will use the method of elimination. First, we will eliminate the variable x from the first and second equations. To do this, we will multiply the first equation by 3 and the second equation by 4, and then add the two equations together. This will give us a new equation with only y and z. Next, we will eliminate the variable x from the first and third equations. To do this, we will multiply the first equation by 3 and the third equation by 4, and then subtract the two equations. This will give us another new equation with only y and z. Now, we have two equations with only y and z. We can solve these two equations simultaneously to find the values of y and z. Once we have the values of y and z, we can substitute them into any of the original equations to find the value of x. Finally, we will have the values of x, y, and z, which we can represent as a JSON array [x, y, z].

This would be a great preliminary plan, but the model did not actually work through the computation steps, and so it then gives the final answer:

[1, 2, 3]

This looks like an example rather than a guess. This is akin to behavior I've seen from OpenAI models where rather than returing a functional program safisying my request, they return an "example" of how I could do it.

Arithmetic Errors & Infinite Loops

It is well-known that base models tend to end up in repeating patterns. In the completion below, we see such a repeating sequence emerge, but looking closely we can see that an arithmetic error occurred prior: in simplifying \(\frac{45}{4}\frac{38}{7}\), the model eliminated the factor of 2 from the nominator (yielding \(45 * 38 / 2 = 855\)) but then did not eliminate the factor from the denominator (\(4 * 7 = 28\) rather than \(4 * 7 / 2 = 14\)).

<think> To solve the system of equations, we can use the method of elimination or substitution. However, since the coefficients of the variables are not the same in all equations, it might be easier to use the method of substitution. First, we can solve the first equation for x: \(4x = 3y - 6z + 6\) \(x = \frac{3y - 6z + 6}{4}\) Next, we can substitute this expression for x into the second and third equations: \(-3(\frac{3y - 6z + 6}{4}) + 4y + 5z = 5\) \(3(\frac{3y - 6z + 6}{4}) + 9y - 7z = -7\) Simplifying these equations, we get: \(-\frac{9}{4}y + \frac{9}{2}z - \frac{9}{2} + 4y + 5z = 5\) \(\frac{9}{4}y - \frac{9}{2}z + \frac{9}{2} + 9y - 7z = -7\) Combining like terms, we get: \(\frac{7}{4}y + \frac{19}{2}z = \frac{19}{2}\) \(\frac{45}{4}y - \frac{23}{2}z = -\frac{23}{2}\) Now, we can solve the first equation for y: \(\frac{7}{4}y = \frac{19}{2} - \frac{19}{2}z\) \(y = \frac{38}{7} - \frac{38}{7}z\) Substituting this expression for y into the second equation, we get: \(\frac{45}{4}(\frac{38}{7} - \frac{38}{7}z) - \frac{23}{2}z = -\frac{23}{2}\) Simplifying this equation, we get: \(\frac{855}{\color{red}{28}} - \frac{855}{28}z - \frac{23}{2}z = -\frac{23}{2}\) Combining like terms, we get: \(-\frac {855}{28}z - \frac{23}{2}z = -\frac{23}{2} - \frac{855}{28}\) \(\color{red}{-\frac{855}{28}z - \frac{23}{2}z = -\frac{23}{2} - \frac{855}{28}}\) \(-\frac{855}{28}z - \frac{23}{2}z = -\frac{23}{2} - \frac{855}{28}\) \(-\frac{855}{28}z - \frac{23}{2}z = -\frac{23}{2} - \frac{855}{28}\) \(-\frac{855}{28}z - \frac{23}{2 }z = -\frac{23}{2} - \frac{855}{28}\) \(-\frac{855}{28}z - \frac{23}{2}z = -\frac{23}{2} - \frac{855}{28}\) \(-\frac{855}{28}z - \frac{23}{2}z = -\frac{23}{2} - \frac{855}{28}\)

Wishful Thinking

Lastly, I see a case of what I'll call wishful thinking. V3 Base suceeds at eliminating \(x\) in the first step, but then it tries to eliminate \(y\) in the same way, neglecting that a multiplication of one equation by \(6\) is required in order to achieve elimination.

Presumably the model at this point had a choice whether to actually do the correct computation, leading to an unintended result where \(y\) was not in fact eliminated, or to copy the pattern from the first step even though the computation is wrong. It chooses the latter, which though of course reduces to statitical pattern matching, has the appearance of wishful thinking.

<think> To solve the system of equations, we can use the method of elimination or substitution. However, since the equations are already in a form that allows for easy elimination, we will use that method. First, we will eliminate \(x\) from the second and third equations by adding the first equation to the second and third equations. This gives us the system of equations: \(4x -3y +6z = 6\), \(y +11z = 11\), \(6y -z = -1\). Next, we will eliminate \(y\) from the third equation by adding the second equation to the third equation. This gives us the system of equations: \(4x -3y +6z = 6\), \(y +11z = 11\), \(\color{red}{10z = 10}\). Finally, we can solve for \(z\) by dividing both sides of the third equation by 10. This gives us \(z = 1\). Substituting this value of \(z\) into the second equation, we can solve for \(y\). This gives us \(y = 0\). Substituting these values of \(y\) and \(z\) into the first equation, we can solve for \(x\). This gives us \(x = 1\). Therefore, the solution to the system of equations is \(x = 1\), \(y = 0\), \(z = 1\). </think> <answer> [1, 0, 1]

Next Steps

I see two directions to take this exploration:

  1. Explore reasoning failures in a reasoning model (R1): does it make the same mistakes only less often? Or does it fail in new ways? This will require harder math problems.

  2. Automate the discovery of the location of the reasoning failure in the chain of thought. One way to do this could be to look at the probability of success as a function of position in a particulary faulty output. I expect that, at least in the case of arithmatic errors, we could automatically discover the exact token at which the error occured, but this does require a lot of inference. I'd like to explore statictical and heuristic methods for efficiently deriving failure explainations.

Demonstrating Superposition

Wherein I implement a toy model of feature superposition by hand in C as a remedial exercise, and create a video showing the model learning a suboptimal representation.


In the unlikely scenario where all of this makes total sense and you feel like you're ready to make contributions, [...]
- Scott Alexander Nov 2023


In this post, I will manually reproduce the intro figure from Toy Models of Superposition without using anything but the C standard library, so as not to hide any of the details. The paper comes with a PyTorch implementation, but autograds do so much work I feel I need to earn the right to use them by working out the toy model math and code myself.

The basic result is this little animation showing how the model learns the pentagonal representation from the paper's intro:

The Data

First off, we need to generate the synthetic data. We want samples with dimension \(n=5\), where features are sparse (being non-zero with probability \(1 - S\)) and uniformly distributed on the unit interval when they do appear, which we can write down as a mixture of two distributions:

\[x_i \sim \sum \begin{cases} \delta(0) & S \\ \text{U}(0, 1) & (1 - S) \end{cases}\]

Here \(\delta\) is the Dirac delta function, i.e. the point mass distribution.

Synthesizing Data in C

The C stdlib doesn't have a uniform random function so I wrote one (1) and used it to generate the data:

  1. #include <stdlib.h>
    
    /// get a float uniformly distributed on U[0, 1)
    float frand() {
        return (random() / (float) RAND_MAX);
    }
    
    macOS manpages implore us to use the cryptographically secure RNG arc4random(), but I think the polynomial PRNG is good enough for this application, and I like that we can use srandom(0) to force reproducibility.
void synthesize(int n, long count, float X[n][count], float S_) {
    // sparsity S in [0, 1), S_ is 1-S
    for (long c = 0; c < count; c++) {
        for (int i = 0; i < n; i++) {
            if (frand() < S_) {
                X[i][c] = frand();
            }
        }
    }
}

Now we can generate some samples with sparsity 1-S = 0.1, using a little printmat function (1) to check our work. Below we see the result for four 5-D samples.

  1. void printmat(char * tag, int rows, int cols, float A[rows][cols]) {
        printf("%s: [\n", tag);
        for (int m = 0; m < rows; m++) {
            for (int n = 0; n < cols; n++) {
                if (A[m][n]) {
                    printf("  %1.03f ", A[m][n]);
                } else {
                    printf("  0     ");
                }
            }
            printf("\n");
        }
        printf("]\n");
    }
    
    const int count = 4;
    srandom(0);
    memset(X, 0, sizeof(X));
    synthesize((float *) X, count, 0.1);
    printmat("X", (float *) X, N, count);
    
\[ X=\begin{bmatrix} 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \\ 0.522 & 0 & 0 & 0 \\ 0 & 0.568 & 0 & 0 \\ 0 & 0 & 0 & 0 \\ \end{bmatrix} \]

Here we see that only ~2 of the 20 elements are non-zero, as expected with this sparsity level.

The Model

The model is a 2-layer feedforward network, where the hidden layer maps down from \(n=5\) to \(m=2\) dimensions without any activation function, and then the output layer uses the transpose of the hidden-layer weights plus a bias term and a ReLU activation function. This is, as far as I can tell, basically an autoencoder.

In matrix notation we have:

\[y = \verb|ReLU|(W^T W x + b).\]

The Forward Pass

Breaking down into steps with indecies we have:

\[ \begin{aligned} h_k &= \sum_{i=1}^n w_{ki} x_i \\ a_j &= b_j + \sum_{k=1}^m h_k w_{kj} \\ y_j &= \max(0, a_j), \end{aligned} \]

from which follows a first C implementation of the forward pass:

void forward(params_t * p, float x[N], float * y) {
    float hk[M];
    memset(hk, 0, sizeof(hk));
    // hidden layer
    for (int k = 0; k < M; k++) {
        for (int i = 0; i < N; i++) {
            hk[k] += p->W[k][i] * x[i];
        }
    }
    // output layer
    for (int j = 0; j < N; j++) {
        y[j] += p->b[j];
        for (int k = 0; k < M; k++) {
            y[j] += p->W[k][j] * hk[k];
        }
        // ReLU activation
        y[j] = y[j] > 0 ? y[j] : 0;
    }
}

Importance and Loss

We want the model to prioritize representation of certain dimensions, so we assign an importance \(I_i\) to each dimension, which we make decrease geometrically: \(I_i = 0.7^i\). A weighted least-squares loss is then:

\[L = \frac{1}{2} \sum_i I_i (y_i - x_i)^2 + \alpha \sum_{k,j} w_{kj}^2.\]

And our goal is to optimize the parameters \(W\) and \(b\) to minimize this loss. We then should be able to visualize the weights and see feature superposition emerging as a function of sparsity.

Note that in the paper they do not specify any regularization. I threw in the L2 regularization term because I saw that a weight-decay optimizer was used in the paper's code example on CoLab, but it turns out to be totally unnecessary if we pick the learning rate right.

Training

Finding the Gradient

As I'm a bit rusty on my calculus, I'll go step by step through the gradient computation. Taking the derivative with respect to an arbitrary weight and pushing the derivative inside the sums as far as it will go, applying the chain and power rules, and using \(\delta_j\) to denote the error in the \(j\)th output, we have:

\[ \begin{align*} \frac{\partial L}{\partial w_{kj}} &= \frac{1}{2} \sum_i I_i \frac{\partial}{\partial w_{kj}} \bigl(y_i - x_i\bigr)^2 + \alpha \sum_{k}\sum_{j'} \frac{\partial}{\partial w_{kj}} w_{kj'}^2, \\[1ex] &= \sum_i I_i\,\delta_i\,\frac{\partial y_i}{\partial w_{kj}} + \alpha\,w_{kj}. \end{align*} \]

Note that in the regularization term we've used the fact that the only summand that depends on \(w_{kj}\) is the one where \(k' = k\) and \(j' = j\), so the primes drop off the indices.

Now focusing on the derivative of the output layer, for the case where \(y_j\) is non-zero, we have:

\[ \begin{align*} \frac{\partial y_{j}}{\partial w_{k i}} &= \frac{\partial}{\partial w_{k i}} \sum_{i'}\sum_{k'} w_{k' j}\,w_{k i'}\,x_{i'} \\[1ex] &= \sum_{i'} x_{i'}\, \sum_{k'} \frac{\partial}{\partial w_{k i}} \bigl(w_{k' j}\,w_{k i'}\bigr) \\[1ex] &= \sum_{i'} x_{i'} \sum_{k'} \begin{cases} 2\,w_{k i'} & k'=k \wedge i'=j=i,\\ w_{k i'} & k'=k \wedge (i' \ne j \wedge i' = i),\\ 0 & \text{otherwise} \end{cases} \\[1ex] &= \sum_{i'} x_{i'}\,w_{k i'} \;+\; x_{j}\,w_{k j} \\[1ex] &= h_{k} \;+\; x_{j}\,w_{k j}. \end{align*} \]

Let's do an intuition check on this derivative. The weight \(w_{k i}\) appears in two places: once in the hidden layer as \(x_i\)'s contribution to \(h_k\), and once in the output layer as \(h_k\)'s contribution to \(y_j\). So increasing \(w_{k i}\) will increase the output proportionally to the value of \(h_k\), but then we need to add in the fact that \(h_k\) itself is also increased proportional to both the \(i\)th input and the current value of the weight. So our calculation seems intuitively correct.

Computing the Gradient in C

To compute the gradient in C, we implement a gradient function that adds to a gradient accumulator:

float gradient(const params_t * p, const float x[N], float alpha, params_t * grad);

The simplest way is just to take the forward pass and keep track of temporary variables that appear in the gradient expression above and then add them together as prescribed. For example the hidden layer computation now looks like:

for (int m = 0; m < M; m++) {
    for (int n = 0; n < N; n++) {
        wkj_xj[m][n] = p->W[m][n] * x[n];
        hk[m] += wkj_xj[m][n];
    }
}

And so on (1) as we compute the gradient, add to the accumulator, and return the loss.

  1. float gradient(const params_t * p, const float x[N], float alpha, params_t * grad) {
        // unlike the forward pass, we keep track of intermediate
        // values that appear in the gradient
        // our toy model is so small that all this fits comfortably 
        // in the thread stack
        // alpha = L1 regularization coefficient
        // grad is a pointer to the gradient accumulator
        // returns loss
        float wkj_xj[M][N];
        float hk[M];
        float y[N];
        float delta[N];
        float dL_wkj[M][N];
        memset(wkj_xj, 0, sizeof(wkj_xj));
        memset(hk, 0, sizeof(hk));
        memset(y, 0, sizeof(y));
        memset(delta, 0, sizeof(delta));
        memset(dL_wkj, 0, sizeof(dL_wkj));
        // hidden layer
        for (int m = 0; m < M; m++) {
            for (int n = 0; n < N; n++) {
                wkj_xj[m][n] = p->W[m][n] * x[n];
                hk[m] += wkj_xj[m][n];
            }
        }
        // output layer
        for (int n = 0; n < N; n++) {
            for (int m = 0; m < M; m++) {
                y[n] += p->W[m][n] * hk[m];
            }
            y[n] += p->b[n];
            // ReLU activation
            y[n] = y[n] > 0 ? y[n] : 0;
            // compute delta
            delta[n] = y[n] - x[n];
        }
        // compute error
        float L = 0;
        for (int n = 0; n < N; n++) {
            float Ij = importance(n);
            L += Ij * delta[n] * delta[n];
        }
        for (int n = 0; n < N; n++) {
            for (int m = 0; m < M; m++) {
                L += alpha * fabs(p->W[m][n]);
            }
            L += alpha * fabs(p->b[n]);
        }
        L /= 2;
        for (int m = 0; m < M; m++) {
            for (int n = 0; n < N; n++) {
                if (y[n] <= 0) continue;
                dL_wkj[m][n] = importance(n) * delta[n] * (hk[m] + wkj_xj[m][n])
                    + alpha * (p->W[m][n] > 0 ? 1 : -1);
            }
        }
        // add to gradient accumulator
        for (int m = 0; m < M; m++) {
            for (int n = 0; n < N; n++) {
                grad->W[m][n] -= dL_wkj[m][n];
            }
        }
        for (int n = 0; n < N; n++) {
            if (y[n] <= 0) continue;
            grad->b[n] -= delta[n] + alpha * (p->b[n] > 0 ? 1 : -1);
        }
        return L;
    }
    

The Training Loop

Now we put it all together (1), adding in a random batch of size 1024 which provides some stochasticity to the gradient descent. Note that I'm not using any optimizer, and I've got regularization turned off.

  1. params_t p;
    memset(&p, 0, sizeof(p));
    // initialize with random weights and biases
    for (int j = 0; j < N; j++) {
        for (int k = 0; k < M; k++) {
            p.W[k][j] = frand() * 0.001;
        }
        p.b[j] = frand() * 0.001;
    }
    params_t grad;
    for (int r = 0; r < runs; r++) {
        memset(&grad, 0, sizeof(grad));
        float L = 0;
        long batch[batch_size];
        batch_indices(batch_size, batch);
        for (long c = 0; c < batch_size; c++) {
            L += gradient(&p, X[batch[c]], alpha, &grad);
        }
        update(&p, &grad, eta / batch_size);
        printf("run: %d\n", r);
        printf("L: %1.04f\n", L / batch_size);
        if (r % 100 == 99) {
            // print b
            printmat("b", 1, N, p.b);
            // print W
            printmat("W", M, N, p.W);
            // print grad w and b
            printmat("grad w", M, N, grad.W);
            printmat("grad b", 1, N, grad.b);
        }
        fflush(stdout);
    }
    

This C program outputs a long log of the weights and loss during the training run. Runs take about 10 seconds for 10000 batches, which is enough to fully converge.

It took only a little bit of trial and error to get the learning rate right.

I then asked o3 (1) to take the outputted log and make an animation, resulting in the video shared at the top. Here was the prompt that got me 90% of the way to a working animation:

  1. One of the hardest parts of learning to use LLMs I find is knowing when and when not to use them. For building visualizations, I find LLMs incredibly helpful, while for learning, it's best to battle through the details oneself.

Please write a Python program that takes a file called log with this format: (pasted example log snippet) and uses matplotlib to render loss as a function of the run, and the W matrix showing how each unit input vector is mapped to the hidden dimensions (2d) which should be a plot with one scatter dot for each of the 5 input unit vectors. Make this an animation showing how the points migrated over time, keeping the xy limits fixed so it is stable. Include a moving average of the loss as a line plot.

Intiution about Local Minima

I'll close with an animation of the same model but with \(n=32\) features, and importance decaying as \(0.9^i\). Notice how it converges to a suboptimal solution!

You can see that the pentagon quickly forms out of features 0 through 4, and then features 5, 6, and 7 start to emerge, but then it is the less important feature 7 which pushes its way out to form a stable hexagon. Why is that? It seems to be because features 5 and 6 were unlucky enough to be on the side of the pentagon shared with higher-importance features while 7 had the good fortune of being near the relatively weaker feature 4 which it could push out of the way.

Bottom line: there are non-trivial local minima even in simple models and we can actually have some hope of gaining intuition about them.