mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-17 06:42:08 +00:00
cleanup comment
This commit is contained in:
parent
ac6fc70c75
commit
b41faae318
@ -6,7 +6,6 @@ from typing import Dict, Optional
|
|||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||||
|
|
||||||
ATTENTION = os.environ["ATTENTION"]
|
ATTENTION = os.environ["ATTENTION"]
|
||||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
||||||
|
@ -36,12 +36,11 @@ def tgi_flash_attention_forward(
|
|||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
):
|
):
|
||||||
# from pdb import set_trace; set_trace()
|
|
||||||
kv_cache = kv_cache[module.layer_idx]
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||||
# from pdb import set_trace; set_trace()
|
|
||||||
|
|
||||||
# Take care of updating the cache in-place
|
# Take care of updating the cache in-place
|
||||||
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
||||||
@ -49,7 +48,6 @@ def tgi_flash_attention_forward(
|
|||||||
_, num_heads, head_dim = query_states.shape
|
_, num_heads, head_dim = query_states.shape
|
||||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
sliding_window = -1 if sliding_window is None else sliding_window
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
# from pdb import set_trace; set_trace()
|
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
@ -108,7 +106,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
# # from pdb import set_trace; set_trace()
|
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
|
||||||
@ -266,7 +263,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
prefill_cache_indices=None, # not used, but passed to match original signature
|
prefill_cache_indices=None, # not used, but passed to match original signature
|
||||||
adapter_data=None, # not supported, but passed to match original signature
|
adapter_data=None, # not supported, but passed to match original signature
|
||||||
):
|
):
|
||||||
# from pdb import set_trace; set_trace()
|
|
||||||
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||||
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user