Poda de modelos LLM: eliminar sin amputar
TL;DR
Un modelo de 7B parámetros tiene decenas de miles de millones de conexiones neuronales. Muchas de ellas contribuyen tan poco que podrías eliminarlas sin que ningún benchmark razonable lo notase. Eso es la poda (pruning): identificar los pesos irrelevantes y suprimirlos para obtener un modelo más pequeño, más rápido o que consuma menos memoria. Las técnicas modernas (SparseGPT, Wanda, 2:4 structured sparsity) hacen esto sin reentrenamiento, en pocas horas de GPU, y con menos de 1 punto de perplexity de penalización. No reemplaza a la cuantización, se combina con ella.
La analogía
Un árbol de roble con cien ramas. Cuando llega el invierno, el árbol poda sus ramas débiles: redirige los recursos hacia los troncos principales. Un podador experto no corta al azar, observa cuáles ramas tienen poco follaje, cuáles están secas, cuáles crecen en dirección equivocada, y corta sólo esas.
Un modelo de lenguaje es ese árbol. Sus “ramas” son los pesos que conectan neuronas. Después del entrenamiento, muchas de esas conexiones son vestigios del proceso de optimización: existían para que el gradiente descendiera con suavidad, pero en producción apenas modifican la salida. El podador que las elimina con precisión es SparseGPT o Wanda. El que corta al azar es magnitude pruning sin calibración. Ambos dan un árbol más pequeño; sólo el experto da uno que sigue produciendo el mismo fruto.
Qué es la poda realmente
Un modelo de lenguaje transformer almacena su conocimiento en matrices de pesos. Una capa de atención tiene cuatro matrices: $W_Q, W_K, W_V, W_O$. Una capa FFN tiene al menos dos ($W_{up}, W_{down}$, más $W_{gate}$ en SwiGLU). Para un modelo de 7B con 32 capas, el número de parámetros individuales supera los 7.000 millones.
Poda es el proceso de fijar a cero un subconjunto de esos parámetros de forma que:
- El modelo resultante ocupe menos memoria (si se almacena en formato disperso) o compute menos operaciones.
- La calidad de las respuestas no caiga de forma apreciable.
Hay dos dimensiones de clasificación que importan:
Granularidad: qué unidad se elimina.
- Poda no estructurada: pesos individuales, dispersos por toda la matriz. Alta compresión, difícil de acelerar en hardware convencional.
- Poda estructurada: cabezas de atención completas, neuronas FFN enteras, o capas completas. Menor compresión, pero el modelo resultante es denso y compatible con cualquier hardware.
- Semi-estructurada N:M: para cada grupo de M pesos consecutivos, exactamente N son cero. El caso 2:4 (2 zeros de cada 4) es el que soportan los Tensor Cores de NVIDIA Ampere y posteriores.
Momento: cuándo se elimina.
- Post-entrenamiento (PTQ de pesos): no requiere gradient, es el estándar en LLMs grandes.
- Durante entrenamiento (gradual/iterativa): más precisa, incompatible con modelos de 70B+ por coste.
Por qué existen tantos pesos redundantes
La respuesta está en cómo se entrenan los modelos. El descenso de gradiente estocástico con millones de pasos y learning rate decreciente produce redes sobre-parametrizadas por diseño: los parámetros extra no representan conocimiento adicional, sino margen de maniobra para que la optimización converja más fácilmente.
La Hipótesis del Ticket de Lotería (Frankle & Carlin, ICLR 2019) formalizó esta intuición: dentro de cualquier red densa entrenada existe una subred que, entrenada desde cero en aislamiento, alcanza la misma calidad. La red original es esa subred envuelta en ruido paramétrico generado por el proceso de entrenamiento.
Para LLMs, la evidencia empírica es consistente: modelos de 7B–70B toleran hasta el 50% de sparsidad no estructurada sin degradación observable en tareas conversacionales. En modelos más grandes, el umbral de tolerancia aumenta.
Las matemáticas que importan
¿Qué pesos son seguros eliminar?
Magnitude pruning: el criterio ingenuo
$$\text{importance}(w_{ij}) = |w_{ij}|$$
Se eliminan los pesos con menor valor absoluto. Intuitivo, pero incompleto: un peso pequeño conectado a una activación muy grande sigue contribuyendo significativamente a la salida.
Wanda: magnitud × activación
$$\text{importance}(w_{ij}) = |w_{ij}| \cdot |x_j|_2$$
Donde $x_j$ es el vector de activación de entrada correspondiente al peso $j$, calculado sobre un dataset de calibración de ~128 samples. El producto captura ambas dimensiones: un peso es seguro eliminar sólo si él es pequeño y su neurona de entrada está poco activa.
Ejemplo numérico:
- Peso A: $|w| = 0.001$, $|x|_2 = 500$ → importancia = 0.5
- Peso B: $|w| = 0.01$, $|x|_2 = 10$ → importancia = 0.1
Magnitude pruning eliminaría A (valor absoluto menor). Wanda elimina B (importancia menor). B es más seguro suprimir.
Wanda no requiere gradientes ni inversas de matriz hessiana. Corre en minutos sobre un modelo de 70B en una sola GPU. En benchmarks de perplexity WikiText-2 con 50% de sparsidad no estructurada, Wanda alcanza resultados comparables a SparseGPT con 10–100× menos coste computacional.
SparseGPT: compensación hessiana
SparseGPT aplica el mismo marco matemático que GPTQ (cuantización capa a capa), pero para poda. Cuando elimina un peso $w_p$, calcula una corrección $\delta w$ sobre los pesos restantes de la misma fila para minimizar el cambio en la salida de la capa:
$$\min_{\delta w} |W x - (W + \delta W) x|_2^2 \quad \text{s.t.} \quad w_p + \delta w_p = 0$$
La solución usa la inversa de la matriz Hessiana de segundo orden $H = X X^T$. El coste extra justifica la mayor precisión cuando la sparsidad objetivo es alta (>70%) o el modelo es pequeño (<7B, donde la redundancia es menor).
| Método | Criterio | Coste | Sparsidad 50% (7B, ppl WikiText-2) |
|---|---|---|---|
| Magnitude | |w| | Instantáneo | +2–5 puntos |
| Wanda | |w| · |x| | Minutos | ~+0.5 puntos |
| SparseGPT | Hessiana | 1–4h GPU | ~+0.4 puntos |
2:4 Structured Sparsity: el caso especial de NVIDIA
NVIDIA Ampere (A100) y posteriores (H100, Ada Lovelace) incluyen hardware dedicado para el patrón 2:4: exactamente 2 de cada 4 pesos consecutivos son cero. Esto permite al hardware omitir las multiplicaciones por cero de forma eficiente, obteniendo hasta 2× speedup en matmul sobre modelos con pesos 2:4.
La restricción es que la sparsidad tiene que ser exactamente 2:4, no un patrón arbitrario. Las herramientas NVIDIA (APEX Sparse, cuSPARSELt) y frameworks como PyTorch 2.x soportan esto nativamente:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
# Convertir pesos densos a 2:4 sparse
sparse_weight = to_sparse_semi_structured(dense_weight)
# Forward pass automáticamente usa sparse tensor cores
output = F.linear(input, sparse_weight)
Qué esperar en la práctica con 2:4:
- RTX 4090 (Ada Lovelace): soporta 2:4 sparse tensor cores para FP16/BF16. Speedup teórico 2×, real 1.3–1.6× dependiendo del tamaño de batch y secuencia.
- H100 (Hopper): ídem con mejoras adicionales en FP8 + 2:4 combinados.
- A100: soportado, sin FP8.
- GPUs consumer anteriores a Ada (3090, etc.): sin soporte de hardware. 2:4 sparsity da un modelo más pequeño en disco pero no acelera la inferencia.
Poda estructurada: eliminar cabezas y capas enteras
Poda de cabezas de atención
Un transformer de 32 capas con 32 cabezas por capa tiene 1.024 cabezas de atención. Estudios sistemáticos en modelos Llama-2 y Qwen muestran que entre el 20–40% de las cabezas tienen una influencia marginal en la salida final: su salida puede fijarse a cero sin que el benchmark cambie dentro del margen de error.
La métrica más usada es la Taylor importance: el producto del gradiente de la pérdida respecto a la salida de la cabeza por el valor de esa salida, sumado sobre un dataset de calibración:
$$\text{I}_{head} = \left| \sum_t \frac{\partial \mathcal{L}}{\partial o_t} \cdot o_t \right|$$
Las cabezas con $I_{head}$ más bajo se eliminan primero. Después de eliminar el 25% de cabezas en Llama-3-8B, la degradación en MMLU es <1% y el tiempo de inferencia de la atención cae ~20% porque los matmuls de atención son más pequeños.
Layer dropping: el atajo más agresivo
Eliminar una capa transformer completa suprime su bloque de atención y su FFN. El criterio más robusto es la Block Influence (BI), introducida en ShortGPT (2024):
$$\text{BI}(l) = 1 - \cos(\text{input}_l, \text{output}_l)$$
Una capa cuya salida es casi idéntica a su entrada (coseno próximo a 1, BI próximo a 0) actúa como función identidad: eliminarla no cambia el flujo de información. Las capas del centro del transformer suelen tener BI más bajo que las capas iniciales y finales.
Ejemplo numérico en LLaMA-2-70B:
- Capas 0–5 (early): BI > 0.3 → no eliminar
- Capas 20–45 (mid): BI < 0.05 → candidatas a eliminar
- Capas 76–80 (final): BI > 0.2 → no eliminar
Eliminando 8 capas de 80 (10%): el modelo pasa de ~140 GB a ~126 GB en BF16. Speedup de inferencia: ~10% (proporcional al número de capas eliminadas). Degradación en benchmarks de razonamiento: 1–3%.
Implicaciones para inferencia on-premise
La poda no estructurada (50% sparsidad) produce modelos con el mismo número de parámetros pero con la mitad a cero. Sin kernels sparse especializados, eso no da speedup: la GPU sigue ejecutando las multiplicaciones, sólo que multiplica por cero muy eficientemente. El beneficio real es de almacenamiento y transferencia (el modelo ocupa menos en disco y en RAM de sistema).
Con 2:4 structured sparsity sobre hardware Ada/Hopper, el speedup es real pero moderado (1.3–1.7×) y requiere herramientas adicionales (cuSPARSELt o PyTorch sparse).
La poda estructurada (cabezas, capas) sí acelera en cualquier hardware porque reduce el tamaño real del modelo. Es la opción correcta si el objetivo es throughput en hardware sin tensor cores sparse.
Combinación con cuantización: poda + cuantización son ortogonales. Un modelo 50% sparse a INT4 ocupa aproximadamente un octavo del original en FP32. Es el punto de llegada de muchos pipelines de compresión agresiva para edge inference.
Aplicado a hardware on-premise genérico
RTX 4090 (24 GB, Ada Lovelace)
Soporta 2:4 sparse tensor cores para FP16/BF16. Con Wanda + 2:4 sparsity sobre un Qwen2.5-14B:
# Pipeline de poda: Wanda 2:4 + quantización INT4
# 1. Ejecutar Wanda con calibración sobre 128 muestras
python wanda/main.py \
--model Qwen/Qwen2.5-14B \
--sparsity_ratio 0.5 \
--sparsity_type 2:4 \
--save pruned_model/
# 2. Cuantizar el modelo podado (opcional pero complementario)
python -m awq.entry --model_path pruned_model/ \
--w_bit 4 --output_path pruned_awq_model/
Resultado esperado: ~13 GB BF16 → ~6.5 GB tras poda 2:4 en sparse format → ~3.2 GB con AWQ INT4. El modelo 14B cabrá en la RTX 4090 con margen para KV cache.
4× H100 SXM (320 GB total, Hopper)
En este hardware la poda estructurada (layer dropping) tiene más sentido que 2:4 para inferencia de alta concurrencia: reduces el número de operaciones FLOPs por token de forma proporcional, lo que beneficia al throughput bajo batch grande donde el cuello es compute, no memoria.
# Aplicar layer dropping con ShortGPT BI metric
from shortgpt import compute_block_influence, drop_layers
bi_scores = compute_block_influence(model, calibration_data)
# Eliminar el 15% de capas con BI más bajo
model = drop_layers(model, bi_scores, drop_ratio=0.15)
Un Llama-3-70B podado al 15% de capas cabe en 3 H100 en vez de 4, liberando una GPU para otra tarea.
Ver también
- https://blog.lo0.es/posts/quantization-fundamentos-inferencia/ — la palanca complementaria: cuantizar reduce la precisión de los pesos que la poda ha decidido conservar; combinadas dan compresión máxima
- https://blog.lo0.es/posts/kv-cache-fundamentos/ — la poda reduce el tamaño del modelo, pero el KV cache sigue creciendo con el contexto; son costes separados en VRAM
- https://blog.lo0.es/posts/speculative-decoding-fundamentos/ — los drafters de speculative decoding son a menudo versiones podadas del modelo base, no modelos entrenados desde cero
- https://blog.lo0.es/posts/decode-optimizaciones-vllm/ — cómo el modelo podado se sirve en vLLM: los parámetros de throughput cambian con un modelo estructuralmente más pequeño
- https://blog.lo0.es/posts/knowledge-distillation-fundamentos/ — alternativa conceptual a la poda: en vez de eliminar partes del modelo grande, entrenar uno pequeño para que imite su comportamiento
Referencias
- SparseGPT: Massive Language Models Can be Accurately Pruned in One Shot — Frantar & Alistarh, 2023
- A Simple and Effective Pruning Approach for Large Language Models (Wanda) — Sun et al., ICLR 2024
- The Lottery Ticket Hypothesis — Frankle & Carlin, ICLR 2019
- ShortGPT: Layers in Large Language Models are More Redundant Than You Expect — Men et al., 2024
- NVIDIA 2:4 Sparsity in PyTorch — PyTorch Blog
- SparseForge: Efficient Semi-Structured LLM Sparsification — 2025