revert + style + minor improvements

This commit is contained in:
Cyril Vallez 2025-01-20 15:13:24 +01:00
parent a2fe842795
commit 6e0f37c0ca
No known key found for this signature in database
5 changed files with 45 additions and 33 deletions

View File

@ -963,7 +963,9 @@ def quantize(
max_shard_size = "10GB" max_shard_size = "10GB"
state_dict_split = split_torch_state_dict_into_shards( state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern="model.safetensors", max_shard_size=max_shard_size, state_dict,
filename_pattern="model.safetensors",
max_shard_size=max_shard_size,
) )
index = None index = None
if state_dict_split.is_sharded: if state_dict_split.is_sharded:

View File

@ -21,7 +21,9 @@ import transformers
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.transformers_flash_causal_lm import TransformersFlashCausalLM from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import ( from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM, MPTForCausalLM,
@ -377,11 +379,19 @@ def get_model(
transformers_causal_lm_class = CausalLM transformers_causal_lm_class = CausalLM
# Fast transformers path # Fast transformers path
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) transformers_model_class = getattr(
if transformers_model_class.is_backend_compatible(): transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
)
if transformers_model_class._supports_flex_attn:
transformers_causal_lm_class = TransformersFlashCausalLM transformers_causal_lm_class = TransformersFlashCausalLM
if not FLASH_ATTENTION and lora_adapter_ids is not None and len(lora_adapter_ids) > 0: if (
raise ValueError("Transformers backend AutoModel do not support `lora_adapter_ids`.") not FLASH_ATTENTION
and lora_adapter_ids is not None
and len(lora_adapter_ids) > 0
):
raise ValueError(
"Transformers backend AutoModel do not support `lora_adapter_ids`."
)
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)
if quantization_config is None: if quantization_config is None:

View File

@ -67,4 +67,3 @@ def set_adapter_to_index(adapter_to_index: Dict[str, int]):
def get_adapter_to_index(): def get_adapter_to_index():
global ADAPTER_TO_INDEX global ADAPTER_TO_INDEX
return ADAPTER_TO_INDEX return ADAPTER_TO_INDEX

View File

@ -22,7 +22,7 @@ def tgi_flash_attention_forward(
query_states: torch.Tensor, query_states: torch.Tensor,
key_states: torch.Tensor, key_states: torch.Tensor,
value_states: torch.Tensor, value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor], # This needs to stay as it is passed as a positional arg in transformers attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers
kv_cache: List[KVCache], kv_cache: List[KVCache],
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
@ -30,6 +30,7 @@ def tgi_flash_attention_forward(
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
max_s: int, max_s: int,
kv_scales: KVScales,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
softcap: Optional[float] = None, softcap: Optional[float] = None,
@ -37,20 +38,13 @@ def tgi_flash_attention_forward(
): ):
kv_cache = kv_cache[module.layer_idx] kv_cache = kv_cache[module.layer_idx]
# This means no scale
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
query_states = query_states.transpose(1, 2).squeeze(dim=0) query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0)
value_states = value_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0)
# Take care of updating the cache in-place # Take care of updating the cache in-place
kv_cache.store( kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
key=key_states,
value=value_states,
slots=slots,
kv_scales=kv_scales
)
_, num_heads, head_dim = query_states.shape _, num_heads, head_dim = query_states.shape
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
@ -110,14 +104,11 @@ class TransformersFlashCausalLM(FlashCausalLM):
if speculator: if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
device_count = 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda:0") device = torch.device("cuda:0")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available(): elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu") device = torch.device("xpu")
device_count = torch.xpu.device_count()
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
@ -156,7 +147,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
else: else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.num_layers = model.config.num_hidden_layers self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads // self.process_group.size() self.num_heads = model.config.num_attention_heads // self.process_group.size()
self.num_kv_heads = model.config.num_key_value_heads self.num_kv_heads = model.config.num_key_value_heads
@ -190,9 +180,16 @@ class TransformersFlashCausalLM(FlashCausalLM):
) )
self.num_groups = self.num_heads // self.num_kv_heads self.num_groups = self.num_heads // self.num_kv_heads
# Those will never change and will be used in the forwards
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_kv_heads, dtype=torch.int32, device=device 0, self.num_kv_heads, dtype=torch.int32, device=device
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
# This means no scale
self.kv_scales = KVScales(
torch.tensor(1.0, device=device),
torch.tensor(1.0, device=device),
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
# Skip FlashCausalLM init. # Skip FlashCausalLM init.
@ -242,21 +239,17 @@ class TransformersFlashCausalLM(FlashCausalLM):
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor],
prefill_cache_indices = None, # not used, but passed to match original signature prefill_cache_indices=None, # not used, but passed to match original signature
adapter_data = None, # not supported, but passed to match original signature adapter_data=None, # not supported, but passed to match original signature
): ):
# Transformers does not support None as a default hidden_states = self.model.model.forward(
if lm_head_indices is None:
lm_head_indices = 0
# Equivalent tp `self.model.forward`, see the monkey patch in __init__
logits = self.model.original_forward(
input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers
position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers position_ids=position_ids.unsqueeze(
0
), # expand dim to easily fit transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True, return_dict=True,
num_logits_to_keep=lm_head_indices,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
@ -264,6 +257,14 @@ class TransformersFlashCausalLM(FlashCausalLM):
seqlen=seqlen, seqlen=seqlen,
max_s=max_s, max_s=max_s,
kv_head_mapping=self.kv_head_mapping, kv_head_mapping=self.kv_head_mapping,
).logits.squeeze(dim=0) kv_scales=self.kv_scales,
)[0].squeeze(dim=0)
return logits, None # And compute logits from the lm_head, slicing correctly the indices
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
# To update with full Transformers support asap
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head.forward(hidden_states)
return logits, None

View File

@ -5,7 +5,7 @@ import torch
from typing import List, Optional, DefaultDict from typing import List, Optional, DefaultDict
from loguru import logger from loguru import logger
from typing import Dict, Union from typing import Dict
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide