From 7e2a7433d3584a5a68dbf3e71def4323079f2c26 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 25 Jan 2024 09:37:53 -0500 Subject: [PATCH] feat: adds phi model (#1442) This PR adds basic modeling for phi-2 run ```bash text-generation-server \ serve \ microsoft/phi-2 \ --revision 834565c23f9b28b96ccbeabe614dd906b6db551a ``` test ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq . # { # "generated_text": "\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from data. These" # } ``` notes - recently (~1 day ago) the Phi weights and model were updated to accommodate adding [GQA/MQA attention to the model.](https://github.com/huggingface/transformers/pull/28163) This impl expects the original model format so a fixed revision is required at the moment. - this PR only includes a basic implementation of the model and can later be extended for support Flash and Sharded versions as well as make use of better optimization --- .../test_flash_phi/test_flash_phi.json | 84 ++++ .../test_flash_phi_all_params.json | 60 +++ .../test_flash_phi/test_flash_phi_load.json | 338 +++++++++++++++ integration-tests/models/test_flash_phi.py | 65 +++ server/pyproject.toml | 2 +- .../text_generation_server/models/__init__.py | 34 ++ .../custom_modeling/flash_phi_modeling.py | 400 ++++++++++++++++++ .../models/custom_modeling/phi_modeling.py | 308 ++++++++++++++ .../models/flash_phi.py | 102 +++++ server/text_generation_server/models/phi.py | 63 +++ 10 files changed, 1455 insertions(+), 1 deletion(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi.json create mode 100644 integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_load.json create mode 100644 integration-tests/models/test_flash_phi.py create mode 100644 server/text_generation_server/models/custom_modeling/flash_phi_modeling.py create mode 100644 server/text_generation_server/models/custom_modeling/phi_modeling.py create mode 100644 server/text_generation_server/models/flash_phi.py create mode 100644 server/text_generation_server/models/phi.py diff --git a/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi.json b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi.json new file mode 100644 index 00000000..51d969b2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi.json @@ -0,0 +1,84 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.3203125, + "special": false, + "text": ":" + }, + { + "id": 1391, + "logprob": -0.98779297, + "special": false, + "text": " {" + }, + { + "id": 25927, + "logprob": -0.76660156, + "special": false, + "text": "request" + }, + { + "id": 92, + "logprob": -0.7246094, + "special": false, + "text": "}" + }, + { + "id": 4943, + "logprob": -0.41333008, + "special": false, + "text": "\")" + }, + { + "id": 198, + "logprob": -0.11785889, + "special": false, + "text": "\n" + }, + { + "id": 50280, + "logprob": -0.97265625, + "special": false, + "text": " " + }, + { + "id": 26209, + "logprob": -1.4414062, + "special": false, + "text": "response" + }, + { + "id": 796, + "logprob": -0.0569458, + "special": false, + "text": " =" + }, + { + "id": 2116, + "logprob": -1.1533203, + "special": false, + "text": " self" + } + ], + "top_tokens": null + }, + "generated_text": ": {request}\")\n response = self" +} diff --git a/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_all_params.json b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_all_params.json new file mode 100644 index 00000000..221ff13d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_all_params.json @@ -0,0 +1,60 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "stop_sequence", + "generated_tokens": 6, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 284, + "logprob": -0.19421387, + "special": false, + "text": " to" + }, + { + "id": 3758, + "logprob": -0.62597656, + "special": false, + "text": " send" + }, + { + "id": 1366, + "logprob": -0.87060547, + "special": false, + "text": " data" + }, + { + "id": 625, + "logprob": -0.88427734, + "special": false, + "text": " over" + }, + { + "id": 257, + "logprob": -1.0830078, + "special": false, + "text": " a" + }, + { + "id": 3127, + "logprob": -1.9462891, + "special": false, + "text": " network" + } + ], + "top_tokens": null + }, + "generated_text": "Test request to send data over a network" +} diff --git a/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_load.json b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_load.json new file mode 100644 index 00000000..62f7fd32 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_load.json @@ -0,0 +1,338 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.3203125, + "special": false, + "text": ":" + }, + { + "id": 1391, + "logprob": -0.98779297, + "special": false, + "text": " {" + }, + { + "id": 25927, + "logprob": -0.7729492, + "special": false, + "text": "request" + }, + { + "id": 92, + "logprob": -0.7241211, + "special": false, + "text": "}" + }, + { + "id": 4943, + "logprob": -0.4091797, + "special": false, + "text": "\")" + }, + { + "id": 198, + "logprob": -0.119018555, + "special": false, + "text": "\n" + }, + { + "id": 50280, + "logprob": -0.9707031, + "special": false, + "text": " " + }, + { + "id": 26209, + "logprob": -1.4414062, + "special": false, + "text": "response" + }, + { + "id": 796, + "logprob": -0.056854248, + "special": false, + "text": " =" + }, + { + "id": 2116, + "logprob": -1.1533203, + "special": false, + "text": " self" + } + ], + "top_tokens": null + }, + "generated_text": ": {request}\")\n response = self" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.3203125, + "special": false, + "text": ":" + }, + { + "id": 1391, + "logprob": -0.98779297, + "special": false, + "text": " {" + }, + { + "id": 25927, + "logprob": -0.7729492, + "special": false, + "text": "request" + }, + { + "id": 92, + "logprob": -0.7241211, + "special": false, + "text": "}" + }, + { + "id": 4943, + "logprob": -0.4091797, + "special": false, + "text": "\")" + }, + { + "id": 198, + "logprob": -0.119018555, + "special": false, + "text": "\n" + }, + { + "id": 50280, + "logprob": -0.9707031, + "special": false, + "text": " " + }, + { + "id": 26209, + "logprob": -1.4414062, + "special": false, + "text": "response" + }, + { + "id": 796, + "logprob": -0.056854248, + "special": false, + "text": " =" + }, + { + "id": 2116, + "logprob": -1.1533203, + "special": false, + "text": " self" + } + ], + "top_tokens": null + }, + "generated_text": ": {request}\")\n response = self" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.3203125, + "special": false, + "text": ":" + }, + { + "id": 1391, + "logprob": -0.98779297, + "special": false, + "text": " {" + }, + { + "id": 25927, + "logprob": -0.7729492, + "special": false, + "text": "request" + }, + { + "id": 92, + "logprob": -0.7241211, + "special": false, + "text": "}" + }, + { + "id": 4943, + "logprob": -0.4091797, + "special": false, + "text": "\")" + }, + { + "id": 198, + "logprob": -0.119018555, + "special": false, + "text": "\n" + }, + { + "id": 50280, + "logprob": -0.9707031, + "special": false, + "text": " " + }, + { + "id": 26209, + "logprob": -1.4414062, + "special": false, + "text": "response" + }, + { + "id": 796, + "logprob": -0.056854248, + "special": false, + "text": " =" + }, + { + "id": 2116, + "logprob": -1.1533203, + "special": false, + "text": " self" + } + ], + "top_tokens": null + }, + "generated_text": ": {request}\")\n response = self" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 14402, + "logprob": null, + "text": "Test" + }, + { + "id": 2581, + "logprob": -11.6171875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 25, + "logprob": -2.3203125, + "special": false, + "text": ":" + }, + { + "id": 1391, + "logprob": -0.98779297, + "special": false, + "text": " {" + }, + { + "id": 25927, + "logprob": -0.7729492, + "special": false, + "text": "request" + }, + { + "id": 92, + "logprob": -0.7241211, + "special": false, + "text": "}" + }, + { + "id": 4943, + "logprob": -0.4091797, + "special": false, + "text": "\")" + }, + { + "id": 198, + "logprob": -0.119018555, + "special": false, + "text": "\n" + }, + { + "id": 50280, + "logprob": -0.9707031, + "special": false, + "text": " " + }, + { + "id": 26209, + "logprob": -1.4414062, + "special": false, + "text": "response" + }, + { + "id": 796, + "logprob": -0.056854248, + "special": false, + "text": " =" + }, + { + "id": 2116, + "logprob": -1.1533203, + "special": false, + "text": " self" + } + ], + "top_tokens": null + }, + "generated_text": ": {request}\")\n response = self" + } +] diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py new file mode 100644 index 00000000..6391f2a1 --- /dev/null +++ b/integration-tests/models/test_flash_phi.py @@ -0,0 +1,65 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_phi_handle(launcher): + with launcher("microsoft/phi-2", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_phi(flash_phi_handle): + await flash_phi_handle.health(300) + return flash_phi_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_phi(flash_phi, response_snapshot): + response = await flash_phi.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == ": {request}\")\n response = self" + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_phi_all_params(flash_phi, response_snapshot): + response = await flash_phi.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["network"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 6 + assert response.generated_text == "Test request to send data over a network" + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): + responses = await generate_load( + flash_phi, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + assert responses[0].generated_text == ": {request}\")\n response = self" + + assert responses == response_snapshot diff --git a/server/pyproject.toml b/server/pyproject.toml index 6e9be43e..d1452678 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,7 +26,7 @@ hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "^0.15.0" huggingface-hub = "^0.19.3" -transformers = "^4.36.1" +transformers = "^4.37.1" einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 39d1d58e..679e1e2f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,6 +18,7 @@ from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded +from text_generation_server.models.phi import Phi # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -57,6 +58,7 @@ try: from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mixtral import FlashMixtral + from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA except ImportError as e: @@ -72,6 +74,7 @@ if FLASH_ATTENTION: __all__.append(IDEFICSSharded) __all__.append(FlashMistral) __all__.append(FlashMixtral) + __all__.append(FlashPhi) def get_model( @@ -227,6 +230,37 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + + elif model_type == "phi": + if FLASH_ATTENTION: + return FlashPhi( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + use_medusa=use_medusa, + ) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + elif model_type == "phi-msft": + if FLASH_ATTENTION: + raise NotImplementedError("Legacy phi-msft is not supported with Flash Attention") + else: + return Phi( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif model_type == "llama" or model_type == "baichuan": if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py new file mode 100644 index 00000000..d103973f --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -0,0 +1,400 @@ +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + get_linear, + FastLayerNorm, +) + +class PhiConfig(PretrainedConfig): + def __init__( + self, + vocab_size=51200, + hidden_size=2560, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="gelu_fast", # llama uses silu + layer_norm_eps=1e-05, # rms in llama, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + resid_pdrop=0.1, # llama doesn't have this + partial_rotary_factor=0.5, # important difference between llama and phi + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.rope_theta = rope_theta + self.resid_pdrop = resid_pdrop + self.partial_rotary_factor = partial_rotary_factor + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + +# this is the same as llama except for Phi uses bias=True +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + # this is the same as llama except for Phi uses bias=True + return TensorParallelColumnLinear( + get_linear(weight, bias=True, quantize=config.quantize) + ) + +class FlashPhiAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.softmax_scale = self.head_size**-0.5 + self.rotary_dim = int(config.partial_rotary_factor * self.head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.rotary_dim, + base=config.rope_theta, + device=weights.device, + ) + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + # in llama the dense layer is called "o_proj" and has bias=False + self.dense = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.dense", + weights=weights, + bias=True, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + # Compute query, key, value and split + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + + # Reshape query and key for rotary embeddings + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + # NOTE: this is the main difference between Llama and Phi + # in llama the rotary embeddings are applied to the whole query and key. + # Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions + # + # Apply partial positional embeddings in place + self.rotary_emb( + query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim], + cos, sin + ) + + # Reshape key and value and cache + paged_attention.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.dense(attn_output.view(-1, self.num_heads*self.head_size)) + +class PhiMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) + ) + + # llama weights are up_proj and down_proj and bias=False + self.up_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.fc1", + weights=weights, + bias=True, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.fc2", + weights=weights, + bias=True, + ) + + def forward(self, hidden_states): + # NOTE: Llama requires the gate up states to an intermediate size + # Phi does not and we can avoid the `view` operation + return self.down_proj(self.act(self.up_proj(hidden_states))) + + +class FlashPhiLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = FlashPhiAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps + ) + self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + hidden_states, res = self.input_layernorm(hidden_states, residual) + # Self Attention + attn_output = self.self_attn( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states))) + + return hidden_states, res + +class FlashPhiModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + FlashPhiLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + self.norm = FastLayerNorm.load( + prefix="model.final_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + +class FlashPhiForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = FlashPhiModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + + return self.lm_head(hidden_states) diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py new file mode 100644 index 00000000..f9999537 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -0,0 +1,308 @@ +# imlementation of the PhiModel and PhiForCausalLM classes + +import torch +import torch.distributed + +import math +from torch import nn +from typing import Optional, List, Tuple, Any +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import CausalLMOutputWithPast + +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelHead, + FastLinear, +) + + +# PhiConfig is the configuration class for the PhiModel. +class PhiConfig(PretrainedConfig): + def __init__( + self, + vocab_size=51200, + n_positions=2048, + n_embd=2560, + n_layer=32, + n_inner=None, + n_head=32, + rotary_dim=32, + layer_norm_epsilon=1e-5, + tie_word_embeddings=False, + pad_vocab_size_multiple=64, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + no_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_inner = n_inner + self.n_head = n_head + self.rotary_dim = rotary_dim + + self.layer_norm_epsilon = layer_norm_epsilon + self.tie_word_embeddings = tie_word_embeddings + self.pad_vocab_size_multiple = pad_vocab_size_multiple + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.no_bias = no_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + +# RotaryEmbedding is a class that implements the rotary embedding. +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + inv_freq = [ + 1.0 / 10000.0 ** (i / dim) + for i in range(0, dim, 2) + ] + inv_freq_len = len(inv_freq) + inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len) + t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1) + freqs = t.matmul(inv_freq) + self.sin = freqs.sin() + self.cos = freqs.cos() + + def apply_rotary_emb_qkv(self, qkv, seqlen_offset): + b_size, seqlen, three, _, _headdim = qkv.shape + if three != 3: + raise Exception("unexpected shape for qkv") + _, rotary_dim = self.cos.shape + rotary_dim = rotary_dim * 2 + q_rot = qkv[:, :, 0, :, :rotary_dim] + q_pass = qkv[:, :, 0, :, rotary_dim:] + k_rot = qkv[:, :, 1, :, :rotary_dim] + k_pass = qkv[:, :, 1, :, rotary_dim:] + q12 = torch.chunk(q_rot, 2, dim=-1) + k12 = torch.chunk(k_rot, 2, dim=-1) + q1, q2 = q12[0], q12[1] + k1, k2 = k12[0], k12[1] + c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1) + s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1) + q_rot = torch.cat( + [ + q1 * c - q2 * s, + q1 * s + q2 * c, + ], + dim=-1, + ) + k_rot = torch.cat( + [ + k1 * c - k2 * s, + k1 * s + k2 * c, + ], + dim=-1, + ) + q = torch.cat([q_rot, q_pass], dim=-1) + k = torch.cat([k_rot, k_pass], dim=-1) + v = qkv[:, :, 2] + return q, k, v + + +# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm. +class PhiCausalLMHead(nn.Module): + def __init__(self, config, weights): + super().__init__() + self.ln = nn.LayerNorm.load( + prefix="lm_head.ln", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.linear = TensorParallelHead.load( + config=config, prefix="lm_head.linear", weights=weights + ) + + def forward(self, hidden_states): + hidden_states = self.ln(hidden_states) + hidden_states = self.linear(hidden_states) + return hidden_states + +# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens. +class PhiMHA(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.Wqkv = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias + ) + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=not config.no_bias, + ) + self.op_size = config.n_embd + self.head_dim = int(config.n_embd / config.n_head) + self.num_heads = config.n_head + self.rotary_emb = RotaryEmbedding( + config.rotary_dim, + config.n_positions, + ) + self.softmax_scale = 1.0 / math.sqrt(self.head_dim) + + def forward( + self, + hidden_states, + past_kv_cache, + attention_mask=None, + ): + b_size, seq_len, _n_embd = hidden_states.shape + qkv = self.Wqkv(hidden_states) + qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim) + seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1] + q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset) + + # if there is a kv_cache, then we need to concatenate + if past_kv_cache is not None: + prev_k, prev_v = past_kv_cache + k = torch.cat([prev_k, k], dim=1) + v = torch.cat([prev_v, v], dim=1) + + past_kv_cache = [k, v] + attn_weights = torch.einsum('bthd,bshd->bhts', q, k * self.softmax_scale) + + if attention_mask is not None: + seqlen_k = k.shape[1] + seqlen_q = q.shape[1] + causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), 1) + attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0) + attn_output = attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)).transpose(1, 2).flatten(-2) + return self.out_proj(attn_output), past_kv_cache + +# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function. +class PhiMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.n_inner = config.n_inner + self.fc1 = FastLinear.load( + config=config, + prefix=f"{prefix}.fc1", + weights=weights, + bias=False, + ) + self.fc2 = FastLinear.load( + config=config, + prefix=f"{prefix}.fc2", + weights=weights, + bias=False, + ) + self.activation = torch.nn.functional.gelu + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + +# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron. +class PhiBlock(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + self.layer_id = layer_id + self.layer_norm = nn.LayerNorm.load(prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon) + self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights) + self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights) + + def forward( + self, + hidden_states, + kv_cache, + attention_mask, + ): + residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask) + feed_forward_hidden_states = self.mlp(hidden_states) + out = attn_outputs + feed_forward_hidden_states + residual + return out, past_kv_cache + +# PhiModel implements the embedding layer and the transformer blocks. +class PhiModel(nn.Module): + def __init__(self, config, weights): + super().__init__() + self.tp_rank = weights.process_group.rank() + self.tp_world_size = weights.process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="transformer.embd.wte", weights=weights + ) + self.blocks = nn.ModuleList( + [PhiBlock(f"transformer.h.{layer_id}", config, weights) for layer_id in range(config.n_layer)] + ) + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + hidden_states = self.embed_tokens(input_ids) + seq_len = hidden_states.shape[1] + mask = None if seq_len <= 1 else attention_mask + + past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values + + for index, block in enumerate(self.blocks): + hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask) + past_key_values[index] = new_key_values + + return hidden_states, past_key_values + +# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. +class PhiForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + self.model = PhiModel(config, weights) + self.lm_head = PhiCausalLMHead(config, weights) + + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + model_output = self.model( + input_ids, past_key_values, attention_mask, return_dict, use_cache + ) + logits = self.lm_head(model_output[0]) + + loss = None + if labels is not None: + loss = nn.CrossEntropyLoss()( + logits[:, :-1].view(-1, logits.size(-1)), + labels[:, 1:].view(-1) + ) + + if not return_dict: + return ((loss,) + (logits,) + model_output[1:]) if loss is not None else (logits,) + model_output[1:] + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=model_output[1], + hidden_states=None, + attentions=None, + ) + + diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py new file mode 100644 index 00000000..1c49f2a9 --- /dev/null +++ b/server/text_generation_server/models/flash_phi.py @@ -0,0 +1,102 @@ +import torch +import torch.distributed + +from opentelemetry import trace +from transformers import AutoConfig, AutoTokenizer +from typing import Optional + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_phi_modeling import ( + FlashPhiForCausalLM, + PhiConfig, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + + +class FlashPhi(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + use_medusa: Optional[str] = None, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashPhi is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = PhiConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id, revision) + + model = FlashPhiForCausalLM(config, weights) + if use_medusa: + from text_generation_server.utils.medusa import MedusaModel + from huggingface_hub import hf_hub_download + import json + import os + from pathlib import Path + + is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv( + "WEIGHTS_CACHE_OVERRIDE", None + ) is not None + + if not is_local_model: + medusa_config = hf_hub_download( + use_medusa, revision=revision, filename="config.json" + ) + medusa_head = hf_hub_download( + use_medusa, revision=revision, filename="medusa_lm_head.pt" + ) + else: + medusa_config = str(Path(use_medusa) / "config.json") + medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + + with open(medusa_config, "r") as f: + config = json.load(f) + medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" + weights = Weights( + [medusa_sf], device, dtype, process_group=self.process_group + ) + lm_head = model.lm_head + model.lm_head = MedusaModel(config, weights, lm_head) + + torch.distributed.barrier(group=self.process_group) + super(FlashPhi, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py new file mode 100644 index 00000000..d477478a --- /dev/null +++ b/server/text_generation_server/models/phi.py @@ -0,0 +1,63 @@ +import torch +import torch.distributed + +from transformers import AutoConfig, AutoTokenizer +from typing import Optional, List, Tuple + +from text_generation_server.models import CausalLM +from text_generation_server.models.custom_modeling.phi_modeling import PhiConfig, PhiForCausalLM +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +class Phi(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, _rank, _world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + config = PhiConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + + tokenizer.bos_token_id = config.bos_token_id + tokenizer.eos_token_id = config.eos_token_id + tokenizer.pad_token = tokenizer.eos_token + + config.quantize = quantize + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + model = PhiForCausalLM(config, weights) + torch.distributed.barrier(group=self.process_group) + super(CausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) +