mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Mostly straightforward, changes to existing code: * Wrap quantizer parameters in a small wrapper to avoid passing around untyped tuples and needing to repack them as a dict. * Move scratch space computation to warmup, because we need the maximum input sequence length to avoid allocating huge scratch buffers that OOM.
87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
import torch
|
|
import torch.distributed
|
|
|
|
from opentelemetry import trace
|
|
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
|
from typing import Optional
|
|
|
|
from text_generation_server.models import FlashCausalLM
|
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
FlashLlamaForCausalLM,
|
|
)
|
|
from text_generation_server.utils import (
|
|
initialize_torch_distributed,
|
|
weight_files,
|
|
Weights,
|
|
)
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
|
|
class FlashLlama(FlashCausalLM):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
quantize: Optional[str] = None,
|
|
speculator: Optional[str] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
trust_remote_code: bool = False,
|
|
):
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
if torch.cuda.is_available():
|
|
device = torch.device(f"cuda:{rank}")
|
|
dtype = torch.float16 if dtype is None else dtype
|
|
elif SYSTEM == "xpu":
|
|
device = torch.device(f"xpu:{rank}")
|
|
dtype = torch.float16 if dtype is None else dtype
|
|
else:
|
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
padding_side="left",
|
|
truncation_side="left",
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
try:
|
|
generation_config = GenerationConfig.from_pretrained(
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
)
|
|
if isinstance(generation_config.eos_token_id, (list, set)):
|
|
# TODO Huge hack
|
|
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
|
except Exception:
|
|
pass
|
|
|
|
config = AutoConfig.from_pretrained(
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
)
|
|
config.quantize = quantize
|
|
config.speculator = speculator
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
if config.quantize in ["gptq", "awq", "exl2"]:
|
|
weights._set_gptq_params(model_id, revision)
|
|
|
|
prefix = ""
|
|
model = FlashLlamaForCausalLM(prefix, config, weights)
|
|
torch.distributed.barrier(group=self.process_group)
|
|
super(FlashLlama, self).__init__(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
num_layers=len(model.model.layers),
|
|
num_kv_heads=model.model.num_key_value_heads,
|
|
head_size=model.model.head_size,
|
|
dtype=dtype,
|
|
device=device,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|