mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Add support for GPTQ Marlin kernels GPTQ Marlin extends the Marlin kernels to support common GPTQ configurations: - bits: 4 or 8 - groupsize: -1, 32, 64, or 128 - desc_act: true/false Using the GPTQ Marlin kernels requires repacking the parameters in the Marlin quantizer format. The kernels were contributed by Neural Magic to VLLM. We vendor them here for convenience.
119 lines
3.5 KiB
Python
119 lines
3.5 KiB
Python
import torch
|
|
import torch.distributed
|
|
|
|
from typing import Optional, Type
|
|
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoConfig,
|
|
PreTrainedTokenizerBase,
|
|
)
|
|
|
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
|
BloomForCausalLM,
|
|
)
|
|
from text_generation_server.models import CausalLM
|
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
|
from text_generation_server.pb import generate_pb2
|
|
from text_generation_server.utils import (
|
|
initialize_torch_distributed,
|
|
weight_files,
|
|
Weights,
|
|
)
|
|
|
|
|
|
class BloomCausalLMBatch(CausalLMBatch):
|
|
@classmethod
|
|
def from_pb(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> "CausalLMBatch":
|
|
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
|
batch.keys_head_dim_last = False
|
|
return batch
|
|
|
|
|
|
class BLOOMSharded(CausalLM):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
quantize: Optional[str] = None,
|
|
speculator: Optional[str] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
trust_remote_code: bool = False,
|
|
):
|
|
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
if torch.cuda.is_available():
|
|
device = torch.device(f"cuda:{rank}")
|
|
dtype = torch.float16 if dtype is None else dtype
|
|
else:
|
|
device = torch.device("cpu")
|
|
dtype = torch.float32 if dtype is None else dtype
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
padding_side="left",
|
|
truncation_side="left",
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
config = AutoConfig.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
slow_but_exact=False,
|
|
tp_parallel=True,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
config.pad_token_id = 3
|
|
config.quantize = quantize
|
|
config.speculator = speculator
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
weights = Weights(
|
|
filenames,
|
|
device=device,
|
|
dtype=dtype,
|
|
process_group=self.process_group,
|
|
prefix="transformer",
|
|
)
|
|
if config.quantize in ["gptq", "marlin"]:
|
|
weights._set_gptq_params(model_id, revision)
|
|
|
|
model = BloomForCausalLM(config, weights)
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
super(CausalLM, self).__init__(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
requires_padding=True,
|
|
dtype=dtype,
|
|
device=device,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
|
|
@property
|
|
def batch_type(self) -> Type[CausalLMBatch]:
|
|
return BloomCausalLMBatch
|
|
|
|
def forward(
|
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
):
|
|
outputs, speculative_logits = self.model.forward(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=True,
|
|
)
|
|
|
|
logits = outputs.logits
|
|
return logits, speculative_logits, outputs.past_key_values
|