Compute Where It Counts: High Quality Sparsely Activated LLMs

<|start_header_id|>user<|end_header_id|> Can you answer two questions for me: 1. Write out the alphabet 2. Find the slope of a line through (5, 8) and (9, 9)<|eot_id|><|start_header_id|>assistant<|end_header_id|> Here are the answers to your questions: 1. The alphabet is written out as follows: A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z 2. To find the slope of a line through the points (5, 8) and (9, 9), we can use the formula: Slope = (y2 - y1) / (x2 - x1) Plugging in the values, we get: Slope = (9 - 8) / (9 - 5) Slope = 1 / 4 Slope = 0.25 So, the slope of the line through the points (5, 8) and (9, 9) is 0.25.<|eot_id|>

Less ComputeMore Compute

We present CWIC (Compute Where It Counts), a new method for creating efficient transformers that automatically decide when to use more or less compute. CWIC makes models faster, more cost-efficient, and more interpretable.

Summary:

  1. CWIC yields a 3x increase in CPU throughput with only a 10% reduction in benchmark performance.

  2. CWIC uses a different amount of compute for each token, making task difficulty interpretable.

  3. CWIC directly optimizes compute as a loss function, and learns to budget compute without labelled data or hand-crafted heuristics.

  4. The CWIC architecture uses learned activation thresholds and expressive sparsity patterns to enable adaptive computation.

Background

Large language models have become ubiquitous tools for natural language tasks. However, LLM inference requirements have grown beyond consumer devices and drive massive industry hardware expenses. For many applications, especially agentic ones, inference speed and cost are critical bottlenecks for real world deployment.

Therefore, many methods have been proposed to improve LLM inference efficiency. These include quantization, pruning, and sparse Mixture of Experts (MoE). Activation sparsity, the category in which CWIC falls, is another such approach. It focuses on removing small and inconsequential activations from the inputs of matrix multiplications, allowing some computations to be skipped without affecting the model's output.

One of the earliest activation sparsity methods for LLMs was Relufication, which inserted ReLU activation functions into LLMs to induce sparsity. ProSparse further increased sparsity by adding an L1 penalty to the ReLU activations. Deja Vu and ShadowLLM predicted sparsity on the fly by training small auxiliary MLPs. Q-Sparse discarded all but the top-K largest activations, and demonstrated a sparse scaling law where larger models are more robust to sparsity.

Most similar to our work are CATS, TEAL, and R-SPARSE. These methods all remove activations with smaller magnitude than a threshold. However, none of these methods directly learn activation thresholds. Furthermore, these methods suffer from performance collapse at high sparsity levels. CWIC addresses both limitations.

Motivating Insights

  1. Learned parameters perform better than heuristically chosen ones. The often-quoted "bitter lesson" states that general learning methods have historically outperformed hand-crafted approaches. We noticed that previous activation sparsity methods like TEAL (block-wise greedy optimization) and R-Sparse (search algorithm) used heuristics to determine activation thresholds. We hypothesized that learning thresholds directly through backpropagation would lead to better results.

  2. Adaptive computation methods with higher combinatorial expressiveness perform better. This was theorized and demonstrated by DeepSeekMoE, which improved over previous MoE methods by increasing the number of experts to choose from. We posited that the same principle would apply to activation sparsity: sparsity patterns with higher flexibility than the standard column pattern would yield better performance.

  3. Different parameters should have different sparsity levels. This insight was drawn from our own preliminary experiments. We found that, among other patterns, the Q attention matrix was more robust to sparsity than the K and V matrices. This shows a limitation in methods like CATS and Q-Sparse that use the same sparsity level for every parameter. Furthermore, while the sparsity level of each parameter could be manually tuned, we wanted to automate this by making sparsity thresholds learnable.

Methods

Sparsity Patterns

Column Sparsitya columnGranular Sparsitya stripeand many more

Previous sparsity methods exploit column sparsity to make matrix multiplication more efficient.

Formally, consider the multiplication of the matrix W  Rm×nW \text{ }\in\text{ } \mathbb{R}^{m\times n} with the vector x  Rnx \text{ }\in\text{ } \mathbb{R}^n resulting in the vector y  Rmy \text{ }\in\text{ } \mathbb{R}^m, such that y = Wxy \text{ }=\text{ } Wx. This operation can be represented as the sum of the column vectors of WW (with vi  Rmv_i \text{ }\in\text{ } \mathbb{R}^m denoting the i-th column of WW) weighted by the elements xix_i of xx:

