From 77ee1f18fa50959725d635612966532b1bf48321 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 17 Jan 2024 22:30:57 +0000 Subject: [PATCH] feat: load phi weights and produce nonsense tokens --- .../text_generation_server/models/__init__.py | 19 +- .../custom_modeling/flash_phi_modeling.py | 435 ++++++++++++++++++ .../models/flash_phi.py | 102 ++++ 3 files changed, 553 insertions(+), 3 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/flash_phi_modeling.py create mode 100644 server/text_generation_server/models/flash_phi.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index aabdd75f..7e7ddef1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -58,6 +58,7 @@ try: from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mixtral import FlashMixtral + from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA except ImportError as e: @@ -73,6 +74,7 @@ if FLASH_ATTENTION: __all__.append(IDEFICSSharded) __all__.append(FlashMistral) __all__.append(FlashMixtral) + __all__.append(FlashPhi) def get_model( @@ -229,11 +231,22 @@ def get_model( trust_remote_code=trust_remote_code, ) + elif model_type == "phi": + if FLASH_ATTENTION: + return FlashPhi( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + use_medusa=use_medusa, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Phi")) + elif model_type == "phi-msft": if FLASH_ATTENTION: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Flash Phi")) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Phi")) + raise NotImplementedError("Legacy phi-msft is not supported with Flash Attention") else: return Phi( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py new file mode 100644 index 00000000..dd466145 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -0,0 +1,435 @@ +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + get_linear, + FastRMSNorm, + FastLayerNorm, + FastLinear, +) + +class PhiConfig(PretrainedConfig): + def __init__( + self, + vocab_size=51200, + hidden_size=2560, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="gelu_fast", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + rope_theta=10000.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + # should never get here + return _load_gqa(config, prefix, weights) + else: + if config.model_type == "baichuan": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.W_pack", + weights=weights, + bias=False, + ) + else: + # should be here + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) + + +class FlashPhiAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + # MAYBE (if not static) + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + + # should be correct + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.dense = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.dense", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + paged_attention.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) + + +class PhiMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) + ) + + self.gate_up_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.fc1", + weights=weights, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.fc2", + weights=weights, + bias=False, + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + post_act = self.act(gate_up_states) + return self.down_proj(post_act) + + +class FlashPhiLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = FlashPhiAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + + mlp_output = self.mlp(normed_hidden_states) + + result = attn_output + mlp_output + res + + return result, res + + +class FlashPhiModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + FlashPhiLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + return hidden_states + + +class FlashPhiForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = FlashPhiModel(config, weights) + # self.lm_head = TensorParallelHead.load( + # config, + # prefix="lm_head", + # weights=weights, + # ) + + # TODO: prefer parallel head + self.linear = FastLinear.load( + config, + prefix="lm_head", + weights=weights, + bias=False, + ) + + # TODO: use in correct place + self.ln = FastLayerNorm.load( + prefix="model.final_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + # 1000K and 10K + cu_seqlen_prefill: Optional[torch.Tensor], # indexes for the items in the batch + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + # paged attention related + block_tables: torch.Tensor, # <- indexes into blocks + slots: torch.Tensor, # <- indexes into mem + input_lengths: torch.Tensor, + # both attentions + max_s: int, # <- max sequence length (make kernals chose swap) + # small opt (only care about final) + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + + normed_hidden_states, res = self.ln(hidden_states, None) + logits = self.linear(normed_hidden_states) + + return logits diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py new file mode 100644 index 00000000..1c49f2a9 --- /dev/null +++ b/server/text_generation_server/models/flash_phi.py @@ -0,0 +1,102 @@ +import torch +import torch.distributed + +from opentelemetry import trace +from transformers import AutoConfig, AutoTokenizer +from typing import Optional + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_phi_modeling import ( + FlashPhiForCausalLM, + PhiConfig, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + + +class FlashPhi(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + use_medusa: Optional[str] = None, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashPhi is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = PhiConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id, revision) + + model = FlashPhiForCausalLM(config, weights) + if use_medusa: + from text_generation_server.utils.medusa import MedusaModel + from huggingface_hub import hf_hub_download + import json + import os + from pathlib import Path + + is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv( + "WEIGHTS_CACHE_OVERRIDE", None + ) is not None + + if not is_local_model: + medusa_config = hf_hub_download( + use_medusa, revision=revision, filename="config.json" + ) + medusa_head = hf_hub_download( + use_medusa, revision=revision, filename="medusa_lm_head.pt" + ) + else: + medusa_config = str(Path(use_medusa) / "config.json") + medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + + with open(medusa_config, "r") as f: + config = json.load(f) + medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" + weights = Weights( + [medusa_sf], device, dtype, process_group=self.process_group + ) + lm_head = model.lm_head + model.lm_head = MedusaModel(config, weights, lm_head) + + torch.distributed.barrier(group=self.process_group) + super(FlashPhi, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + )