text-generation-inference/server/text_generation_server/layers/exl2.py
Daniël de Kok e52be9bba2
Add support for Deepseek V2 (#2224)
Deepseek V2 is a MoE model from Deepseek. Relevant variations
compared to other models:

- Grouped top-K in expert selection.
- mscale in yarn is calculated using the `mscale` and `mscale_all_dim`
  configuration options.
- `mscale_all_dim` is also used in scaling attention softmax.
- Permuting of the query/key representations before applying rotary
  embeddings.
- Some projections cannot be sharded (`q_a_proj`, `kv_a_proj_with_mqa`).
  So, we need weight loads that supports quantized weights. To this
  end `{Weights,WeightLoader}.get_weight` was added.
- The query/key head dimensionality differs from that of the value,
  so we need to pad during attention.
- Heads with size 192, needs an extension to our paged attention
  fork and we need to ensure that the KV cache is allocated with the
  correct size.
- Shared experts.
2024-07-19 17:23:20 +02:00

79 lines
2.4 KiB
Python

from dataclasses import dataclass
from typing import List, Union
import torch
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass
class Exl2Weight(Weight):
"""
Exllama2 exl2 quantized weights.
"""
q_weight: torch.Tensor
q_scale: torch.Tensor
q_invperm: torch.Tensor
q_scale_max: torch.Tensor
q_groups: torch.Tensor
def __post_init__(self):
self.q_scale_max /= 256
self.q_invperm = self.q_invperm.short()
@property
def device(self) -> torch.device:
return self.q_weight.device
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.gptq import ExllamaQuantLinear
return ExllamaQuantLinear(self, bias)
class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights."""
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str):
# Sharding is not yet supported, so we return the weights as-is.
return self.get_weights(weights, prefix)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2")
def get_weights_row(self, weights: Weights, prefix: str):
# Sharding is not yet supported, so we return the weights as-is.
return self.get_weights(weights, prefix)