mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
improve type hints + required args
This commit is contained in:
parent
32488c1a11
commit
ac62bd1572
@ -1,6 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional, Tuple, Dict, Any
|
from typing import List, Optional, Tuple, Dict, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
@ -24,7 +24,7 @@ from text_generation_server.utils.import_utils import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.adapters import AdapterBatchData
|
from text_generation_server.adapters import AdapterBatchData
|
||||||
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||||
from text_generation_server.layers.attention.kv_cache import KVScales
|
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION
|
||||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||||
|
|
||||||
@ -37,14 +37,20 @@ def tgi_flash_attention_forward(
|
|||||||
query_states: torch.Tensor,
|
query_states: torch.Tensor,
|
||||||
key_states: torch.Tensor,
|
key_states: torch.Tensor,
|
||||||
value_states: torch.Tensor,
|
value_states: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
kv_cache: List[KVCache],
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
seqlen: Seqlen,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
**kwargs,
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
):
|
):
|
||||||
|
|
||||||
kv_cache = kwargs["kv_cache"][module.layer_idx]
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
# This means no scale
|
# This means no scale
|
||||||
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
||||||
|
|
||||||
@ -56,7 +62,7 @@ def tgi_flash_attention_forward(
|
|||||||
kv_cache.store(
|
kv_cache.store(
|
||||||
key=key_states,
|
key=key_states,
|
||||||
value=value_states,
|
value=value_states,
|
||||||
slots=kwargs["slots"],
|
slots=slots,
|
||||||
kv_scales=kv_scales
|
kv_scales=kv_scales
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -64,15 +70,15 @@ def tgi_flash_attention_forward(
|
|||||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
sliding_window = -1 if sliding_window is None else sliding_window
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
|
|
||||||
if kwargs["cu_seqlen_prefill"] is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query=query_states,
|
query=query_states,
|
||||||
key=key_states,
|
key=key_states,
|
||||||
value=value_states,
|
value=value_states,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
kv_scales=kv_scales,
|
kv_scales=kv_scales,
|
||||||
seqlen=kwargs["seqlen"],
|
seqlen=seqlen,
|
||||||
block_tables=kwargs["block_tables"],
|
block_tables=block_tables,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
window_size_left=sliding_window,
|
window_size_left=sliding_window,
|
||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
@ -81,11 +87,11 @@ def tgi_flash_attention_forward(
|
|||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query_states,
|
query_states,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
kwargs["kv_head_mapping"],
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
kwargs["block_tables"],
|
block_tables,
|
||||||
kwargs["seqlen"],
|
seqlen,
|
||||||
kwargs["max_s"],
|
max_s,
|
||||||
kv_scales=kv_scales,
|
kv_scales=kv_scales,
|
||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
)
|
)
|
||||||
@ -145,16 +151,13 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=("auto" if device_count > 1 else None),
|
device_map="auto",
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
attn_implementation="tgi",
|
attn_implementation="tgi",
|
||||||
tp_plan="auto" if world_size > 1 else None,
|
tp_plan="auto" if world_size > 1 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if device_count == 1 and quantize != "bitsandbytes":
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
tokenizer.pad_token_id = model.config.pad_token_id
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
@ -237,23 +240,21 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _model_forward(
|
def _model_forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids: torch.Tensor,
|
||||||
position_ids,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache,
|
kv_cache: List[KVCache],
|
||||||
block_tables,
|
block_tables: torch.Tensor,
|
||||||
slots,
|
slots: torch.Tensor,
|
||||||
seqlen,
|
seqlen: Seqlen,
|
||||||
max_s,
|
max_s: int,
|
||||||
prefill_cache_indices,
|
lm_head_indices: torch.Tensor,
|
||||||
lm_head_indices,
|
|
||||||
):
|
):
|
||||||
hidden_states = self.model.model.forward(
|
hidden_states = self.model.model.forward(
|
||||||
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
|
input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers
|
||||||
position_ids=position_ids[None, ...], # expand dim to easily fit transformers
|
position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers
|
||||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@ -263,7 +264,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
|
||||||
kv_head_mapping=self.kv_head_mapping,
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
)[0].squeeze(dim=0)
|
)[0].squeeze(dim=0)
|
||||||
# And compute logits from the lm_head, slicing correctly the indices
|
# And compute logits from the lm_head, slicing correctly the indices
|
||||||
@ -335,7 +335,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
@ -496,7 +495,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
del seqlen
|
del seqlen
|
||||||
@ -520,7 +518,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
self.cuda_graphs[bs]["logits"] = logits
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
@ -561,5 +558,4 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
prefill_cache_indices=None,
|
|
||||||
)
|
)
|
Loading…
Reference in New Issue
Block a user