feat(server): Using quantize_config.json instead of GPTQ_BITS env

variables.

- Current PR is not great because we're side stepping the
  `Weights.__init__` but Weights shouldn't requires anything related
  to the config or the model_id as it aims to be a simple Wrapper
  over multi file loading.
- Ideal solution would be to use something like Rust enum
  ```
  enum Quantize{
    Bitandbytes(Bitsandbytes),
    GPTQ(bits: usize, groupsize: usize)
  ```
  And passing that around during load. Unfortunately we don't
  have access to this, so for now, side-stepping seems easier.
This commit is contained in:
Nicolas Patry 2023-07-21 10:12:28 +00:00
parent 37df6df38e
commit c07ee68b60
12 changed files with 49 additions and 19 deletions

View File

@ -76,6 +76,8 @@ class BLOOMSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = BloomForCausalLM(config, weights)

View File

@ -76,7 +76,7 @@ def _load_multi_mqa_gptq(
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_qparams()
bits, groupsize = weights._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA

View File

@ -130,17 +130,17 @@ class OPTAttention(nn.Module):
process_group=None,
):
super().__init__()
embed_dim = config.embed_dim
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
self.embed_dim = embed_dim
self.hidden_size = hidden_size
self.num_heads = num_heads
self.dropout = config.dropout
self.head_dim = embed_dim // num_heads
self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
@ -153,7 +153,7 @@ class OPTAttention(nn.Module):
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size()
self.embed_dim = self.embed_dim // process_group.size()
self.hidden_size = self.hidden_size // process_group.size()
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
@ -300,9 +300,9 @@ class OPTAttention(nn.Module):
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size)
attn_output = self.out_proj(attn_output)
@ -313,7 +313,7 @@ class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights):
super().__init__()
self.process_group = weights.process_group
self.embed_dim = config.hidden_size
self.hidden_size = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
@ -352,7 +352,7 @@ class OPTDecoderLayer(nn.Module):
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size

View File

@ -55,13 +55,15 @@ class FlashLlama(FlashCausalLM):
config = LlamaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
config.quantize = quantize
model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)

View File

@ -52,6 +52,8 @@ class FlashNeoXSharded(FlashCausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = FlashGPTNeoXForCausalLM(config, weights)

View File

@ -58,6 +58,9 @@ class FlashRWSharded(FlashCausalLM):
)
config.quantize = quantize
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = FlashRWForCausalLM(config, weights)

View File

@ -4,7 +4,10 @@ import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
import json
import os
from huggingface_hub import hf_hub_download
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
@ -59,6 +62,8 @@ class FlashSantacoderSharded(FlashCausalLM):
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
if config.quantize == "gptq":
weights.set_gptq_params(model_id)
model = FlashSantacoderForCausalLM(config, weights)

View File

@ -191,6 +191,8 @@ class GalacticaSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = OPTForCausalLM(config, weights)

View File

@ -56,6 +56,8 @@ class GPTNeoxSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = GPTNeoxForCausalLM(config, weights)

View File

@ -78,6 +78,8 @@ class MPTSharded(CausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
config.quantize = quantize
model = MPTForCausalLM(config, weights)

View File

@ -54,6 +54,8 @@ class OPTSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
model = OPTForCausalLM(config, weights)

View File

@ -128,7 +128,7 @@ class Weights:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
bits, groupsize = self._get_gptq_qparams()
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -138,7 +138,7 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_exllama = True
bits, groupsize = self._get_gptq_qparams()
bits, groupsize = self._get_gptq_params()
if bits != 4:
use_exllama = False
@ -208,17 +208,25 @@ class Weights:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_qparams(self) -> Tuple[int, int]:
def _get_gptq_params(self) -> Tuple[int, int]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except (SafetensorError, RuntimeError) as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
bits = self.gptq_bits
groupsize = self.gptq_groupsize
except Exception:
raise e
return bits, groupsize
def _set_gptq_params(self, model_id):
try:
filename = hf_hub_download(model_id, filename="quantize_config.json")
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
except Exception:
pass