mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: load phi weights and produce nonsense tokens
This commit is contained in:
parent
cd8e0b221e
commit
77ee1f18fa
@ -58,6 +58,7 @@ try:
|
|||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
from text_generation_server.models.flash_mistral import FlashMistral
|
from text_generation_server.models.flash_mistral import FlashMistral
|
||||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||||
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -73,6 +74,7 @@ if FLASH_ATTENTION:
|
|||||||
__all__.append(IDEFICSSharded)
|
__all__.append(IDEFICSSharded)
|
||||||
__all__.append(FlashMistral)
|
__all__.append(FlashMistral)
|
||||||
__all__.append(FlashMixtral)
|
__all__.append(FlashMixtral)
|
||||||
|
__all__.append(FlashPhi)
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
@ -229,11 +231,22 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif model_type == "phi":
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return FlashPhi(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Phi"))
|
||||||
|
|
||||||
elif model_type == "phi-msft":
|
elif model_type == "phi-msft":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Flash Phi"))
|
raise NotImplementedError("Legacy phi-msft is not supported with Flash Attention")
|
||||||
elif sharded:
|
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Phi"))
|
|
||||||
else:
|
else:
|
||||||
return Phi(
|
return Phi(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -0,0 +1,435 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
TensorParallelHead,
|
||||||
|
get_linear,
|
||||||
|
FastRMSNorm,
|
||||||
|
FastLayerNorm,
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
class PhiConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=51200,
|
||||||
|
hidden_size=2560,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
hidden_act="gelu_fast",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_scaling=None,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix, weights):
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
# should never get here
|
||||||
|
return _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
if config.model_type == "baichuan":
|
||||||
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.W_pack",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# should be here
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
quantize=config.quantize,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.quantize not in ["gptq", "awq"]:
|
||||||
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(
|
||||||
|
get_linear(weight, bias=None, quantize=config.quantize)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPhiAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
# MAYBE (if not static)
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# should be correct
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
self.dense = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.dense",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
paged_attention.reshape_and_cache(
|
||||||
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
)
|
||||||
|
|
||||||
|
# output tensor
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
flash_attn.attention(
|
||||||
|
query,
|
||||||
|
torch.select(kv, dim=1, index=0),
|
||||||
|
torch.select(kv, dim=1, index=1),
|
||||||
|
attn_output,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
max_s,
|
||||||
|
self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
paged_attention.attention(
|
||||||
|
attn_output,
|
||||||
|
query,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class PhiMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate="tanh"
|
||||||
|
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gate_up_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.fc1",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.fc2",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
post_act = self.act(gate_up_states)
|
||||||
|
return self.down_proj(post_act)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPhiLayer(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"model.layers.{layer_id}"
|
||||||
|
self.self_attn = FlashPhiAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
mlp_output = self.mlp(normed_hidden_states)
|
||||||
|
|
||||||
|
result = attn_output + mlp_output + res
|
||||||
|
|
||||||
|
return result, res
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPhiModel(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix="model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashPhiLayer(
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPhiForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = FlashPhiModel(config, weights)
|
||||||
|
# self.lm_head = TensorParallelHead.load(
|
||||||
|
# config,
|
||||||
|
# prefix="lm_head",
|
||||||
|
# weights=weights,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# TODO: prefer parallel head
|
||||||
|
self.linear = FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: use in correct place
|
||||||
|
self.ln = FastLayerNorm.load(
|
||||||
|
prefix="model.final_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
# 1000K and 10K
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor], # indexes for the items in the batch
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
# paged attention related
|
||||||
|
block_tables: torch.Tensor, # <- indexes into blocks
|
||||||
|
slots: torch.Tensor, # <- indexes into mem
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
# both attentions
|
||||||
|
max_s: int, # <- max sequence length (make kernals chose swap)
|
||||||
|
# small opt (only care about final)
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
||||||
|
normed_hidden_states, res = self.ln(hidden_states, None)
|
||||||
|
logits = self.linear(normed_hidden_states)
|
||||||
|
|
||||||
|
return logits
|
102
server/text_generation_server/models/flash_phi.py
Normal file
102
server/text_generation_server/models/flash_phi.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from text_generation_server.models import FlashCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
|
FlashPhiForCausalLM,
|
||||||
|
PhiConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPhi(FlashCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("FlashPhi is only available on GPU")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = PhiConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
if config.quantize in ["gptq", "awq"]:
|
||||||
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
model = FlashPhiForCausalLM(config, weights)
|
||||||
|
if use_medusa:
|
||||||
|
from text_generation_server.utils.medusa import MedusaModel
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
|
||||||
|
"WEIGHTS_CACHE_OVERRIDE", None
|
||||||
|
) is not None
|
||||||
|
|
||||||
|
if not is_local_model:
|
||||||
|
medusa_config = hf_hub_download(
|
||||||
|
use_medusa, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
|
medusa_head = hf_hub_download(
|
||||||
|
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
medusa_config = str(Path(use_medusa) / "config.json")
|
||||||
|
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||||
|
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||||
|
weights = Weights(
|
||||||
|
[medusa_sf], device, dtype, process_group=self.process_group
|
||||||
|
)
|
||||||
|
lm_head = model.lm_head
|
||||||
|
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(FlashPhi, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_layers=len(model.model.layers),
|
||||||
|
num_kv_heads=model.model.num_key_value_heads,
|
||||||
|
head_size=model.model.head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user