mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
* feat: first draft load multiple lora * feat: load weights within layer and refactor lora pass * fix: refactor and reduce lora math * feat: baseline impl single request multi lora support * feat: prefer lorax implementation and port loading logic * fix: prefer adapter_data and refactors * feat: perfer loraxs custom punica kernels and add mlp loras * fix: adjust batch for bgmv * fix: adjust adapter_segments logic when in batch * fix: refactor and move changes to v3 proto * fix: pass model_id for all flash causal lms * fix: pass model_id for all causal and seq2seq lms * fix: add model_id to model test * feat: add lora support to mistral and refactors * feat: prefer model id in request * fix: include rust code for adapter id * feat: bump launcher and add new lora docs * feat: support base model generation and refactors * fix: rename doc to retry ci build * feat: support if vlm models * fix: add adapter_data param and avoid missing layers * fix: add adapter_data param to phi and neox * fix: update all models forwards to include adapter_data * fix: add model_id to IdeficsCausalLM * Update lora.md Fixed a typo * Update lora.md Fixing spam image * fix: add lora kernel to dockerfile, support running without kernels and refactors * fix: avoid dockerfile conflict * fix: refactors and adjust flash llama lora logic * fix: skip llama test due to CI issue (temp) * fix: skip llama test CI (temp) 2 * fix: revert skips and prefer updated ci token for tests * fix: refactors and helpful comments * fix: add noop in TensorParallelAdapterRowLinear too * fix: refactor and move shard_lora_weights logic * fix: exit early if no adapter_data --------- Co-authored-by: Derek <datavistics@gmail.com>
92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
import torch
|
|
import torch.distributed
|
|
|
|
from opentelemetry import trace
|
|
from transformers import AutoTokenizer
|
|
from typing import Optional
|
|
|
|
from text_generation_server.models import FlashCausalLM
|
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
|
RWConfig,
|
|
FlashRWForCausalLM,
|
|
)
|
|
from text_generation_server.utils import (
|
|
initialize_torch_distributed,
|
|
weight_files,
|
|
Weights,
|
|
)
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
class FlashRWSharded(FlashCausalLM):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
quantize: Optional[str] = None,
|
|
speculator: Optional[str] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
trust_remote_code: bool = False,
|
|
):
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
if torch.cuda.is_available():
|
|
device = torch.device(f"cuda:{rank}")
|
|
dtype = torch.float16 if dtype is None else dtype
|
|
elif SYSTEM == "ipex":
|
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
device = torch.device(f"xpu:{rank}")
|
|
dtype = torch.float16 if dtype is None else dtype
|
|
else:
|
|
device = torch.device("cpu")
|
|
dtype = torch.bfloat16 if dtype is None else dtype
|
|
else:
|
|
raise NotImplementedError("FlashRW is only available on GPU")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
padding_side="left",
|
|
truncation_side="left",
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
config = RWConfig.from_pretrained(
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
)
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
weights = Weights(
|
|
filenames,
|
|
device,
|
|
dtype,
|
|
process_group=self.process_group,
|
|
aliases={
|
|
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
|
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
|
},
|
|
)
|
|
|
|
config.quantize = quantize
|
|
config.speculator = speculator
|
|
if config.quantize in ["gptq", "marlin"]:
|
|
weights._set_gptq_params(model_id, revision)
|
|
|
|
model = FlashRWForCausalLM(config, weights)
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
super(FlashRWSharded, self).__init__(
|
|
model_id=model_id,
|
|
model=model.to(device),
|
|
tokenizer=tokenizer,
|
|
num_layers=len(model.transformer.h),
|
|
num_kv_heads=model.transformer.cache_size,
|
|
head_size=model.transformer.head_size,
|
|
dtype=dtype,
|
|
device=device,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|