From d0ddc80c31852d89aff753a3e99f8535a0d98d82 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 26 Jan 2024 16:31:48 +0100 Subject: [PATCH] fmt --- integration-tests/models/test_flash_phi.py | 8 +-- server/tests/utils/test_layers.py | 37 +++++++---- server/text_generation_server/cli.py | 2 +- .../text_generation_server/models/__init__.py | 6 +- .../custom_modeling/flash_phi_modeling.py | 30 ++++++--- .../models/custom_modeling/mpt_modeling.py | 27 +++++--- .../models/custom_modeling/phi_modeling.py | 62 +++++++++++++------ .../models/flash_llama.py | 12 ++-- .../models/flash_phi.py | 12 ++-- server/text_generation_server/models/phi.py | 7 ++- server/text_generation_server/utils/layers.py | 4 +- 11 files changed, 135 insertions(+), 72 deletions(-) diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py index 6391f2a1..0987b3a1 100644 --- a/integration-tests/models/test_flash_phi.py +++ b/integration-tests/models/test_flash_phi.py @@ -21,7 +21,7 @@ async def test_flash_phi(flash_phi, response_snapshot): ) assert response.details.generated_tokens == 10 - assert response.generated_text == ": {request}\")\n response = self" + assert response.generated_text == ': {request}")\n response = self' assert response == response_snapshot @@ -52,14 +52,12 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot): @pytest.mark.asyncio @pytest.mark.private async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): - responses = await generate_load( - flash_phi, "Test request", max_new_tokens=10, n=4 - ) + responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) assert len(responses) == 4 assert all( [r.generated_text == responses[0].generated_text for r in responses] ), f"{[r.generated_text for r in responses]}" - assert responses[0].generated_text == ": {request}\")\n response = self" + assert responses[0].generated_text == ': {request}")\n response = self' assert responses == response_snapshot diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py index 0a9fecd1..93a0e982 100644 --- a/server/tests/utils/test_layers.py +++ b/server/tests/utils/test_layers.py @@ -3,24 +3,27 @@ from text_generation_server.utils.layers import ( TensorParallelEmbedding, ) + class ProcessGroup: def __init__(self, rank: int, world_size: int): self._rank = rank self.world_size = world_size - def size(self)->int: + def size(self) -> int: return self.world_size - def rank(self)->int: + def rank(self) -> int: return self._rank + class Weights: def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int): - self.weight = torch.arange(vocab_size*hidden_dim).float().view(vocab_size, hidden_dim) + self.weight = ( + torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim) + ) self.process_group = ProcessGroup(rank, world_size) - - def get_partial_sharded(self, name:str, dim: int): + def get_partial_sharded(self, name: str, dim: int): assert dim == 0 rank = self.process_group.rank() @@ -35,10 +38,11 @@ class Weights: def get_shape(self, name: str): return self.weight.shape + def test_weight_hub_files_offline_error(): - vocab_size= 17 - weights = Weights(rank=0, world_size=1, vocab_size = vocab_size,hidden_dim = 256) + vocab_size = 17 + weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256) embeddings = TensorParallelEmbedding("", weights) input_ids = torch.arange(vocab_size) @@ -47,18 +51,27 @@ def test_weight_hub_files_offline_error(): assert embeddings.max_id == 17 torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256)) - weights_0_2 = Weights(rank=0, world_size=2, vocab_size = vocab_size,hidden_dim = 256) - weights_1_2 = Weights(rank=1, world_size=2, vocab_size = vocab_size,hidden_dim = 256) + weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256) + weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256) embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False) assert embeddings_0_2.min_id == 0 assert embeddings_0_2.max_id == 9 - torch.testing.assert_close(embeddings_0_2.weight , torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0).view(10, 256).float()) + torch.testing.assert_close( + embeddings_0_2.weight, + torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0) + .view(10, 256) + .float(), + ) embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False) assert embeddings_1_2.min_id == 9 assert embeddings_1_2.max_id == 17 - torch.testing.assert_close(embeddings_1_2.weight , torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0).view(9, 256).float()) + torch.testing.assert_close( + embeddings_1_2.weight, + torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0) + .view(9, 256) + .float(), + ) output_tp_0 = embeddings_0_2.forward(input_ids) output_tp_1 = embeddings_1_2.forward(input_ids) torch.testing.assert_close(output, output_tp_0 + output_tp_1) - diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 99be6c7e..b74fbe36 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -226,7 +226,7 @@ def download_weights( pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass - + elif (Path(model_id) / "adapter_config.json").exists(): # Try to load as a local PEFT model try: diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 679e1e2f..68096709 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -230,7 +230,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - + elif model_type == "phi": if FLASH_ATTENTION: return FlashPhi( @@ -252,7 +252,9 @@ def get_model( elif model_type == "phi-msft": if FLASH_ATTENTION: - raise NotImplementedError("Legacy phi-msft is not supported with Flash Attention") + raise NotImplementedError( + "Legacy phi-msft is not supported with Flash Attention" + ) else: return Phi( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index d103973f..96701794 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -17,6 +17,7 @@ from text_generation_server.utils.layers import ( FastLayerNorm, ) + class PhiConfig(PretrainedConfig): def __init__( self, @@ -25,15 +26,15 @@ class PhiConfig(PretrainedConfig): num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, - hidden_act="gelu_fast", # llama uses silu - layer_norm_eps=1e-05, # rms in llama, + hidden_act="gelu_fast", # llama uses silu + layer_norm_eps=1e-05, # rms in llama, pad_token_id=0, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=10000.0, - resid_pdrop=0.1, # llama doesn't have this - partial_rotary_factor=0.5, # important difference between llama and phi + resid_pdrop=0.1, # llama doesn't have this + partial_rotary_factor=0.5, # important difference between llama and phi **kwargs, ): self.vocab_size = vocab_size @@ -55,6 +56,7 @@ class PhiConfig(PretrainedConfig): **kwargs, ) + # this is the same as llama except for Phi uses bias=True def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: @@ -68,6 +70,7 @@ def load_attention(config, prefix, weights): bias=True, ) + def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 @@ -94,6 +97,7 @@ def _load_gqa(config, prefix: str, weights): get_linear(weight, bias=True, quantize=config.quantize) ) + class FlashPhiAttention(torch.nn.Module): def __init__( self, @@ -173,8 +177,7 @@ class FlashPhiAttention(torch.nn.Module): # # Apply partial positional embeddings in place self.rotary_emb( - query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim], - cos, sin + query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin ) # Reshape key and value and cache @@ -210,7 +213,8 @@ class FlashPhiAttention(torch.nn.Module): max_s, ) - return self.dense(attn_output.view(-1, self.num_heads*self.head_size)) + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) + class PhiMLP(nn.Module): def __init__(self, prefix, config, weights): @@ -256,7 +260,9 @@ class FlashPhiLayer(nn.Module): ) self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = FastLayerNorm.load( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_eps, ) self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) @@ -287,10 +293,13 @@ class FlashPhiLayer(nn.Module): max_s, ) - hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states))) + hidden_states = self.resid_dropout(attn_output).add( + self.resid_dropout(self.mlp(hidden_states)) + ) return hidden_states, res + class FlashPhiModel(torch.nn.Module): def __init__(self, config, weights): super().__init__() @@ -361,6 +370,7 @@ class FlashPhiModel(torch.nn.Module): return hidden_states + class FlashPhiForCausalLM(torch.nn.Module): def __init__(self, config, weights): super().__init__() @@ -380,7 +390,7 @@ class FlashPhiForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + input_lengths: torch.Tensor, max_s: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 1a9aef74..2c2fec48 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -54,9 +54,19 @@ def load_col(config, prefix, weights, bias): bias_h = bias_h[0] bias_block_size = bias_h // bias_size - bias_q_part = bias_slice_[bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size] - bias_k_part = bias_slice_[bias_h + bias_rank * bias_block_size : bias_h + (bias_rank + 1) * bias_block_size] - bias_v_part = bias_slice_[2 * bias_h + bias_rank * bias_block_size : 2 * bias_h + (bias_rank + 1) * bias_block_size] + bias_q_part = bias_slice_[ + bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size + ] + bias_k_part = bias_slice_[ + bias_h + + bias_rank * bias_block_size : bias_h + + (bias_rank + 1) * bias_block_size + ] + bias_v_part = bias_slice_[ + 2 * bias_h + + bias_rank * bias_block_size : 2 * bias_h + + (bias_rank + 1) * bias_block_size + ] bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0) if bias.dtype != torch.int32: @@ -352,8 +362,12 @@ class MultiheadAttention(nn.Module): hidden_size = config.d_model head_dim = hidden_size // self.n_heads - self.q_ln = LPLayerNorm(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights) - self.k_ln = LPLayerNorm(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights) + self.q_ln = LPLayerNorm( + d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights + ) + self.k_ln = LPLayerNorm( + self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights + ) if self.attn_impl == "flash": self.attn_fn = flash_attn_fn elif self.attn_impl == "triton": @@ -684,7 +698,6 @@ class LPLayerNorm(torch.nn.LayerNorm): self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0)) self.normalized_shape = self.weight.shape - def forward(self, x): module_device = x.device downcast_x = _cast_if_autocast_enabled(x) @@ -798,7 +811,7 @@ class MPTModel(MPTPreTrainedModel): self.wte = TensorParallelEmbedding("transformer.wte", weights) if not self.alibi: - self.wpe = TensorParallelEmbedding("transformer.wpe", weights) + self.wpe = TensorParallelEmbedding("transformer.wpe", weights) self.blocks = nn.ModuleList( [ MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index f9999537..e5c09728 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -62,14 +62,12 @@ class PhiConfig(PretrainedConfig): **kwargs, ) + # RotaryEmbedding is a class that implements the rotary embedding. class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() - inv_freq = [ - 1.0 / 10000.0 ** (i / dim) - for i in range(0, dim, 2) - ] + inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)] inv_freq_len = len(inv_freq) inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len) t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1) @@ -131,6 +129,7 @@ class PhiCausalLMHead(nn.Module): hidden_states = self.linear(hidden_states) return hidden_states + # PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens. class PhiMHA(nn.Module): def __init__(self, prefix, config, weights): @@ -172,19 +171,27 @@ class PhiMHA(nn.Module): v = torch.cat([prev_v, v], dim=1) past_kv_cache = [k, v] - attn_weights = torch.einsum('bthd,bshd->bhts', q, k * self.softmax_scale) + attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale) if attention_mask is not None: seqlen_k = k.shape[1] seqlen_q = q.shape[1] - causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), 1) + causal_mask = torch.triu( + torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), + 1, + ) attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype) - + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0) - attn_output = attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)).transpose(1, 2).flatten(-2) + attn_output = ( + attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)) + .transpose(1, 2) + .flatten(-2) + ) return self.out_proj(attn_output), past_kv_cache + # PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function. class PhiMLP(nn.Module): def __init__(self, prefix, config, weights): @@ -204,19 +211,22 @@ class PhiMLP(nn.Module): bias=False, ) self.activation = torch.nn.functional.gelu - + def forward(self, hidden_states): hidden_states = self.fc1(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states + # PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron. class PhiBlock(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() self.layer_id = layer_id - self.layer_norm = nn.LayerNorm.load(prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon) + self.layer_norm = nn.LayerNorm.load( + prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon + ) self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights) self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights) @@ -228,11 +238,14 @@ class PhiBlock(nn.Module): ): residual = hidden_states hidden_states = self.layer_norm(hidden_states) - attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask) + attn_outputs, past_kv_cache = self.mixer( + hidden_states, kv_cache, attention_mask + ) feed_forward_hidden_states = self.mlp(hidden_states) out = attn_outputs + feed_forward_hidden_states + residual return out, past_kv_cache + # PhiModel implements the embedding layer and the transformer blocks. class PhiModel(nn.Module): def __init__(self, config, weights): @@ -241,9 +254,12 @@ class PhiModel(nn.Module): self.tp_world_size = weights.process_group.size() self.embed_tokens = TensorParallelEmbedding( prefix="transformer.embd.wte", weights=weights - ) + ) self.blocks = nn.ModuleList( - [PhiBlock(f"transformer.h.{layer_id}", config, weights) for layer_id in range(config.n_layer)] + [ + PhiBlock(f"transformer.h.{layer_id}", config, weights) + for layer_id in range(config.n_layer) + ] ) def forward( @@ -258,14 +274,19 @@ class PhiModel(nn.Module): seq_len = hidden_states.shape[1] mask = None if seq_len <= 1 else attention_mask - past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values + past_key_values = ( + [None] * len(self.blocks) if past_key_values is None else past_key_values + ) for index, block in enumerate(self.blocks): - hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask) + hidden_states, new_key_values = block( + hidden_states, past_key_values[index], mask + ) past_key_values[index] = new_key_values return hidden_states, past_key_values + # PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. class PhiForCausalLM(torch.nn.Module): def __init__(self, config, weights): @@ -290,12 +311,15 @@ class PhiForCausalLM(torch.nn.Module): loss = None if labels is not None: loss = nn.CrossEntropyLoss()( - logits[:, :-1].view(-1, logits.size(-1)), - labels[:, 1:].view(-1) + logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1) ) if not return_dict: - return ((loss,) + (logits,) + model_output[1:]) if loss is not None else (logits,) + model_output[1:] + return ( + ((loss,) + (logits,) + model_output[1:]) + if loss is not None + else (logits,) + model_output[1:] + ) return CausalLMOutputWithPast( loss=loss, @@ -304,5 +328,3 @@ class PhiForCausalLM(torch.nn.Module): hidden_states=None, attentions=None, ) - - diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 7be61906..94bd58f4 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -73,11 +73,11 @@ class FlashLlama(FlashCausalLM): import json import os from pathlib import Path - - is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv( - "WEIGHTS_CACHE_OVERRIDE", None - ) is not None - + + is_local_model = ( + Path(use_medusa).exists() and Path(use_medusa).is_dir() + ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None + if not is_local_model: medusa_config = hf_hub_download( use_medusa, revision=revision, filename="config.json" @@ -88,7 +88,7 @@ class FlashLlama(FlashCausalLM): else: medusa_config = str(Path(use_medusa) / "config.json") medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") - + with open(medusa_config, "r") as f: config = json.load(f) medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 1c49f2a9..061b9740 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -63,11 +63,11 @@ class FlashPhi(FlashCausalLM): import json import os from pathlib import Path - - is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv( - "WEIGHTS_CACHE_OVERRIDE", None - ) is not None - + + is_local_model = ( + Path(use_medusa).exists() and Path(use_medusa).is_dir() + ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None + if not is_local_model: medusa_config = hf_hub_download( use_medusa, revision=revision, filename="config.json" @@ -78,7 +78,7 @@ class FlashPhi(FlashCausalLM): else: medusa_config = str(Path(use_medusa) / "config.json") medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") - + with open(medusa_config, "r") as f: config = json.load(f) medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index d477478a..79aa3fb9 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -5,13 +5,17 @@ from transformers import AutoConfig, AutoTokenizer from typing import Optional, List, Tuple from text_generation_server.models import CausalLM -from text_generation_server.models.custom_modeling.phi_modeling import PhiConfig, PhiForCausalLM +from text_generation_server.models.custom_modeling.phi_modeling import ( + PhiConfig, + PhiForCausalLM, +) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, ) + class Phi(CausalLM): def __init__( self, @@ -60,4 +64,3 @@ class Phi(CausalLM): dtype=dtype, device=device, ) - diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 6ddfd6f4..010d6143 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -510,7 +510,9 @@ class TensorParallelEmbedding(nn.Module): block_size = (num_embeddings + world_size - 1) // world_size self.min_id = rank * block_size self.max_id = min(num_embeddings, (rank + 1) * block_size) - self.null_idx = weight.shape[0] # Usually block_size, might be less in non even vocab_size. + self.null_idx = weight.shape[ + 0 + ] # Usually block_size, might be less in non even vocab_size. self.process_group = weights.process_group self.reduce = reduce