mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
revert + style + minor improvements
This commit is contained in:
parent
a2fe842795
commit
6e0f37c0ca
@ -963,7 +963,9 @@ def quantize(
|
||||
|
||||
max_shard_size = "10GB"
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern="model.safetensors", max_shard_size=max_shard_size,
|
||||
state_dict,
|
||||
filename_pattern="model.safetensors",
|
||||
max_shard_size=max_shard_size,
|
||||
)
|
||||
index = None
|
||||
if state_dict_split.is_sharded:
|
||||
|
@ -21,7 +21,9 @@ import transformers
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||
from text_generation_server.models.transformers_flash_causal_lm import TransformersFlashCausalLM
|
||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||
TransformersFlashCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||
MPTForCausalLM,
|
||||
@ -377,11 +379,19 @@ def get_model(
|
||||
transformers_causal_lm_class = CausalLM
|
||||
|
||||
# Fast transformers path
|
||||
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
||||
if transformers_model_class.is_backend_compatible():
|
||||
transformers_model_class = getattr(
|
||||
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
||||
)
|
||||
if transformers_model_class._supports_flex_attn:
|
||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||
if not FLASH_ATTENTION and lora_adapter_ids is not None and len(lora_adapter_ids) > 0:
|
||||
raise ValueError("Transformers backend AutoModel do not support `lora_adapter_ids`.")
|
||||
if (
|
||||
not FLASH_ATTENTION
|
||||
and lora_adapter_ids is not None
|
||||
and len(lora_adapter_ids) > 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Transformers backend AutoModel do not support `lora_adapter_ids`."
|
||||
)
|
||||
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
if quantization_config is None:
|
||||
|
@ -67,4 +67,3 @@ def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||
def get_adapter_to_index():
|
||||
global ADAPTER_TO_INDEX
|
||||
return ADAPTER_TO_INDEX
|
||||
|
||||
|
@ -22,7 +22,7 @@ def tgi_flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor], # This needs to stay as it is passed as a positional arg in transformers
|
||||
attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers
|
||||
kv_cache: List[KVCache],
|
||||
kv_head_mapping: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
@ -30,6 +30,7 @@ def tgi_flash_attention_forward(
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
max_s: int,
|
||||
kv_scales: KVScales,
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
softcap: Optional[float] = None,
|
||||
@ -37,20 +38,13 @@ def tgi_flash_attention_forward(
|
||||
):
|
||||
|
||||
kv_cache = kv_cache[module.layer_idx]
|
||||
# This means no scale
|
||||
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
||||
|
||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||
|
||||
# Take care of updating the cache in-place
|
||||
kv_cache.store(
|
||||
key=key_states,
|
||||
value=value_states,
|
||||
slots=slots,
|
||||
kv_scales=kv_scales
|
||||
)
|
||||
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
||||
|
||||
_, num_heads, head_dim = query_states.shape
|
||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||
@ -110,14 +104,11 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
if speculator:
|
||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||
|
||||
device_count = 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
device_count = torch.cuda.device_count()
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device("xpu")
|
||||
device_count = torch.xpu.device_count()
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
@ -156,7 +147,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
|
||||
self.num_layers = model.config.num_hidden_layers
|
||||
self.num_heads = model.config.num_attention_heads // self.process_group.size()
|
||||
self.num_kv_heads = model.config.num_key_value_heads
|
||||
@ -190,9 +180,16 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_kv_heads
|
||||
|
||||
# Those will never change and will be used in the forwards
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_kv_heads, dtype=torch.int32, device=device
|
||||
).repeat_interleave(self.num_groups)
|
||||
# This means no scale
|
||||
self.kv_scales = KVScales(
|
||||
torch.tensor(1.0, device=device),
|
||||
torch.tensor(1.0, device=device),
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
# Skip FlashCausalLM init.
|
||||
@ -242,21 +239,17 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
prefill_cache_indices = None, # not used, but passed to match original signature
|
||||
adapter_data = None, # not supported, but passed to match original signature
|
||||
prefill_cache_indices=None, # not used, but passed to match original signature
|
||||
adapter_data=None, # not supported, but passed to match original signature
|
||||
):
|
||||
# Transformers does not support None as a default
|
||||
if lm_head_indices is None:
|
||||
lm_head_indices = 0
|
||||
|
||||
# Equivalent tp `self.model.forward`, see the monkey patch in __init__
|
||||
logits = self.model.original_forward(
|
||||
hidden_states = self.model.model.forward(
|
||||
input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers
|
||||
position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers
|
||||
position_ids=position_ids.unsqueeze(
|
||||
0
|
||||
), # expand dim to easily fit transformers
|
||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||
return_dict=True,
|
||||
num_logits_to_keep=lm_head_indices,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
@ -264,6 +257,14 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
kv_head_mapping=self.kv_head_mapping,
|
||||
).logits.squeeze(dim=0)
|
||||
kv_scales=self.kv_scales,
|
||||
)[0].squeeze(dim=0)
|
||||
|
||||
return logits, None
|
||||
# And compute logits from the lm_head, slicing correctly the indices
|
||||
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
|
||||
# To update with full Transformers support asap
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.model.lm_head.forward(hidden_states)
|
||||
|
||||
return logits, None
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
from typing import List, Optional, DefaultDict
|
||||
|
||||
from loguru import logger
|
||||
from typing import Dict, Union
|
||||
from typing import Dict
|
||||
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
|
Loading…
Reference in New Issue
Block a user