mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 02:42:05 +00:00
initial changes
This commit is contained in:
parent
e497bc09f6
commit
d8d09f9c7b
@ -206,9 +206,15 @@ try:
|
|||||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||||
TransformersFlashCausalLM,
|
TransformersFlashCausalLM,
|
||||||
)
|
)
|
||||||
except ImportError:
|
from text_generation_server.models.transformers_flash_vlm import (
|
||||||
|
TransformersFlashVlmCausalLM,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
|
||||||
FLASH_TRANSFORMERS_BACKEND = False
|
FLASH_TRANSFORMERS_BACKEND = False
|
||||||
|
|
||||||
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
|
|
||||||
class ModelType(enum.Enum):
|
class ModelType(enum.Enum):
|
||||||
DEEPSEEK_V2 = {
|
DEEPSEEK_V2 = {
|
||||||
@ -1173,12 +1179,13 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif FLASH_TRANSFORMERS_BACKEND:
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
return TransformersFlashCausalLM.fallback(
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
|
# AutoModelForConditionalGeneration,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
@ -1483,6 +1490,15 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == QWEN2_VL:
|
if model_type == QWEN2_VL:
|
||||||
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
# AutoModelForConditionalGeneration,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Qwen2VLForConditionalGeneration,
|
model_class=Qwen2VLForConditionalGeneration,
|
||||||
@ -1563,23 +1579,33 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == PALIGEMMA:
|
if model_type == PALIGEMMA:
|
||||||
if FLASH_ATTENTION:
|
# return TransformersFlashVlmCausalLM.fallback(
|
||||||
return VlmCausalLM(
|
# model_id,
|
||||||
model_id=model_id,
|
# # AutoModelForConditionalGeneration,
|
||||||
model_class=PaliGemmaForConditionalGeneration,
|
# revision,
|
||||||
revision=revision,
|
# quantize=quantize,
|
||||||
quantize=quantize,
|
# speculator=speculator,
|
||||||
speculator=speculator,
|
# dtype=torch.bfloat16,
|
||||||
dtype=dtype,
|
# trust_remote_code=trust_remote_code,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
# batch_class=PaliGemmaBatch,
|
||||||
# Works better for these models
|
# )
|
||||||
default_dtype=torch.bfloat16,
|
# if FLASH_ATTENTION:
|
||||||
trust_remote_code=trust_remote_code,
|
return VlmCausalLM(
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
model_id=model_id,
|
||||||
batch_class=PaliGemmaBatch,
|
model_class=PaliGemmaForConditionalGeneration,
|
||||||
)
|
revision=revision,
|
||||||
else:
|
quantize=quantize,
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
batch_class=PaliGemmaBatch,
|
||||||
|
)
|
||||||
|
# else:
|
||||||
|
# raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == LLAVA_NEXT:
|
if model_type == LLAVA_NEXT:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -1344,9 +1344,6 @@ class FlashCausalLM(Model):
|
|||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
def max_past(self) -> int:
|
|
||||||
return getattr(self.model, "max_past", None)
|
|
||||||
|
|
||||||
def init_kv_cache(
|
def init_kv_cache(
|
||||||
self,
|
self,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
@ -1792,12 +1789,6 @@ class FlashCausalLM(Model):
|
|||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
||||||
# in a circular buffer mode.
|
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
|
||||||
max_s = min(self.max_past(), max_s)
|
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
if sorted_padded_bs:
|
if sorted_padded_bs:
|
||||||
|
@ -6,6 +6,7 @@ from typing import Dict, Optional
|
|||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||||
|
|
||||||
ATTENTION = os.environ["ATTENTION"]
|
ATTENTION = os.environ["ATTENTION"]
|
||||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
||||||
|
@ -36,10 +36,12 @@ def tgi_flash_attention_forward(
|
|||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
):
|
):
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
kv_cache = kv_cache[module.layer_idx]
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
|
|
||||||
# Take care of updating the cache in-place
|
# Take care of updating the cache in-place
|
||||||
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
||||||
@ -47,6 +49,7 @@ def tgi_flash_attention_forward(
|
|||||||
_, num_heads, head_dim = query_states.shape
|
_, num_heads, head_dim = query_states.shape
|
||||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
sliding_window = -1 if sliding_window is None else sliding_window
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
@ -72,6 +75,7 @@ def tgi_flash_attention_forward(
|
|||||||
max_s,
|
max_s,
|
||||||
kv_scales=kv_scales,
|
kv_scales=kv_scales,
|
||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
|
window_size_left=sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(-1, num_heads * head_dim)
|
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||||
@ -104,6 +108,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
|
# # from pdb import set_trace; set_trace()
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
|
||||||
@ -157,7 +162,14 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
self.num_layers = model.config.num_hidden_layers
|
self.num_layers = model.config.num_hidden_layers
|
||||||
self.num_heads = model.config.num_attention_heads
|
self.num_heads = model.config.num_attention_heads
|
||||||
self.num_kv_heads = model.config.num_key_value_heads
|
self.num_kv_heads = model.config.num_key_value_heads
|
||||||
self.head_size = model.config.hidden_size // model.config.num_attention_heads
|
# Some models use GQA and different sizes for o_proj
|
||||||
|
# and q_proj, that allows for that.
|
||||||
|
if hasattr(model.config, "head_dim"):
|
||||||
|
self.head_size = model.config.head_dim
|
||||||
|
else:
|
||||||
|
self.head_size = (
|
||||||
|
model.config.hidden_size // model.config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
# Skip it for models in the exception list
|
# Skip it for models in the exception list
|
||||||
if model.config.model_type not in REPLICATED_ATTENTION_MODELS:
|
if model.config.model_type not in REPLICATED_ATTENTION_MODELS:
|
||||||
@ -254,6 +266,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
prefill_cache_indices=None, # not used, but passed to match original signature
|
prefill_cache_indices=None, # not used, but passed to match original signature
|
||||||
adapter_data=None, # not supported, but passed to match original signature
|
adapter_data=None, # not supported, but passed to match original signature
|
||||||
):
|
):
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||||
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||||
|
|
||||||
|
467
server/text_generation_server/models/transformers_flash_vlm.py
Normal file
467
server/text_generation_server/models/transformers_flash_vlm.py
Normal file
@ -0,0 +1,467 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from opentelemetry import trace
|
||||||
|
from transformers import AutoTokenizer, AutoProcessor
|
||||||
|
import transformers.modeling_utils
|
||||||
|
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch
|
||||||
|
from text_generation_server.utils import initialize_torch_distributed
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||||
|
from text_generation_server.models.globals import ATTENTION
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def tgi_flash_attention_forward(
|
||||||
|
module,
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers
|
||||||
|
kv_cache: List[KVCache],
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
seqlen: Seqlen,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
kv_scales: KVScales,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
|
):
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
logger.info("Using TGI Flash Attention")
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
|
|
||||||
|
# Take care of updating the cache in-place
|
||||||
|
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
|
||||||
|
|
||||||
|
_, num_heads, head_dim = query_states.shape
|
||||||
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
|
sliding_window = -1 if sliding_window is None else sliding_window
|
||||||
|
# if module.layer_idx == 0:
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
attention_mask = None
|
||||||
|
if attention_mask is None:
|
||||||
|
attn_output = attention(
|
||||||
|
query=query_states,
|
||||||
|
key=key_states,
|
||||||
|
value=value_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
block_tables=block_tables,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
window_size_left=sliding_window,
|
||||||
|
softcap=softcap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
logger.info("uSING FLASH ATTENTION")
|
||||||
|
lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]
|
||||||
|
max_length = max(lengths)
|
||||||
|
attention_mask = attention_mask[:, :, :, :max_length]
|
||||||
|
enable_gqa = query_states.shape[1] != key_states.shape[1]
|
||||||
|
# Split tensors using vectorized split
|
||||||
|
query_list = torch.split(query_states, lengths.tolist(), dim=0)
|
||||||
|
key_list = torch.split(key_states, lengths.tolist(), dim=0)
|
||||||
|
value_list = torch.split(value_states, lengths.tolist(), dim=0)
|
||||||
|
|
||||||
|
padded_query = torch.nn.utils.rnn.pad_sequence(query_list, batch_first=True)
|
||||||
|
padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True)
|
||||||
|
padded_value = torch.nn.utils.rnn.pad_sequence(value_list, batch_first=True)
|
||||||
|
|
||||||
|
padded_query = padded_query.transpose(1, 2).contiguous()
|
||||||
|
padded_key = padded_key.transpose(1, 2).contiguous()
|
||||||
|
padded_value = padded_value.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
padded_query,
|
||||||
|
padded_key,
|
||||||
|
padded_value,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
scale=softmax_scale,
|
||||||
|
enable_gqa=enable_gqa,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(
|
||||||
|
1, 2
|
||||||
|
) # [batch_size, seq_len, num_heads, head_dim]
|
||||||
|
max_seq_len = padded_query.size(2)
|
||||||
|
seq_range = torch.arange(max_seq_len, device=padded_query.device).unsqueeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
lengths_tensor = torch.tensor(
|
||||||
|
lengths, device=padded_query.device
|
||||||
|
).unsqueeze(1)
|
||||||
|
mask = seq_range < lengths_tensor # [batch, max_seq_len]
|
||||||
|
attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim]
|
||||||
|
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query_states,
|
||||||
|
kv_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
kv_scales=kv_scales,
|
||||||
|
softcap=softcap,
|
||||||
|
window_size_left=sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||||
|
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
|
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||||
|
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["tgi"] = (
|
||||||
|
transformers.models.siglip.modeling_siglip.SIGLIP_ATTENTION_CLASSES["sdpa"]
|
||||||
|
)
|
||||||
|
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES["eager"]
|
||||||
|
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES["eager"]
|
||||||
|
transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_ATTENTION_CLASSES["tgi"] = (
|
||||||
|
tgi_flash_attention_forward
|
||||||
|
)
|
||||||
|
transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||||
|
"tgi"
|
||||||
|
] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||||
|
"eager"
|
||||||
|
]
|
||||||
|
|
||||||
|
# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
|
||||||
|
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
|
||||||
|
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
|
||||||
|
# to internal constraints it was not (yet?) possible to circumvent
|
||||||
|
REPLICATED_ATTENTION_MODELS = [
|
||||||
|
"olmo2",
|
||||||
|
"phi3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
default_dtype=torch.float16,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tokenizer_class=AutoTokenizer,
|
||||||
|
processor_class=AutoProcessor,
|
||||||
|
processor_kwargs=None,
|
||||||
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
|
batch_class=VlmCausalLMBatch,
|
||||||
|
):
|
||||||
|
# # from pdb import set_trace; set_trace()
|
||||||
|
self.batch_class = VlmCausalLMBatch
|
||||||
|
self.quantize = quantize
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
|
||||||
|
if speculator:
|
||||||
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device("xpu")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Flash `Transformers` modeling backend is not available on cpu."
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if processor_kwargs is None:
|
||||||
|
processor_kwargs = {}
|
||||||
|
# processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}
|
||||||
|
self.processor = processor_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**processor_kwargs,
|
||||||
|
)
|
||||||
|
from transformers import Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
attn_implementation="tgi",
|
||||||
|
device_map=device if world_size == 1 else None,
|
||||||
|
tp_plan="auto" if world_size > 1 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
self.config = model.config
|
||||||
|
config = model.config
|
||||||
|
|
||||||
|
# VLM models define the config we care about in their text_config
|
||||||
|
text_config = getattr(model.config, "text_config", None)
|
||||||
|
if text_config is not None:
|
||||||
|
config = text_config
|
||||||
|
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
if model.config.pad_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
|
elif model.config.eos_token_id is not None and isinstance(
|
||||||
|
model.config.eos_token_id, int
|
||||||
|
):
|
||||||
|
tokenizer.pad_token_id = model.config.eos_token_id
|
||||||
|
elif tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
else:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.num_kv_heads = config.num_key_value_heads
|
||||||
|
# Some models use GQA and different sizes for o_proj
|
||||||
|
# and q_proj, that allows for that.
|
||||||
|
if hasattr(config, "head_dim"):
|
||||||
|
self.head_size = config.head_dim
|
||||||
|
else:
|
||||||
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
|
# Skip it for models in the exception list
|
||||||
|
if config.model_type not in REPLICATED_ATTENTION_MODELS:
|
||||||
|
self.num_heads = self.num_heads // self.process_group.size()
|
||||||
|
self.num_kv_heads = (
|
||||||
|
self.num_kv_heads // self.process_group.size()
|
||||||
|
if self.num_kv_heads > 1
|
||||||
|
else self.num_kv_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cuda_graphs = {}
|
||||||
|
self.kv_cache = []
|
||||||
|
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
create_prefill_state,
|
||||||
|
create_decode_state,
|
||||||
|
create_prefill_with_paged_kv_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.prefill_state = create_prefill_state(device=device)
|
||||||
|
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decode_state = create_decode_state(
|
||||||
|
device=device,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_groups = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
# Those will never change and will be used in the forwards
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_kv_heads, dtype=torch.int32, device=device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
# This means no scale
|
||||||
|
self.kv_scales = KVScales(
|
||||||
|
torch.tensor(1.0, device=device),
|
||||||
|
torch.tensor(1.0, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip FlashCausalLM init.
|
||||||
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
||||||
|
# We first copy the original model.forward because we still need it in the monkey patch
|
||||||
|
self.model.original_forward = self.model.forward
|
||||||
|
self.model.forward = self._model_forward
|
||||||
|
self.model.get_position_ids = self.get_position_ids
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor):
|
||||||
|
if image_grid_thw is None:
|
||||||
|
return (
|
||||||
|
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.repeat(1, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||||
|
vision_start_token_id = self.config.vision_start_token_id
|
||||||
|
vision_end_token_id = self.config.vision_end_token_id
|
||||||
|
device = input_ids.device
|
||||||
|
dtype = input_ids.dtype
|
||||||
|
input_ids_len = input_ids.shape[0]
|
||||||
|
|
||||||
|
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||||
|
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||||
|
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
|
||||||
|
prev_vision_end = torch.cat(
|
||||||
|
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
|
||||||
|
)
|
||||||
|
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
|
||||||
|
vision_widths_max = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||||
|
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
|
||||||
|
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
|
||||||
|
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
|
||||||
|
|
||||||
|
# create position ids for each vision segment based on the image grid
|
||||||
|
llm_pos_ids_list = []
|
||||||
|
for i, _ in enumerate(vision_segments):
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[i][0],
|
||||||
|
image_grid_thw[i][1] // spatial_merge_size,
|
||||||
|
image_grid_thw[i][2] // spatial_merge_size,
|
||||||
|
)
|
||||||
|
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
|
||||||
|
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
|
||||||
|
w_indices = torch.arange(w, device=device).repeat(t * h)
|
||||||
|
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||||
|
|
||||||
|
# offset by the position of the last vision segment
|
||||||
|
im = image_position_ids + vision_segment_lengths[i]
|
||||||
|
llm_pos_ids_list.append(im)
|
||||||
|
|
||||||
|
# create position ids for each text segment
|
||||||
|
text_ranges = [
|
||||||
|
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
+ text_segment_lengths[i]
|
||||||
|
for i, seq_len in enumerate(text_lengths_between_vision)
|
||||||
|
]
|
||||||
|
|
||||||
|
full_llm_pos_ids_list = [
|
||||||
|
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||||
|
]
|
||||||
|
# import ipdb
|
||||||
|
|
||||||
|
# ipdb.set_trace()
|
||||||
|
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||||
|
final_text_len = input_ids_len - vision_ends[-1]
|
||||||
|
if final_text_len > 0:
|
||||||
|
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
full_llm_pos_ids_list.append(m + max_s)
|
||||||
|
|
||||||
|
position_ids = (
|
||||||
|
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||||
|
)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fallback(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
batch_class: Optional[type] = VlmCausalLMBatch,
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=batch_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[KVCache],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
|
prefill_cache_indices=None, # not used, but passed to match original signature
|
||||||
|
adapter_data=None, # not supported, but passed to match original signature
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
|
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||||
|
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||||
|
|
||||||
|
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||||
|
logits = (
|
||||||
|
self.model.original_forward(
|
||||||
|
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
|
||||||
|
position_ids=position_ids.transpose(0, 1).unsqueeze(
|
||||||
|
1
|
||||||
|
), # expand dim to fit Transformers
|
||||||
|
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||||
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
return_dict=True,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
)
|
||||||
|
.logits.squeeze(dim=0)[lm_head_indices]
|
||||||
|
.unsqueeze(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# from pdb import set_trace; set_trace()
|
||||||
|
|
||||||
|
return logits, None
|
@ -372,9 +372,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return self.batch_class
|
return self.batch_class
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.text_model, "max_past", None)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: VlmCausalLMBatch,
|
batch: VlmCausalLMBatch,
|
||||||
@ -442,12 +439,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
batch.position_ids = position_ids
|
batch.position_ids = position_ids
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
||||||
# in a circular buffer mode.
|
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
|
||||||
max_s = min(self.max_past(), max_s)
|
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
Loading…
Reference in New Issue
Block a user