From 6e0f37c0cacd8a54701ec6509fc99cc9e7505f5b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 15:13:24 +0100 Subject: [PATCH] revert + style + minor improvements --- .../layers/gptq/quantize.py | 4 +- .../text_generation_server/models/__init__.py | 20 ++++++-- .../text_generation_server/models/globals.py | 1 - .../models/transformers_flash_causal_lm.py | 51 ++++++++++--------- .../utils/logits_process.py | 2 +- 5 files changed, 45 insertions(+), 33 deletions(-) diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 41dc867d..aa664ea6 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -963,7 +963,9 @@ def quantize( max_shard_size = "10GB" state_dict_split = split_torch_state_dict_into_shards( - state_dict, filename_pattern="model.safetensors", max_shard_size=max_shard_size, + state_dict, + filename_pattern="model.safetensors", + max_shard_size=max_shard_size, ) index = None if state_dict_split.is_sharded: diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 4b506532..5069fff6 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -21,7 +21,9 @@ import transformers from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast -from text_generation_server.models.transformers_flash_causal_lm import TransformersFlashCausalLM +from text_generation_server.models.transformers_flash_causal_lm import ( + TransformersFlashCausalLM, +) from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, @@ -377,11 +379,19 @@ def get_model( transformers_causal_lm_class = CausalLM # Fast transformers path - transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) - if transformers_model_class.is_backend_compatible(): + transformers_model_class = getattr( + transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + ) + if transformers_model_class._supports_flex_attn: transformers_causal_lm_class = TransformersFlashCausalLM - if not FLASH_ATTENTION and lora_adapter_ids is not None and len(lora_adapter_ids) > 0: - raise ValueError("Transformers backend AutoModel do not support `lora_adapter_ids`.") + if ( + not FLASH_ATTENTION + and lora_adapter_ids is not None + and len(lora_adapter_ids) > 0 + ): + raise ValueError( + "Transformers backend AutoModel do not support `lora_adapter_ids`." + ) quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 8a33fb32..8d988ad5 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -67,4 +67,3 @@ def set_adapter_to_index(adapter_to_index: Dict[str, int]): def get_adapter_to_index(): global ADAPTER_TO_INDEX return ADAPTER_TO_INDEX - diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index cfc9c861..30ea4c8f 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -22,7 +22,7 @@ def tgi_flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - attention_mask: Optional[torch.Tensor], # This needs to stay as it is passed as a positional arg in transformers + attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers kv_cache: List[KVCache], kv_head_mapping: torch.Tensor, slots: torch.Tensor, @@ -30,6 +30,7 @@ def tgi_flash_attention_forward( seqlen: Seqlen, block_tables: torch.Tensor, max_s: int, + kv_scales: KVScales, softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, @@ -37,20 +38,13 @@ def tgi_flash_attention_forward( ): kv_cache = kv_cache[module.layer_idx] - # This means no scale - kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) query_states = query_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0) # Take care of updating the cache in-place - kv_cache.store( - key=key_states, - value=value_states, - slots=slots, - kv_scales=kv_scales - ) + kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales) _, num_heads, head_dim = query_states.shape softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale @@ -110,14 +104,11 @@ class TransformersFlashCausalLM(FlashCausalLM): if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") - device_count = 0 if torch.cuda.is_available(): device = torch.device("cuda:0") - device_count = torch.cuda.device_count() dtype = torch.float16 if dtype is None else dtype elif hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("xpu") - device_count = torch.xpu.device_count() dtype = torch.float16 if dtype is None else dtype else: if quantize: @@ -156,7 +147,6 @@ class TransformersFlashCausalLM(FlashCausalLM): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.num_layers = model.config.num_hidden_layers self.num_heads = model.config.num_attention_heads // self.process_group.size() self.num_kv_heads = model.config.num_key_value_heads @@ -190,9 +180,16 @@ class TransformersFlashCausalLM(FlashCausalLM): ) self.num_groups = self.num_heads // self.num_kv_heads + + # Those will never change and will be used in the forwards self.kv_head_mapping = torch.arange( 0, self.num_kv_heads, dtype=torch.int32, device=device ).repeat_interleave(self.num_groups) + # This means no scale + self.kv_scales = KVScales( + torch.tensor(1.0, device=device), + torch.tensor(1.0, device=device), + ) torch.distributed.barrier(group=self.process_group) # Skip FlashCausalLM init. @@ -242,21 +239,17 @@ class TransformersFlashCausalLM(FlashCausalLM): seqlen: Seqlen, max_s: int, lm_head_indices: Optional[torch.Tensor], - prefill_cache_indices = None, # not used, but passed to match original signature - adapter_data = None, # not supported, but passed to match original signature + prefill_cache_indices=None, # not used, but passed to match original signature + adapter_data=None, # not supported, but passed to match original signature ): - # Transformers does not support None as a default - if lm_head_indices is None: - lm_head_indices = 0 - - # Equivalent tp `self.model.forward`, see the monkey patch in __init__ - logits = self.model.original_forward( + hidden_states = self.model.model.forward( input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers - position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers + position_ids=position_ids.unsqueeze( + 0 + ), # expand dim to easily fit transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object return_dict=True, - num_logits_to_keep=lm_head_indices, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, @@ -264,6 +257,14 @@ class TransformersFlashCausalLM(FlashCausalLM): seqlen=seqlen, max_s=max_s, kv_head_mapping=self.kv_head_mapping, - ).logits.squeeze(dim=0) + kv_scales=self.kv_scales, + )[0].squeeze(dim=0) - return logits, None \ No newline at end of file + # And compute logits from the lm_head, slicing correctly the indices + # NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules + # To update with full Transformers support asap + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.model.lm_head.forward(hidden_states) + + return logits, None diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 5066de53..64a285b9 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -5,7 +5,7 @@ import torch from typing import List, Optional, DefaultDict from loguru import logger -from typing import Dict, Union +from typing import Dict from text_generation_server.pb.generate_pb2 import GrammarType from outlines.fsm.guide import RegexGuide