mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
This PR adds basic modeling for phi-2 run ```bash text-generation-server \ serve \ microsoft/phi-2 \ --revision 834565c23f9b28b96ccbeabe614dd906b6db551a ``` test ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq . # { # "generated_text": "\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from data. These" # } ``` notes - recently (~1 day ago) the Phi weights and model were updated to accommodate adding [GQA/MQA attention to the model.](https://github.com/huggingface/transformers/pull/28163) This impl expects the original model format so a fixed revision is required at the moment. - this PR only includes a basic implementation of the model and can later be extended for support Flash and Sharded versions as well as make use of better optimization
309 lines
11 KiB
Python
309 lines
11 KiB
Python
# imlementation of the PhiModel and PhiForCausalLM classes
|
|
|
|
import torch
|
|
import torch.distributed
|
|
|
|
import math
|
|
from torch import nn
|
|
from typing import Optional, List, Tuple, Any
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
from text_generation_server.utils.layers import (
|
|
TensorParallelRowLinear,
|
|
TensorParallelColumnLinear,
|
|
TensorParallelEmbedding,
|
|
TensorParallelHead,
|
|
FastLinear,
|
|
)
|
|
|
|
|
|
# PhiConfig is the configuration class for the PhiModel.
|
|
class PhiConfig(PretrainedConfig):
|
|
def __init__(
|
|
self,
|
|
vocab_size=51200,
|
|
n_positions=2048,
|
|
n_embd=2560,
|
|
n_layer=32,
|
|
n_inner=None,
|
|
n_head=32,
|
|
rotary_dim=32,
|
|
layer_norm_epsilon=1e-5,
|
|
tie_word_embeddings=False,
|
|
pad_vocab_size_multiple=64,
|
|
pad_token_id=0,
|
|
bos_token_id=1,
|
|
eos_token_id=2,
|
|
no_bias=False,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.n_positions = n_positions
|
|
self.n_embd = n_embd
|
|
self.n_layer = n_layer
|
|
self.n_inner = n_inner
|
|
self.n_head = n_head
|
|
self.rotary_dim = rotary_dim
|
|
|
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
self.tie_word_embeddings = tie_word_embeddings
|
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
|
self.pad_token_id = pad_token_id
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
self.no_bias = no_bias
|
|
|
|
super().__init__(
|
|
pad_token_id=pad_token_id,
|
|
bos_token_id=bos_token_id,
|
|
eos_token_id=eos_token_id,
|
|
tie_word_embeddings=tie_word_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
# RotaryEmbedding is a class that implements the rotary embedding.
|
|
class RotaryEmbedding(nn.Module):
|
|
def __init__(self, dim, max_seq_len):
|
|
super().__init__()
|
|
inv_freq = [
|
|
1.0 / 10000.0 ** (i / dim)
|
|
for i in range(0, dim, 2)
|
|
]
|
|
inv_freq_len = len(inv_freq)
|
|
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
|
|
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
|
|
freqs = t.matmul(inv_freq)
|
|
self.sin = freqs.sin()
|
|
self.cos = freqs.cos()
|
|
|
|
def apply_rotary_emb_qkv(self, qkv, seqlen_offset):
|
|
b_size, seqlen, three, _, _headdim = qkv.shape
|
|
if three != 3:
|
|
raise Exception("unexpected shape for qkv")
|
|
_, rotary_dim = self.cos.shape
|
|
rotary_dim = rotary_dim * 2
|
|
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
|
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
|
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
|
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
|
q12 = torch.chunk(q_rot, 2, dim=-1)
|
|
k12 = torch.chunk(k_rot, 2, dim=-1)
|
|
q1, q2 = q12[0], q12[1]
|
|
k1, k2 = k12[0], k12[1]
|
|
c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
|
|
s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
|
|
q_rot = torch.cat(
|
|
[
|
|
q1 * c - q2 * s,
|
|
q1 * s + q2 * c,
|
|
],
|
|
dim=-1,
|
|
)
|
|
k_rot = torch.cat(
|
|
[
|
|
k1 * c - k2 * s,
|
|
k1 * s + k2 * c,
|
|
],
|
|
dim=-1,
|
|
)
|
|
q = torch.cat([q_rot, q_pass], dim=-1)
|
|
k = torch.cat([k_rot, k_pass], dim=-1)
|
|
v = qkv[:, :, 2]
|
|
return q, k, v
|
|
|
|
|
|
# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm.
|
|
class PhiCausalLMHead(nn.Module):
|
|
def __init__(self, config, weights):
|
|
super().__init__()
|
|
self.ln = nn.LayerNorm.load(
|
|
prefix="lm_head.ln",
|
|
weights=weights,
|
|
eps=config.layer_norm_epsilon,
|
|
)
|
|
self.linear = TensorParallelHead.load(
|
|
config=config, prefix="lm_head.linear", weights=weights
|
|
)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.ln(hidden_states)
|
|
hidden_states = self.linear(hidden_states)
|
|
return hidden_states
|
|
|
|
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
|
|
class PhiMHA(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.Wqkv = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
|
)
|
|
self.out_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.out_proj",
|
|
weights=weights,
|
|
bias=not config.no_bias,
|
|
)
|
|
self.op_size = config.n_embd
|
|
self.head_dim = int(config.n_embd / config.n_head)
|
|
self.num_heads = config.n_head
|
|
self.rotary_emb = RotaryEmbedding(
|
|
config.rotary_dim,
|
|
config.n_positions,
|
|
)
|
|
self.softmax_scale = 1.0 / math.sqrt(self.head_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
past_kv_cache,
|
|
attention_mask=None,
|
|
):
|
|
b_size, seq_len, _n_embd = hidden_states.shape
|
|
qkv = self.Wqkv(hidden_states)
|
|
qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim)
|
|
seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1]
|
|
q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset)
|
|
|
|
# if there is a kv_cache, then we need to concatenate
|
|
if past_kv_cache is not None:
|
|
prev_k, prev_v = past_kv_cache
|
|
k = torch.cat([prev_k, k], dim=1)
|
|
v = torch.cat([prev_v, v], dim=1)
|
|
|
|
past_kv_cache = [k, v]
|
|
attn_weights = torch.einsum('bthd,bshd->bhts', q, k * self.softmax_scale)
|
|
|
|
if attention_mask is not None:
|
|
seqlen_k = k.shape[1]
|
|
seqlen_q = q.shape[1]
|
|
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), 1)
|
|
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
|
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
|
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
|
|
attn_output = attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)).transpose(1, 2).flatten(-2)
|
|
return self.out_proj(attn_output), past_kv_cache
|
|
|
|
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
|
|
class PhiMLP(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
|
|
self.n_inner = config.n_inner
|
|
self.fc1 = FastLinear.load(
|
|
config=config,
|
|
prefix=f"{prefix}.fc1",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
self.fc2 = FastLinear.load(
|
|
config=config,
|
|
prefix=f"{prefix}.fc2",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
self.activation = torch.nn.functional.gelu
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
|
|
class PhiBlock(nn.Module):
|
|
def __init__(self, layer_id, config, weights):
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
self.layer_norm = nn.LayerNorm.load(prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon)
|
|
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
|
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
kv_cache,
|
|
attention_mask,
|
|
):
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
|
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
|
out = attn_outputs + feed_forward_hidden_states + residual
|
|
return out, past_kv_cache
|
|
|
|
# PhiModel implements the embedding layer and the transformer blocks.
|
|
class PhiModel(nn.Module):
|
|
def __init__(self, config, weights):
|
|
super().__init__()
|
|
self.tp_rank = weights.process_group.rank()
|
|
self.tp_world_size = weights.process_group.size()
|
|
self.embed_tokens = TensorParallelEmbedding(
|
|
prefix="transformer.embd.wte", weights=weights
|
|
)
|
|
self.blocks = nn.ModuleList(
|
|
[PhiBlock(f"transformer.h.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
attention_mask: Optional[torch.ByteTensor] = None,
|
|
return_dict: Optional[bool] = None,
|
|
use_cache: Optional[bool] = None,
|
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
seq_len = hidden_states.shape[1]
|
|
mask = None if seq_len <= 1 else attention_mask
|
|
|
|
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
|
|
|
|
for index, block in enumerate(self.blocks):
|
|
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
|
|
past_key_values[index] = new_key_values
|
|
|
|
return hidden_states, past_key_values
|
|
|
|
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
|
class PhiForCausalLM(torch.nn.Module):
|
|
def __init__(self, config, weights):
|
|
super().__init__()
|
|
self.model = PhiModel(config, weights)
|
|
self.lm_head = PhiCausalLMHead(config, weights)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
attention_mask: Optional[torch.ByteTensor] = None,
|
|
return_dict: Optional[bool] = None,
|
|
use_cache: Optional[bool] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
model_output = self.model(
|
|
input_ids, past_key_values, attention_mask, return_dict, use_cache
|
|
)
|
|
logits = self.lm_head(model_output[0])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = nn.CrossEntropyLoss()(
|
|
logits[:, :-1].view(-1, logits.size(-1)),
|
|
labels[:, 1:].view(-1)
|
|
)
|
|
|
|
if not return_dict:
|
|
return ((loss,) + (logits,) + model_output[1:]) if loss is not None else (logits,) + model_output[1:]
|
|
|
|
return CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=model_output[1],
|
|
hidden_states=None,
|
|
attentions=None,
|
|
)
|
|
|
|
|