From 649cb1f5f1a99049c45c03ecbca39efb66fe4133 Mon Sep 17 00:00:00 2001 From: System administrator Date: Thu, 12 Dec 2024 14:27:07 +0000 Subject: [PATCH] runnable version --- .../text_generation_server/models/globals.py | 2 +- .../models/transformers_flash_causal_lm.py | 87 ++++++++++++++++--- 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 7d6639f2..7c7e026e 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -8,7 +8,7 @@ from text_generation_server.utils.log import log_master REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} ATTENTION = os.environ["ATTENTION"] # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" -PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in { +PREFIX_CACHING = os.environ["USE_PREFIX_CACHING"].lower() in { "1", "true", } diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index de2570b0..f4f24749 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -5,12 +5,17 @@ from typing import Optional, Tuple, Dict, Any import torch from opentelemetry import trace from loguru import logger -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, ) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.utils.import_utils import ( empty_cache, synchronize, @@ -57,7 +62,7 @@ def _flash_attention_forward_patched( query_length: int, is_causal: bool, softmax_scale: Optional[float] = None, - sliding_window: int = -1, + sliding_window: Optional[int] = None, softcap: Optional[float] = None, **kwargs, ): @@ -67,11 +72,11 @@ def _flash_attention_forward_patched( kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) # Correctly reshape the states - _, _, num_heads, head_dim = query_states.size() - _, _, num_kv_heads, _ = key_states.size() - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) + _, num_heads, head_dim = query_states.size() + # _, num_kv_heads, _ = key_states.size() + # query_states = query_states.view(-1, num_heads, head_dim) + # key_states = key_states.view(-1, num_kv_heads, head_dim) + # value_states = value_states.view(-1, num_kv_heads, head_dim) # Take care of updating the cache in-place kv_cache.store( @@ -82,6 +87,7 @@ def _flash_attention_forward_patched( ) softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + sliding_window = -1 if sliding_window is None else sliding_window if kwargs["cu_seqlen_prefill"] is not None: attn_output = attention( @@ -109,7 +115,8 @@ def _flash_attention_forward_patched( softcap=softcap, ) - attn_output = attn_output.view(attn_output.shape[0], -1) + # attn_output = attn_output.view(attn_output.shape[0], -1) + attn_output = attn_output.view(-1, num_heads * head_dim) return attn_output @@ -122,14 +129,21 @@ class TransformersFlashCausalLM(FlashCausalLM): quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, + config_class=AutoConfig, + kv_cache_dtype: Optional[torch.dtype] = None, ): + self.quantize = quantize + self.process_group, rank, world_size = initialize_torch_distributed() + if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") device_count = 0 if torch.cuda.is_available(): - device = torch.device("cuda") + device = torch.device("cuda:0") device_count = torch.cuda.device_count() dtype = torch.float16 if dtype is None else dtype elif hasattr(torch, "xpu") and torch.xpu.is_available(): @@ -157,6 +171,7 @@ class TransformersFlashCausalLM(FlashCausalLM): device_map=("auto" if device_count > 1 else None), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, + attn_implementation="flash_attention_2" ) if device_count == 1 and quantize != "bitsandbytes": model = model.to(device) @@ -174,10 +189,44 @@ class TransformersFlashCausalLM(FlashCausalLM): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.num_layers = len(model.model.layers) + self.num_layers = model.config.num_hidden_layers + self.num_heads = model.config.num_attention_heads // self.process_group.size() self.num_kv_heads = model.config.num_key_value_heads + self.num_kv_heads = ( + self.num_kv_heads // self.process_group.size() + if self.num_kv_heads > 1 + else self.num_kv_heads + ) self.head_size = model.config.hidden_size // model.config.num_attention_heads + self.cuda_graphs = {} + self.kv_cache = [] + self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_prefill_state, + create_decode_state, + create_prefill_with_paged_kv_state, + ) + + self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) + + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + + self.num_groups = self.num_heads // self.num_kv_heads + self.kv_head_mapping = torch.arange( + 0, self.num_kv_heads, dtype=torch.int32, device=device + ).repeat_interleave(self.num_groups) + + torch.distributed.barrier(group=self.process_group) # Skip FlashCausalLM init. super(FlashCausalLM, self).__init__( model_id=model_id, @@ -186,6 +235,8 @@ class TransformersFlashCausalLM(FlashCausalLM): requires_padding=False, dtype=dtype, device=device, + rank=rank, + world_size=world_size, ) @classmethod @@ -198,11 +249,18 @@ class TransformersFlashCausalLM(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - return cls(model_id, revision, quantize, speculator, dtype, trust_remote_code) + return cls( + model_id=model_id, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) - def warmup(self, batch: FlashCausalLMBatch): + def warmup(self, batch: FlashCausalLMBatch, max_input_tokens: Optional[int], max_total_tokens: Optional[int],): patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) - super().warmup(batch) + return super().warmup(batch, max_input_tokens, max_total_tokens) def forward( self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData @@ -270,7 +328,8 @@ class TransformersFlashCausalLM(FlashCausalLM): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - ) + kv_head_mapping=self.kv_head_mapping, + ).logits if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits, None