mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fmt
This commit is contained in:
parent
bc04a059c9
commit
d0ddc80c31
@ -21,7 +21,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.generated_text == ": {request}\")\n response = self"
|
||||
assert response.generated_text == ': {request}")\n response = self'
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@ -52,14 +52,12 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_phi, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"{[r.generated_text for r in responses]}"
|
||||
assert responses[0].generated_text == ": {request}\")\n response = self"
|
||||
assert responses[0].generated_text == ': {request}")\n response = self'
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
@ -3,6 +3,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class ProcessGroup:
|
||||
def __init__(self, rank: int, world_size: int):
|
||||
self._rank = rank
|
||||
@ -14,12 +15,14 @@ class ProcessGroup:
|
||||
def rank(self) -> int:
|
||||
return self._rank
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
|
||||
self.weight = torch.arange(vocab_size*hidden_dim).float().view(vocab_size, hidden_dim)
|
||||
self.weight = (
|
||||
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
|
||||
)
|
||||
self.process_group = ProcessGroup(rank, world_size)
|
||||
|
||||
|
||||
def get_partial_sharded(self, name: str, dim: int):
|
||||
assert dim == 0
|
||||
|
||||
@ -35,6 +38,7 @@ class Weights:
|
||||
def get_shape(self, name: str):
|
||||
return self.weight.shape
|
||||
|
||||
|
||||
def test_weight_hub_files_offline_error():
|
||||
|
||||
vocab_size = 17
|
||||
@ -52,13 +56,22 @@ def test_weight_hub_files_offline_error():
|
||||
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
|
||||
assert embeddings_0_2.min_id == 0
|
||||
assert embeddings_0_2.max_id == 9
|
||||
torch.testing.assert_close(embeddings_0_2.weight , torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0).view(10, 256).float())
|
||||
torch.testing.assert_close(
|
||||
embeddings_0_2.weight,
|
||||
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
|
||||
.view(10, 256)
|
||||
.float(),
|
||||
)
|
||||
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
|
||||
assert embeddings_1_2.min_id == 9
|
||||
assert embeddings_1_2.max_id == 17
|
||||
torch.testing.assert_close(embeddings_1_2.weight , torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0).view(9, 256).float())
|
||||
torch.testing.assert_close(
|
||||
embeddings_1_2.weight,
|
||||
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
|
||||
.view(9, 256)
|
||||
.float(),
|
||||
)
|
||||
output_tp_0 = embeddings_0_2.forward(input_ids)
|
||||
output_tp_1 = embeddings_1_2.forward(input_ids)
|
||||
|
||||
torch.testing.assert_close(output, output_tp_0 + output_tp_1)
|
||||
|
||||
|
@ -252,7 +252,9 @@ def get_model(
|
||||
|
||||
elif model_type == "phi-msft":
|
||||
if FLASH_ATTENTION:
|
||||
raise NotImplementedError("Legacy phi-msft is not supported with Flash Attention")
|
||||
raise NotImplementedError(
|
||||
"Legacy phi-msft is not supported with Flash Attention"
|
||||
)
|
||||
else:
|
||||
return Phi(
|
||||
model_id,
|
||||
|
@ -17,6 +17,7 @@ from text_generation_server.utils.layers import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
@ -55,6 +56,7 @@ class PhiConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# this is the same as llama except for Phi uses bias=True
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
@ -68,6 +70,7 @@ def load_attention(config, prefix, weights):
|
||||
bias=True,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
@ -94,6 +97,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||
get_linear(weight, bias=True, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class FlashPhiAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -173,8 +177,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
#
|
||||
# Apply partial positional embeddings in place
|
||||
self.rotary_emb(
|
||||
query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
|
||||
cos, sin
|
||||
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
|
||||
)
|
||||
|
||||
# Reshape key and value and cache
|
||||
@ -212,6 +215,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
@ -256,7 +260,9 @@ class FlashPhiLayer(nn.Module):
|
||||
)
|
||||
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.input_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
|
||||
|
||||
@ -287,10 +293,13 @@ class FlashPhiLayer(nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))
|
||||
hidden_states = self.resid_dropout(attn_output).add(
|
||||
self.resid_dropout(self.mlp(hidden_states))
|
||||
)
|
||||
|
||||
return hidden_states, res
|
||||
|
||||
|
||||
class FlashPhiModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
@ -361,6 +370,7 @@ class FlashPhiModel(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashPhiForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
@ -54,9 +54,19 @@ def load_col(config, prefix, weights, bias):
|
||||
bias_h = bias_h[0]
|
||||
bias_block_size = bias_h // bias_size
|
||||
|
||||
bias_q_part = bias_slice_[bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size]
|
||||
bias_k_part = bias_slice_[bias_h + bias_rank * bias_block_size : bias_h + (bias_rank + 1) * bias_block_size]
|
||||
bias_v_part = bias_slice_[2 * bias_h + bias_rank * bias_block_size : 2 * bias_h + (bias_rank + 1) * bias_block_size]
|
||||
bias_q_part = bias_slice_[
|
||||
bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size
|
||||
]
|
||||
bias_k_part = bias_slice_[
|
||||
bias_h
|
||||
+ bias_rank * bias_block_size : bias_h
|
||||
+ (bias_rank + 1) * bias_block_size
|
||||
]
|
||||
bias_v_part = bias_slice_[
|
||||
2 * bias_h
|
||||
+ bias_rank * bias_block_size : 2 * bias_h
|
||||
+ (bias_rank + 1) * bias_block_size
|
||||
]
|
||||
|
||||
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
|
||||
if bias.dtype != torch.int32:
|
||||
@ -352,8 +362,12 @@ class MultiheadAttention(nn.Module):
|
||||
hidden_size = config.d_model
|
||||
head_dim = hidden_size // self.n_heads
|
||||
|
||||
self.q_ln = LPLayerNorm(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
|
||||
self.k_ln = LPLayerNorm(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
|
||||
self.q_ln = LPLayerNorm(
|
||||
d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights
|
||||
)
|
||||
self.k_ln = LPLayerNorm(
|
||||
self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights
|
||||
)
|
||||
if self.attn_impl == "flash":
|
||||
self.attn_fn = flash_attn_fn
|
||||
elif self.attn_impl == "triton":
|
||||
@ -684,7 +698,6 @@ class LPLayerNorm(torch.nn.LayerNorm):
|
||||
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
|
||||
self.normalized_shape = self.weight.shape
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
module_device = x.device
|
||||
downcast_x = _cast_if_autocast_enabled(x)
|
||||
|
@ -62,14 +62,12 @@ class PhiConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# RotaryEmbedding is a class that implements the rotary embedding.
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
inv_freq = [
|
||||
1.0 / 10000.0 ** (i / dim)
|
||||
for i in range(0, dim, 2)
|
||||
]
|
||||
inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]
|
||||
inv_freq_len = len(inv_freq)
|
||||
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
|
||||
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
|
||||
@ -131,6 +129,7 @@ class PhiCausalLMHead(nn.Module):
|
||||
hidden_states = self.linear(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
|
||||
class PhiMHA(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
@ -172,19 +171,27 @@ class PhiMHA(nn.Module):
|
||||
v = torch.cat([prev_v, v], dim=1)
|
||||
|
||||
past_kv_cache = [k, v]
|
||||
attn_weights = torch.einsum('bthd,bshd->bhts', q, k * self.softmax_scale)
|
||||
attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale)
|
||||
|
||||
if attention_mask is not None:
|
||||
seqlen_k = k.shape[1]
|
||||
seqlen_q = q.shape[1]
|
||||
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), 1)
|
||||
causal_mask = torch.triu(
|
||||
torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),
|
||||
1,
|
||||
)
|
||||
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
|
||||
attn_output = attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)).transpose(1, 2).flatten(-2)
|
||||
attn_output = (
|
||||
attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))
|
||||
.transpose(1, 2)
|
||||
.flatten(-2)
|
||||
)
|
||||
return self.out_proj(attn_output), past_kv_cache
|
||||
|
||||
|
||||
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
@ -211,12 +218,15 @@ class PhiMLP(nn.Module):
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
|
||||
class PhiBlock(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.layer_norm = nn.LayerNorm.load(prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon)
|
||||
self.layer_norm = nn.LayerNorm.load(
|
||||
prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
||||
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
|
||||
|
||||
@ -228,11 +238,14 @@ class PhiBlock(nn.Module):
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
|
||||
attn_outputs, past_kv_cache = self.mixer(
|
||||
hidden_states, kv_cache, attention_mask
|
||||
)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
out = attn_outputs + feed_forward_hidden_states + residual
|
||||
return out, past_kv_cache
|
||||
|
||||
|
||||
# PhiModel implements the embedding layer and the transformer blocks.
|
||||
class PhiModel(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
@ -243,7 +256,10 @@ class PhiModel(nn.Module):
|
||||
prefix="transformer.embd.wte", weights=weights
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[PhiBlock(f"transformer.h.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
|
||||
[
|
||||
PhiBlock(f"transformer.h.{layer_id}", config, weights)
|
||||
for layer_id in range(config.n_layer)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -258,14 +274,19 @@ class PhiModel(nn.Module):
|
||||
seq_len = hidden_states.shape[1]
|
||||
mask = None if seq_len <= 1 else attention_mask
|
||||
|
||||
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
|
||||
past_key_values = (
|
||||
[None] * len(self.blocks) if past_key_values is None else past_key_values
|
||||
)
|
||||
|
||||
for index, block in enumerate(self.blocks):
|
||||
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
|
||||
hidden_states, new_key_values = block(
|
||||
hidden_states, past_key_values[index], mask
|
||||
)
|
||||
past_key_values[index] = new_key_values
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
||||
|
||||
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
||||
class PhiForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
@ -290,12 +311,15 @@ class PhiForCausalLM(torch.nn.Module):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = nn.CrossEntropyLoss()(
|
||||
logits[:, :-1].view(-1, logits.size(-1)),
|
||||
labels[:, 1:].view(-1)
|
||||
logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return ((loss,) + (logits,) + model_output[1:]) if loss is not None else (logits,) + model_output[1:]
|
||||
return (
|
||||
((loss,) + (logits,) + model_output[1:])
|
||||
if loss is not None
|
||||
else (logits,) + model_output[1:]
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
@ -304,5 +328,3 @@ class PhiForCausalLM(torch.nn.Module):
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -74,9 +74,9 @@ class FlashLlama(FlashCausalLM):
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
|
||||
"WEIGHTS_CACHE_OVERRIDE", None
|
||||
) is not None
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
|
@ -64,9 +64,9 @@ class FlashPhi(FlashCausalLM):
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
|
||||
"WEIGHTS_CACHE_OVERRIDE", None
|
||||
) is not None
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
|
@ -5,13 +5,17 @@ from transformers import AutoConfig, AutoTokenizer
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.custom_modeling.phi_modeling import PhiConfig, PhiForCausalLM
|
||||
from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||
PhiConfig,
|
||||
PhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class Phi(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
@ -60,4 +64,3 @@ class Phi(CausalLM):
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
@ -510,7 +510,9 @@ class TensorParallelEmbedding(nn.Module):
|
||||
block_size = (num_embeddings + world_size - 1) // world_size
|
||||
self.min_id = rank * block_size
|
||||
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||
self.null_idx = weight.shape[0] # Usually block_size, might be less in non even vocab_size.
|
||||
self.null_idx = weight.shape[
|
||||
0
|
||||
] # Usually block_size, might be less in non even vocab_size.
|
||||
self.process_group = weights.process_group
|
||||
self.reduce = reduce
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user