improve type hints + required args

This commit is contained in:
Cyril Vallez 2025-01-17 13:09:52 +00:00
parent 32488c1a11
commit ac62bd1572
No known key found for this signature in database

View File

@ -1,6 +1,6 @@
import math import math
import sys import sys
from typing import Optional, Tuple, Dict, Any from typing import List, Optional, Tuple, Dict, Any
import torch import torch
from opentelemetry import trace from opentelemetry import trace
@ -24,7 +24,7 @@ from text_generation_server.utils.import_utils import (
) )
from text_generation_server.adapters import AdapterBatchData from text_generation_server.adapters import AdapterBatchData
from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention import paged_attention, attention, Seqlen
from text_generation_server.layers.attention.kv_cache import KVScales from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION from text_generation_server.models.globals import ATTENTION
from text_generation_server.models.metadata_kernels import block_tables_to_ragged from text_generation_server.models.metadata_kernels import block_tables_to_ragged
@ -37,14 +37,20 @@ def tgi_flash_attention_forward(
query_states: torch.Tensor, query_states: torch.Tensor,
key_states: torch.Tensor, key_states: torch.Tensor,
value_states: torch.Tensor, value_states: torch.Tensor,
attention_mask: torch.Tensor, kv_cache: List[KVCache],
kv_head_mapping: torch.Tensor,
slots: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
seqlen: Seqlen,
block_tables: torch.Tensor,
max_s: int,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
softcap: Optional[float] = None, softcap: Optional[float] = None,
**kwargs, **kwargs, # This is needed to "absorb" other args passed by Transformers modeling
): ):
kv_cache = kwargs["kv_cache"][module.layer_idx] kv_cache = kv_cache[module.layer_idx]
# This means no scale # This means no scale
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
@ -56,7 +62,7 @@ def tgi_flash_attention_forward(
kv_cache.store( kv_cache.store(
key=key_states, key=key_states,
value=value_states, value=value_states,
slots=kwargs["slots"], slots=slots,
kv_scales=kv_scales kv_scales=kv_scales
) )
@ -64,15 +70,15 @@ def tgi_flash_attention_forward(
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window sliding_window = -1 if sliding_window is None else sliding_window
if kwargs["cu_seqlen_prefill"] is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query=query_states, query=query_states,
key=key_states, key=key_states,
value=value_states, value=value_states,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=kv_scales, kv_scales=kv_scales,
seqlen=kwargs["seqlen"], seqlen=seqlen,
block_tables=kwargs["block_tables"], block_tables=block_tables,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
window_size_left=sliding_window, window_size_left=sliding_window,
softcap=softcap, softcap=softcap,
@ -81,11 +87,11 @@ def tgi_flash_attention_forward(
attn_output = paged_attention( attn_output = paged_attention(
query_states, query_states,
kv_cache, kv_cache,
kwargs["kv_head_mapping"], kv_head_mapping,
softmax_scale, softmax_scale,
kwargs["block_tables"], block_tables,
kwargs["seqlen"], seqlen,
kwargs["max_s"], max_s,
kv_scales=kv_scales, kv_scales=kv_scales,
softcap=softcap, softcap=softcap,
) )
@ -145,16 +151,13 @@ class TransformersFlashCausalLM(FlashCausalLM):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map=("auto" if device_count > 1 else None), device_map="auto",
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
attn_implementation="tgi", attn_implementation="tgi",
tp_plan="auto" if world_size > 1 else None, tp_plan="auto" if world_size > 1 else None,
) )
if device_count == 1 and quantize != "bitsandbytes":
model = model.to(device)
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None: if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id tokenizer.pad_token_id = model.config.pad_token_id
@ -237,23 +240,21 @@ class TransformersFlashCausalLM(FlashCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
def _model_forward( def _model_forward(
self, self,
input_ids, input_ids: torch.Tensor,
position_ids, position_ids: torch.Tensor,
cu_seqlen_prefill, cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache, kv_cache: List[KVCache],
block_tables, block_tables: torch.Tensor,
slots, slots: torch.Tensor,
seqlen, seqlen: Seqlen,
max_s, max_s: int,
prefill_cache_indices, lm_head_indices: torch.Tensor,
lm_head_indices,
): ):
hidden_states = self.model.model.forward( hidden_states = self.model.model.forward(
input_ids=input_ids[None, ...], # expand dim to easily fit transformers input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers
position_ids=position_ids[None, ...], # expand dim to easily fit transformers position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True, return_dict=True,
@ -263,7 +264,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
kv_head_mapping=self.kv_head_mapping, kv_head_mapping=self.kv_head_mapping,
)[0].squeeze(dim=0) )[0].squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices # And compute logits from the lm_head, slicing correctly the indices
@ -335,7 +335,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
@ -496,7 +495,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
) )
del seqlen del seqlen
@ -520,7 +518,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
) )
self.cuda_graphs[bs]["logits"] = logits self.cuda_graphs[bs]["logits"] = logits
@ -561,5 +558,4 @@ class TransformersFlashCausalLM(FlashCausalLM):
slots=slots, slots=slots,
max_s=max_s, max_s=max_s,
lm_head_indices=None, lm_head_indices=None,
prefill_cache_indices=None,
) )