y = i=1nxiviy \text{ }=\text{ } \sum_{i=1}^n x_i v_i

Column-wise sparse matrix multiplication, which we denote as SMM(x,W,M)\text{SMM}(x, W,\mathcal{M}), uses a mask M  {0,1}n\mathcal{M} \text{ }\in\text{ } \{0,1\}^n to zero out some elements of xx. We use the set SM = {1in  Mi=1}S_\mathcal{M} \text{ }=\text{ } \{ 1 \leq i \leq n \text{ }|\text{ } \mathcal{M}_i = 1\} to denote the set of indices where the elements of xx are not zero.

ysparse = SMM(x,W,M) = i=1nMixivi= iSMxiviy_{\text{sparse}} \text{ }=\text{ } \text{SMM}(x, W, \mathcal{M}) \text{ }=\text{ } \sum_{i=1}^n \mathcal{M}_i x_i v_i \\ =\text{ } \sum_{i \in S_\mathcal{M}} x_i v_i

When SM<n|S_\mathcal{M}| < n, SMM requires fewer multiply-adds than dense matrix multiplication.

We introduce granular sparsity, which partitions columns into sub-units called stripes. Each stripe can be independently activated or deactivated, creating a rich set of available sparsity configurations. This is visualized in the diagram above, which shows a matrix with 2 stripes per column. This scheme is such that stripes align with SIMD registers during CPU inference, which maximizes hardware efficiency.

First, we partition the output vector yy into kk equally sized stripes, such that y = [y(1),y(2),...,y(k)]y \text{ }=\text{ } [y^{(1)}, y^{(2)}, ..., y^{(k)}]. We similarly partition the column vectors viv_i of WW such that vi = [vi(1),vi(2),...,vi(k)]v_i \text{ }=\text{ } [v_i^{(1)}, v_i^{(2)}, ..., v_i^{(k)}]. With some abuse of notation, we use Ξ\Xi as a concatenation analogue of the summation operator \sum, such that y = Ξr=1ky(r)y \text{ }=\text{ } \Xi_{r=1}^k y^{(r)}. This means that the original matrix multiplication y = Wxy \text{ }=\text{ } Wx can be written as:

y = Ξr=1ki=1nxivi(r)y \text{ }=\text{ } \Xi_{r=1}^k \sum_{i=1}^n x_i v_i^{(r)}

The granular sparsity operation, which we denote as GMM(x,W,G)\text{GMM}(x,W,\mathcal{G}), uses a mask G  {0,1}k×n\mathcal{G} \text{ }\in\text{ } \{0,1\}^{k \times n} that zeroes out some of the stripes in WW. We define SG = {(r,i): 1rk, 1in  Gr,i=1}S_\mathcal{G} \text{ }=\text{ } \{ (r,i): \text{ } 1 \leq r \leq k, \text{ } 1 \leq i \leq n \text{ }|\text{ } \mathcal{G}_{r,i} = 1\} to be the set of stripe indices that are not zeroed out.

ygranular = GMM(x,W,G) = Ξr=1ki=1nGr,ixivi(r)= Ξr=1k(r,i)SGxivi(r)y_{\text{granular}} \text{ }=\text{ } \text{GMM}(x,W,\mathcal{G}) \text{ }=\text{ } \Xi_{r=1}^k \sum_{i=1}^n \mathcal{G}_{r,i} x_i v_i^{(r)} \\ =\text{ } \Xi_{r=1}^k \sum_{(r,i)\in S_\mathcal{G}} x_i v_i^{(r)}

When SG<kn|S_\mathcal{G}| < kn, GMM requires fewer multiply-adds than dense matrix multiplication. Also note that when k=1k=1, granular sparsity is equivalent to column sparsity.

Note that comparing against the thresholds knkn times, adds knkn comparison operations on top of the multiply-adds. For this reason one must use a stripe size that is larger than the max flop ratio one wants to be able to achieve.

Activation Thresholds

To determine the conditional sparsity mask G(x)\mathcal{G}(x), we use the magnitudes of each element in xx. Specifically, we learn a grid of thresholds, θ  R+k×n\theta \text{ }\in\text{ } \mathbb{R}_+^{k \times n}, such that Gr,i(x;θ)\mathcal{G}_{r,i}(x;\theta) is 11 if and only if xi  θr,i|x_i| \text{ }\geq\text{ } \theta_{r,i}. In other words, we only activate the stripe at position (r,i) if the magnitude of xix_i is greater than the threshold at position (r,i). We define this relation using the Heaviside step function H(z)H(z):

