mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
feat(server): Using quantize_config.json
instead of GPTQ_BITS env
variables. - Current PR is not great because we're side stepping the `Weights.__init__` but Weights shouldn't requires anything related to the config or the model_id as it aims to be a simple Wrapper over multi file loading. - Ideal solution would be to use something like Rust enum ``` enum Quantize{ Bitandbytes(Bitsandbytes), GPTQ(bits: usize, groupsize: usize) ``` And passing that around during load. Unfortunately we don't have access to this, so for now, side-stepping seems easier.
This commit is contained in:
parent
37df6df38e
commit
c07ee68b60
@ -76,6 +76,8 @@ class BLOOMSharded(CausalLM):
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
model = BloomForCausalLM(config, weights)
|
||||
|
||||
|
@ -76,7 +76,7 @@ def _load_multi_mqa_gptq(
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
g_idx = g_idx.to(device=weights.device)
|
||||
bits, groupsize = weights._get_gptq_qparams()
|
||||
bits, groupsize = weights._get_gptq_params()
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
|
||||
|
@ -130,17 +130,17 @@ class OPTAttention(nn.Module):
|
||||
process_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
embed_dim = config.embed_dim
|
||||
hidden_size = config.hidden_size
|
||||
num_heads = config.num_attention_heads
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.dropout = config.dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
if (self.head_dim * num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
@ -153,7 +153,7 @@ class OPTAttention(nn.Module):
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.embed_dim = self.embed_dim // process_group.size()
|
||||
self.hidden_size = self.hidden_size // process_group.size()
|
||||
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
|
||||
@ -300,9 +300,9 @@ class OPTAttention(nn.Module):
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned aross GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
@ -313,7 +313,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
def __init__(self, layer_id: int, config: OPTConfig, weights):
|
||||
super().__init__()
|
||||
self.process_group = weights.process_group
|
||||
self.embed_dim = config.hidden_size
|
||||
self.hidden_size = config.hidden_size
|
||||
prefix = f"model.decoder.layers.{layer_id}"
|
||||
self.self_attn = OPTAttention(
|
||||
config,
|
||||
@ -352,7 +352,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
|
||||
|
@ -55,13 +55,15 @@ class FlashLlama(FlashCausalLM):
|
||||
config = LlamaConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
config.quantize = quantize
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -52,6 +52,8 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
||||
|
||||
|
@ -58,6 +58,9 @@ class FlashRWSharded(FlashCausalLM):
|
||||
)
|
||||
|
||||
config.quantize = quantize
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
|
||||
model = FlashRWForCausalLM(config, weights)
|
||||
|
||||
|
@ -4,7 +4,10 @@ import torch.distributed
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, List
|
||||
import json
|
||||
import os
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||
FlashSantacoderForCausalLM,
|
||||
@ -59,6 +62,8 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
process_group=self.process_group,
|
||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights.set_gptq_params(model_id)
|
||||
|
||||
model = FlashSantacoderForCausalLM(config, weights)
|
||||
|
||||
|
@ -191,6 +191,8 @@ class GalacticaSharded(CausalLM):
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
|
@ -56,6 +56,8 @@ class GPTNeoxSharded(CausalLM):
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
model = GPTNeoxForCausalLM(config, weights)
|
||||
|
||||
|
@ -78,6 +78,8 @@ class MPTSharded(CausalLM):
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
config.quantize = quantize
|
||||
model = MPTForCausalLM(config, weights)
|
||||
|
@ -54,6 +54,8 @@ class OPTSharded(CausalLM):
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
|
@ -128,7 +128,7 @@ class Weights:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
bits, groupsize = self._get_gptq_qparams()
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
@ -138,7 +138,7 @@ class Weights:
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
use_exllama = True
|
||||
bits, groupsize = self._get_gptq_qparams()
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
|
||||
if bits != 4:
|
||||
use_exllama = False
|
||||
@ -208,17 +208,25 @@ class Weights:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_qparams(self) -> Tuple[int, int]:
|
||||
def _get_gptq_params(self) -> Tuple[int, int]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
bits = self.gptq_bits
|
||||
groupsize = self.gptq_groupsize
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return bits, groupsize
|
||||
|
||||
def _set_gptq_params(self, model_id):
|
||||
try:
|
||||
filename = hf_hub_download(model_id, filename="quantize_config.json")
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["bits"]
|
||||
self.gptq_groupsize = data["group_size"]
|
||||
except Exception:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user