text-generation-inference/server/text_generation_server/models/flash_starcoder2.py
Daniël de Kok 093a27c528
Add support for GPTQ Marlin (#2052)
Add support for GPTQ Marlin kernels

GPTQ Marlin extends the Marlin kernels to support common GPTQ
configurations:

- bits: 4 or 8
- groupsize: -1, 32, 64, or 128
- desc_act: true/false

Using the GPTQ Marlin kernels requires repacking the parameters in the
Marlin quantizer format.

The kernels were contributed by Neural Magic to VLLM. We vendor them
here for convenience.
2024-06-14 09:45:42 +02:00

84 lines
2.6 KiB
Python

import math
import torch
from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
Starcoder2Config,
FlashStarcoder2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
# Starcoder2 has the same base as Mistral
class FlashStarcoder2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashStarcoder2 is only available on GPU")
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = Starcoder2Config.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)