H(z) = {1z  00z < 0 Gr,i = H(xiθr,i)H(z) \text{ }=\text{ } \begin{cases} 1 & z \text{ } \geq \text{ } 0\\ 0 & z \text{ } < \text{ } 0 \end{cases} \\ \text{ } \\ \mathcal{G}_{r,i} \text{ }=\text{ } H(|x_i| - \theta_{r,i})

Plugging this into our previous GMM definition, we get the expression for granular matrix multiplication parameterized by thresholds, GMM(x,W;θ)\text{GMM}(x,W;\theta):

ygranular = GMM(x,W;θ)= Ξr=1ki=1nH(xiθr,i)  xivi(r)y_{\text{granular}} \text{ }=\text{ } \text{GMM}(x,W;\theta) \\ =\text{ } \Xi_{r=1}^k \sum_{i=1}^n H(|x_i| - \theta_{r,i}) \; x_i v_i^{(r)}

Note that when k=1k=1, this is equivalent to the inference strategy of TEAL, which uses a single threshold for each column of WW.

For optimal performance, we directly learn these thresholds with backpropogation. Unfortunately, the step function H(z)H(z) is not differentiable, so we can only approximate the gradient. To overcome this, we draw ideas from sparse autoencoders (SAEs). Specifically, JumpReLU uses a straight-through-estimator with a pseudo-derivative that approximates the true derivative of H(z)H(z). This pseudo-derivative is defined as follows, with KK representing a kernel function and ϵ\epsilon representing a tunable bandwidth:

/z [H(z)]  K(z/ϵ)/ϵ\partial/\partial z \text{ } \big[H(z)\big] \text{ }\approx\text{ } K(z/\epsilon) / \epsilon

Using this approximation, both the sparsity level and model outputs are fully differentiable with respect to the thresholds. Furthermore, activations receive gradients related to crossing thresholds, so that they can adjust their magnitudes to increase sparsity or improve outputs.

Loss Functions

To train our network, we use a combination of two loss functions.

First is a knowledge distillation loss that minimizes the difference between the distributions of the sparse student model and a fixed dense teacher model. In all of our experiments, the teacher model is the same model from which the sparse model was initialized. We used a combination of the FKL and RKL losses:

Ldistill = t=1T [KL(pteacher(xtx<t)pstudent(xtx<t)) + KL(pstudent(xtx<t)pteacher(xtx<t))]\mathcal{L}_{\text{distill}} \text{ }=\text{ } \sum_{t=1}^T \text{ } \Big[ \text{KL}\big(p_\text{teacher}(x_t | x_{<t}) || p_\text{student}(x_t | x_{<t})\big) \text{ }+\text{ } \text{KL}\big(p_\text{student}(x_t | x_{<t}) || p_\text{teacher}(x_t | x_{<t})\big) \Big]

Second is a FLOPs loss that controls the sparsity of the model. We define this loss using the overall batch FLOP usage FLOPs(x)\text{FLOPs}(x) of the sparse model, the FLOP usage of an equivalent dense model FLOPsdense(x)\text{FLOPs}_\text{dense}(x), and a target FLOP reduction ratio Rtarget\text{R}_\text{target}. The FLOPs loss is defined as follows:

LFLOPs = [min((FLOPsdense(x) / FLOPs(x))  Rtarget, 0)]2\mathcal{L}_{\text{FLOPs}} \text{ }=\text{ } \Big[ \text{min} \Big( \big(\text{FLOPs}_\text{dense}(x) \text{ }/\text{ } \text{FLOPs}(x)\big) \text{ }-\text{ } \text{R}_\text{target}, \text{ } 0 \Big) \Big]^2

Since the FLOPs loss is defined in terms of the overall batch FLOP usage, the model can budget different amounts of compute to different tokens, different sequences, and different parameters. This is key to the success of CWIC, as it enables the model to allocate compute where it counts.

Notably, we found it important that the target flop reduction ratio starts near 1x and increases throughout training. This helps the model smoothly transition into a sparse state without destroying its existing knowledge. Throughout our work, we use checkpoints from intermediate RtargetR_\text{target} values for evaluation at different FLOP levels, so models with larger FLOP reductions have usually been trained for longer.

We combine the individual losses with a sum weighted by the hyperparameter λFLOPs\lambda_\text{FLOPs}:

L = Ldistill + λFLOPsLFLOPs\mathcal{L} \text{ }=\text{ } \mathcal{L}_{\text{distill}} \text{ }+\text{ } \lambda_\text{FLOPs} \mathcal{L}_{\text{FLOPs}}

Further Details

More details about our methodology, including normalization and special treatment of FFN activations, can be be found in our paper.

Experiments

Setup

  1. meta-llama/Llama-3.2-1B-Instruct was used as both the base model and the teacher model.

  2. The corpus over which we performed distillation was a compilation of chat and question answering datasets, which included teknium/OpenHermes-2.5, facebook/natural_reasoning, HuggingFaceTB/smoltalk, microsoft/orca-agentinstruct-1M-v1, lmsys/lmsys-chat-1m, cais/mmlu/auxiliary_train, allenai/ai2_arc Challenge and Easy Training Splits, and allenai/winogrande XL Training Split. The full corpus can be found here.

  3. Each of the QKV, O, and FFN_GATE matrices were partitioned into stripes of size 1024. The LM_HEAD matrix was partitioned into stripes of size 8192.

  4. Training was conducted over 1.3 billion tokens. This took 18 hours on a cluster of 8 H100 GPUs (for a total of 144 GPU hours). This equates to only 0.04% of the 370K H100 hours originally used to train Llama-3.2-1B.

Benchmarks

CWIC vs TEAL - Evals Across FLOPs Reductions
FRMethodAverageMMLUARC-CARC-EHellaswagWinoGrandeBoolQSciQ
--Random25.025.025.025.025.050.050.025.0
1xBase Model57.148.354.872.441.152.553.377.6
2xTEAL/CWIC50.7 / 54.541.3 / 43.644.7 / 47.560.7 / 66.835.0 / 38.850.0 / 51.954.8 / 60.168.3 / 72.6
3xTEAL/CWIC36.4 / 51.528.3 / 40.327.6 / 44.131.9 / 59.927.0 / 34.149.6 / 49.554.7 / 60.135.6 / 72.3
4xTEAL/CWIC32.9 / 48.425.6 / 37.924.0 / 38.620.9 / 57.525.1 / 33.448.9 / 50.360.2 / 55.325.8 / 65.5
5xTEAL/CWIC33.5 / 46.423.8 / 35.723.2 / 37.324.4 / 55.625.0 / 30.749.6 / 50.362.6 / 54.825.8 / 60.7
6xTEAL/CWIC32.9 / 44.023.2 / 33.823.0 / 32.124.0 / 49.825.1 / 29.449.7 / 53.062.2 / 53.622.9 / 56.4

We measure of performance of models with different FLOP reduction levels on MMLU, ARC, HellaSwag, WinoGrande, BoolQ, and SciQ. As expected, we observe a decrease in performance across most benchmarks as compute decreases. Performance on WinoGrande and BoolQ does not consistently decline because the base model's performance is close to random choice. Comparing CWIC to the existing TEAL, we note the following observations:

  1. CWIC beats TEAL in average performance across all compute levels.

  2. The gap is largest at a 3x FLOP reduction, where CWIC achieves a 15 point average improvement over TEAL.

  3. CWIC at 3x FLOP reduction shows better average performance than TEAL at 2x FLOP reduction.

  4. CWIC at 6x FLOP reduction shows better average performance than TEAL at 3x FLOP reduction.

Real-World Acceleration

The graphic above shows the wall-clock speedup of CWIC and TEAL at different FLOP reduction levels. Consistent with TEAL's reporting, we tested single batch vector-matrix multiplications with a matrix of size 4096x4096. We used stripes of size 1024, corresponding to 4 stripes per column in this case.

On CPU, CWIC achieves near-ideal speedups.

On GPU, CWIC achieves similar speedups to TEAL. This means that for ideal quality at a given GPU throughput, CWIC is always preferred over TEAL.

Analysis

Token-Level Compute Allocation

This isn’t just about speed, it’s about targeted computation. By letting the model decide where its FLOPs are most valuable, we get systems that not only run faster, but also exhibit interpretable allocation of effort.

<|start_header_id|>user<|end_header_id|> Can you answer two questions for me: 1. Write out the alphabet 2. Find the slope of a line through (5, 8) and (9, 9)<|eot_id|><|start_header_id|>assistant<|end_header_id|> Here are the answers to your questions: 1. The alphabet is written out as follows: A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z 2. To find the slope of a line through the points (5, 8) and (9, 9), we can use the formula: Slope = (y2 - y1) / (x2 - x1) Plugging in the values, we get: Slope = (9 - 8) / (9 - 5) Slope = 1 / 4 Slope = 0.25 So, the slope of the line through the points (5, 8) and (9, 9) is 0.25.<|eot_id|>

Less ComputeMore Compute

The graphic above shows a chat completion from our sparse model with a 2x FLOP reduction. Boldness corresponds to the amount of compute allocated for a given token. We see that the model allocates very little compute to formatting tokens like <|start_header_id|>user<|end_header_id|>, since they are effectively memorized.

Other effects such as quoting text or repeating information being significantly lower compute are also visible.

Sequence-Level Compute Allocation

TODO: add and explain a graphic comparing ARC-C and ARC-E flop levels (this is in the paper)

We discovered a sequence-level trend in our model's FLOP allocation: less compute is used on problems that humans deem easier. This was found by comparing the number of FLOPs that our model uses on ARC-Easy and ARC-Challenge. Our model uses fewer FLOPs per token on ARC-Easy than on ARC-Challenge.

The finding is notable because we did not explicitly train the model to do this. Instead, it is an emergent behavior that arises from giving the model flexibility over where it allocates FLOPs. This aligns with our original insight that easier problems should use less compute.

Parameter-Level Compute Allocation

TODO: add graph with average sparsity of different parameters at different layers. The graph above shows the average sparsity of different parameters across layers in a model with a 6x overall FLOP reduction. We note the following:

At a 6x overall FLOP reduction we noted the following:

  1. The most dense parameters (in order) are the V, K, and O attention matrices. This is consistent with our preliminary experiments showing that those parameters were the most sensitive to sparsity. We hypothesize that the K and V matrices require the most density because they must compress information to be recalled by future tokens, and it is hard to determine which information will be important in advance.

  2. The least dense parameters are the FFN_UP and FFN_DOWN matrices. Llama-3.2 uses a GLU FFN, which passes the output of the FFN_GATE matrix through a silu activation function before multiplying it with the output of the FFN_UP matrix, with the result then passed through the FFN_DOWN matrix. Previous work has shown that intermediate FFN activations are naturally sparse, so it makes sense that the FFN_DOWN exhibits high sparsity. The FFN_UP matrix is handled specially by CWIC in a way that leverages the sparsity of the intermediate activations (more information can be found in our paper), so it is also expected to be especially sparse.

  3. Later layers are the most sparse. Density peaks in the middle layers of the model, then decreases for later layers. This is consistent with previous work showing that later layers are less impactful to the performance of deep transformers.

Attention Head Sparsity

Stripe Activation Frequency in O Attention Matrices
Channels Grouped By HeadStripes grouped By Layer

The diagram above shows the average frequency of stripe activations in the O attention matrices of a sparse model with a 6x average FLOP reduction. Lighter color corresponds to more frequent activation. The x-axis represents different channels going into the O matrix, segmented into attention heads. The y-axis represents each layer of the model, with vertical segments within each square representing different stripes corresponding to the same channel (2 stripes per channel in this case). This lets us make a few observations:

  1. The model learns to prune attention heads. We see that some heads are almost entirely dark, meaning that the model almost never uses them. This is similar to the attention head pruning strategy that humans have developed to improve LLM efficiency.

  2. Some attention heads have low output rank. We see that some heads are dominated by channels that almost always activate, while the other channels are rarely used. These heads appear mostly black with a few white vertical stripes. This shows that the output of that attention head can be compressed into only a few dimensions. These patters also imply that our method is effective at aligning the model's computational bases with its channels, which is a key property for effective sparsity but not naturally present in transformers (which theoretically operate with unprivileged bases).

Adaptive Computation

Beyond the Numbers

We see this as a stepping stone toward fully adaptive LLMs that tailor their computation to the difficulty of the task in real time - much like humans do when skimming an easy text versus solving a tough math problem.

To recap:

Learned parameters perform better than heuristically chosen ones. The often-quoted "bitter lesson" states that general learning methods have historically outperformed hand-crafted approaches. With CWIC we can learn computation budgets directly through backpropagation.

The last few years of AI advancement has been driven by scaling quality with ever increasing training compute.

Now we can scale speed with training compute.

If you’re currently using TEAL or other activation sparsity methods, Granular Sparsity with Learned Thresholds can offer better performance, more flexibility, and flops reductions well past 2× , paving the way for CPU-ready efficiency.

We’ll be releasing code to train CWIC at https://github.com/crystal-ai-org/cwic as well as example pretrained models soon.


If you are interested in fully adaptive LLMs, stay tuned... Our next blog post will be on latent space reasoning models.

And if you are in the Bay Area...

Come play with us in Palo Alto! →

Future work