mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
This PR adds basic modeling for phi-2 run ```bash text-generation-server \ serve \ microsoft/phi-2 \ --revision 834565c23f9b28b96ccbeabe614dd906b6db551a ``` test ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq . ``` notes - recently (~1 day ago) the Phi weights and model were updated to accommodate adding [GQA/MQA attention to the model.](https://github.com/huggingface/transformers/pull/28163) This impl expects the original model format so a fixed revision is required at the moment. - this PR only includes a basic implementation of the model and can later be extended for support Flash and Sharded versions as well as make use of better optimization
103 lines
3.5 KiB
Python
103 lines
3.5 KiB
Python
import torch
|
|
import torch.distributed
|
|
|
|
from opentelemetry import trace
|
|
from transformers import AutoConfig, AutoTokenizer
|
|
from typing import Optional
|
|
|
|
from text_generation_server.models import FlashCausalLM
|
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
|
FlashPhiForCausalLM,
|
|
PhiConfig,
|
|
)
|
|
from text_generation_server.utils import (
|
|
initialize_torch_distributed,
|
|
weight_files,
|
|
Weights,
|
|
)
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
class FlashPhi(FlashCausalLM):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
quantize: Optional[str] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
trust_remote_code: bool = False,
|
|
use_medusa: Optional[str] = None,
|
|
):
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
if torch.cuda.is_available():
|
|
device = torch.device(f"cuda:{rank}")
|
|
dtype = torch.float16 if dtype is None else dtype
|
|
else:
|
|
raise NotImplementedError("FlashPhi is only available on GPU")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
padding_side="left",
|
|
truncation_side="left",
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
config = PhiConfig.from_pretrained(
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
)
|
|
config.quantize = quantize
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
if config.quantize in ["gptq", "awq"]:
|
|
weights._set_gptq_params(model_id, revision)
|
|
|
|
model = FlashPhiForCausalLM(config, weights)
|
|
if use_medusa:
|
|
from text_generation_server.utils.medusa import MedusaModel
|
|
from huggingface_hub import hf_hub_download
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
|
|
"WEIGHTS_CACHE_OVERRIDE", None
|
|
) is not None
|
|
|
|
if not is_local_model:
|
|
medusa_config = hf_hub_download(
|
|
use_medusa, revision=revision, filename="config.json"
|
|
)
|
|
medusa_head = hf_hub_download(
|
|
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
|
)
|
|
else:
|
|
medusa_config = str(Path(use_medusa) / "config.json")
|
|
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
|
|
|
with open(medusa_config, "r") as f:
|
|
config = json.load(f)
|
|
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
|
weights = Weights(
|
|
[medusa_sf], device, dtype, process_group=self.process_group
|
|
)
|
|
lm_head = model.lm_head
|
|
model.lm_head = MedusaModel(config, weights, lm_head)
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
super(FlashPhi, self).__init__(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
num_layers=len(model.model.layers),
|
|
num_kv_heads=model.model.num_key_value_heads,
|
|
head_size=model.model.head_size,
|
|
dtype=dtype,
|
|
device=device,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|