mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Add support for GPTQ Marlin kernels GPTQ Marlin extends the Marlin kernels to support common GPTQ configurations: - bits: 4 or 8 - groupsize: -1, 32, 64, or 128 - desc_act: true/false Using the GPTQ Marlin kernels requires repacking the parameters in the Marlin quantizer format. The kernels were contributed by Neural Magic to VLLM. We vendor them here for convenience.
95 lines
3.1 KiB
Python
95 lines
3.1 KiB
Python
import torch
|
|
import torch.distributed
|
|
|
|
from opentelemetry import trace
|
|
from transformers import AutoTokenizer, AutoConfig
|
|
from typing import Optional, List
|
|
import json
|
|
import os
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
from text_generation_server.models import FlashCausalLM
|
|
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
|
FlashSantacoderForCausalLM,
|
|
)
|
|
from text_generation_server.utils import (
|
|
initialize_torch_distributed,
|
|
weight_files,
|
|
Weights,
|
|
)
|
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
class FlashSantacoderSharded(FlashCausalLM):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
quantize: Optional[str] = None,
|
|
speculator: Optional[str] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
trust_remote_code: bool = False,
|
|
):
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
if torch.cuda.is_available():
|
|
device = torch.device(f"cuda:{rank}")
|
|
dtype = torch.float16 if dtype is None else dtype
|
|
elif SYSTEM == "xpu":
|
|
device = torch.device(f"xpu:{rank}")
|
|
dtype = torch.float16 if dtype is None else dtype
|
|
else:
|
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
padding_side="left",
|
|
truncation_side="left",
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
config = AutoConfig.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
trust_remote_code=True,
|
|
)
|
|
config.quantize = quantize
|
|
config.speculator = speculator
|
|
config.transpose = config.architectures[0].startswith("GPT2")
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
weights = Weights(
|
|
filenames,
|
|
device=device,
|
|
dtype=dtype,
|
|
process_group=self.process_group,
|
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
|
)
|
|
if config.quantize in ["gptq", "marlin"]:
|
|
weights._set_gptq_params(model_id, revision)
|
|
|
|
model = FlashSantacoderForCausalLM(config, weights)
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
super(FlashSantacoderSharded, self).__init__(
|
|
model=model.to(device),
|
|
tokenizer=tokenizer,
|
|
num_layers=len(model.transformer.h),
|
|
num_kv_heads=1,
|
|
head_size=model.transformer.head_size,
|
|
dtype=dtype,
|
|
device=device,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
|
|
def decode(self, generated_ids: List[int]) -> str:
|
|
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
|
return self.tokenizer.decode(
|
|
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
)
|