diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 49dcac62..18ab27c2 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -1,6 +1,6 @@ import math import sys -from typing import Optional, Tuple, Dict, Any +from typing import List, Optional, Tuple, Dict, Any import torch from opentelemetry import trace @@ -24,7 +24,7 @@ from text_generation_server.utils.import_utils import ( ) from text_generation_server.adapters import AdapterBatchData from text_generation_server.layers.attention import paged_attention, attention, Seqlen -from text_generation_server.layers.attention.kv_cache import KVScales +from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.models.globals import ATTENTION from text_generation_server.models.metadata_kernels import block_tables_to_ragged @@ -37,14 +37,20 @@ def tgi_flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - attention_mask: torch.Tensor, + kv_cache: List[KVCache], + kv_head_mapping: torch.Tensor, + slots: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + seqlen: Seqlen, + block_tables: torch.Tensor, + max_s: int, softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, - **kwargs, + **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): - kv_cache = kwargs["kv_cache"][module.layer_idx] + kv_cache = kv_cache[module.layer_idx] # This means no scale kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) @@ -56,7 +62,7 @@ def tgi_flash_attention_forward( kv_cache.store( key=key_states, value=value_states, - slots=kwargs["slots"], + slots=slots, kv_scales=kv_scales ) @@ -64,15 +70,15 @@ def tgi_flash_attention_forward( softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale sliding_window = -1 if sliding_window is None else sliding_window - if kwargs["cu_seqlen_prefill"] is not None: + if cu_seqlen_prefill is not None: attn_output = attention( query=query_states, key=key_states, value=value_states, kv_cache=kv_cache, kv_scales=kv_scales, - seqlen=kwargs["seqlen"], - block_tables=kwargs["block_tables"], + seqlen=seqlen, + block_tables=block_tables, softmax_scale=softmax_scale, window_size_left=sliding_window, softcap=softcap, @@ -81,11 +87,11 @@ def tgi_flash_attention_forward( attn_output = paged_attention( query_states, kv_cache, - kwargs["kv_head_mapping"], + kv_head_mapping, softmax_scale, - kwargs["block_tables"], - kwargs["seqlen"], - kwargs["max_s"], + block_tables, + seqlen, + max_s, kv_scales=kv_scales, softcap=softcap, ) @@ -145,16 +151,13 @@ class TransformersFlashCausalLM(FlashCausalLM): model_id, revision=revision, torch_dtype=dtype, - device_map=("auto" if device_count > 1 else None), + device_map="auto", load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, attn_implementation="tgi", tp_plan="auto" if world_size > 1 else None, ) - if device_count == 1 and quantize != "bitsandbytes": - model = model.to(device) - if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: tokenizer.pad_token_id = model.config.pad_token_id @@ -237,23 +240,21 @@ class TransformersFlashCausalLM(FlashCausalLM): trust_remote_code=trust_remote_code, ) - def _model_forward( self, - input_ids, - position_ids, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - seqlen, - max_s, - prefill_cache_indices, - lm_head_indices, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[KVCache], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + lm_head_indices: torch.Tensor, ): hidden_states = self.model.model.forward( - input_ids=input_ids[None, ...], # expand dim to easily fit transformers - position_ids=position_ids[None, ...], # expand dim to easily fit transformers + input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers + position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object return_dict=True, @@ -263,7 +264,6 @@ class TransformersFlashCausalLM(FlashCausalLM): slots=slots, seqlen=seqlen, max_s=max_s, - prefill_cache_indices=prefill_cache_indices, kv_head_mapping=self.kv_head_mapping, )[0].squeeze(dim=0) # And compute logits from the lm_head, slicing correctly the indices @@ -335,7 +335,6 @@ class TransformersFlashCausalLM(FlashCausalLM): slots=slots, seqlen=seqlen, max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, ) if batch.prefill_cache_indices is not None: @@ -496,7 +495,6 @@ class TransformersFlashCausalLM(FlashCausalLM): slots=slots, seqlen=seqlen, max_s=max_s, - prefill_cache_indices=None, lm_head_indices=None, ) del seqlen @@ -520,7 +518,6 @@ class TransformersFlashCausalLM(FlashCausalLM): slots=slots, seqlen=seqlen, max_s=max_s, - prefill_cache_indices=None, lm_head_indices=None, ) self.cuda_graphs[bs]["logits"] = logits @@ -561,5 +558,4 @@ class TransformersFlashCausalLM(FlashCausalLM): slots=slots, max_s=max_s, lm_head_indices=None, - prefill_cache_indices=None, ) \ No newline at end of file