mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
add fix
This commit is contained in:
parent
9a8d0462e1
commit
33a7ec57e2
@ -3,20 +3,157 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer, AutoProcessor
|
||||
import transformers.modeling_utils
|
||||
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch
|
||||
from text_generation_server.utils import initialize_torch_distributed
|
||||
|
||||
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
|
||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
|
||||
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
|
||||
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
|
||||
# to internal constraints it was not (yet?) possible to circumvent
|
||||
REPLICATED_ATTENTION_MODELS = [
|
||||
"olmo2",
|
||||
"phi3",
|
||||
]
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/e1a2c699dda82199e88e433c144eae66f3b31878/vllm/v1/attention/backends/flash_attn.py
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.Tensor,
|
||||
page_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
# Handle if we are starting in the middle of a local attention block,
|
||||
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||
# the number of tokens that are not in the first local attention block and
|
||||
# then we can simply use a cdiv for the rest.
|
||||
# For example if we have:
|
||||
# attn_chunk_size = 4
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# k_seqlens = [6, 17, 9]
|
||||
# Then we would get:
|
||||
# new_tokens_in_first_block = [2, 1, 4]
|
||||
# local_blocks = [2, 4, 2]
|
||||
q_tokens_in_first_block = np.minimum(
|
||||
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
|
||||
).astype(np.int32)
|
||||
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
|
||||
|
||||
# Once we know the number of local blocks we can compute the request spans
|
||||
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||
# have to make,
|
||||
# For the above example we would get:
|
||||
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||
#
|
||||
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||
cu_num_blocks = np.cumsum(local_blocks)
|
||||
virtual_batches = cu_num_blocks[-1]
|
||||
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||
# first and last blocks could be partial
|
||||
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||
# set the first block since this may be a partial block
|
||||
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||
# set the remaining blocks
|
||||
seqlens_q_local[arange > 0] = np.minimum(
|
||||
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
|
||||
)[arange > 0]
|
||||
|
||||
# convert from q_seqlens to cu_seqlens_q
|
||||
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
|
||||
|
||||
# compute the seqlens_k_local,
|
||||
# basically a full local attention block for all but the last block in each
|
||||
# batch
|
||||
# For our example this will be:
|
||||
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
|
||||
if ATTENTION == "flashdecoding":
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
|
||||
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
|
||||
)
|
||||
|
||||
# For the example the local attention blocks start at:
|
||||
# _b0_ _____b1_____ _b2_
|
||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||
block_starts = k_seqstarts_absolute // page_size
|
||||
assert attn_chunk_size % page_size == 0, (
|
||||
f"attn_chunk_size {attn_chunk_size} is not "
|
||||
f"divisible by page_size {page_size}"
|
||||
)
|
||||
pages_per_local_batch = attn_chunk_size // page_size
|
||||
|
||||
# Create a block_table for the local attention blocks
|
||||
# For out example if we have a block-table like (assuming page_size=2):
|
||||
# block_table = [
|
||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||
# ]
|
||||
# Then for the local batches we would want a block-table like
|
||||
# block_table_local = [
|
||||
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||
# ]
|
||||
block_indices = np.broadcast_to(
|
||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||
(virtual_batches, pages_per_local_batch),
|
||||
) + np.expand_dims(block_starts, axis=1)
|
||||
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
||||
batch_indices = np.repeat(
|
||||
np.arange(actual_batch_size, dtype=np.int32),
|
||||
local_blocks * pages_per_local_batch,
|
||||
)
|
||||
block_table_local = block_table[batch_indices, block_indices].view(
|
||||
virtual_batches, -1
|
||||
)
|
||||
else:
|
||||
block_table_local = block_table
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
|
||||
|
||||
|
||||
# # Qwen2VL
|
||||
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||
# "tgi"
|
||||
# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||
# "eager"
|
||||
# ]
|
||||
def tgi_flash_attention_forward(
|
||||
module,
|
||||
query_states: torch.Tensor,
|
||||
@ -34,8 +171,14 @@ def tgi_flash_attention_forward(
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
softcap: Optional[float] = None,
|
||||
use_sdpa: Optional[bool] = False,
|
||||
local_seqlen: Optional[Seqlen] = None,
|
||||
local_block_tables: Optional[torch.Tensor] = None,
|
||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||
):
|
||||
if hasattr(module, "use_rope") and module.use_rope:
|
||||
seqlen = local_seqlen
|
||||
block_tables = local_block_tables
|
||||
|
||||
kv_cache = kv_cache[module.layer_idx]
|
||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||
@ -50,6 +193,7 @@ def tgi_flash_attention_forward(
|
||||
sliding_window = -1 if sliding_window is None else sliding_window
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
if not use_sdpa:
|
||||
attn_output = attention(
|
||||
query=query_states,
|
||||
key=key_states,
|
||||
@ -62,6 +206,47 @@ def tgi_flash_attention_forward(
|
||||
window_size_left=sliding_window,
|
||||
softcap=softcap,
|
||||
)
|
||||
else:
|
||||
lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]
|
||||
max_length = max(lengths)
|
||||
attention_mask = attention_mask[:, :, :, :max_length]
|
||||
enable_gqa = query_states.shape[1] != key_states.shape[1]
|
||||
# Split tensors using vectorized split
|
||||
query_list = torch.split(query_states, lengths.tolist(), dim=0)
|
||||
key_list = torch.split(key_states, lengths.tolist(), dim=0)
|
||||
value_list = torch.split(value_states, lengths.tolist(), dim=0)
|
||||
|
||||
padded_query = torch.nn.utils.rnn.pad_sequence(query_list, batch_first=True)
|
||||
padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True)
|
||||
padded_value = torch.nn.utils.rnn.pad_sequence(value_list, batch_first=True)
|
||||
|
||||
padded_query = padded_query.transpose(1, 2).contiguous()
|
||||
padded_key = padded_key.transpose(1, 2).contiguous()
|
||||
padded_value = padded_value.transpose(1, 2).contiguous()
|
||||
|
||||
# Compute attention
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
padded_query,
|
||||
padded_key,
|
||||
padded_value,
|
||||
attn_mask=attention_mask,
|
||||
scale=softmax_scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(
|
||||
1, 2
|
||||
) # [batch_size, seq_len, num_heads, head_dim]
|
||||
max_seq_len = padded_query.size(2)
|
||||
seq_range = torch.arange(max_seq_len, device=padded_query.device).unsqueeze(
|
||||
0
|
||||
)
|
||||
lengths_tensor = torch.tensor(
|
||||
lengths, device=padded_query.device
|
||||
).unsqueeze(1)
|
||||
mask = seq_range < lengths_tensor # [batch, max_seq_len]
|
||||
attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim]
|
||||
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query_states,
|
||||
@ -83,20 +268,16 @@ def tgi_flash_attention_forward(
|
||||
|
||||
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||
|
||||
# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
|
||||
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
|
||||
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
|
||||
# to internal constraints it was not (yet?) possible to circumvent
|
||||
REPLICATED_ATTENTION_MODELS = [
|
||||
"olmo2",
|
||||
"phi3",
|
||||
]
|
||||
|
||||
# TODO: implement
|
||||
# tgi_cross_attention_forward
|
||||
|
||||
|
||||
class TransformersFlashCausalLM(FlashCausalLM):
|
||||
class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model_class,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
@ -104,10 +285,15 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
default_dtype=torch.float16,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
processor_class=AutoProcessor,
|
||||
processor_kwargs=None,
|
||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||
batch_class=VlmCausalLMBatch,
|
||||
):
|
||||
self.batch_class = batch_class
|
||||
self.quantize = quantize
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
self.dtype = dtype
|
||||
|
||||
if speculator:
|
||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||
@ -131,18 +317,40 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
if processor_kwargs is None:
|
||||
processor_kwargs = {}
|
||||
|
||||
self.processor = processor_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
attn_implementation = {
|
||||
"text_config": "tgi",
|
||||
"vision_config": "sdpa",
|
||||
}
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
attn_implementation="tgi",
|
||||
attn_implementation=attn_implementation,
|
||||
device_map=device if world_size == 1 else None,
|
||||
tp_plan="auto" if world_size > 1 else None,
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.config = model.config
|
||||
config = model.config
|
||||
|
||||
# VLM models define the config we care about in their text_config
|
||||
text_config = getattr(model.config, "text_config", None)
|
||||
if text_config is not None:
|
||||
config = text_config
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
@ -156,20 +364,18 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
self.num_layers = model.config.num_hidden_layers
|
||||
self.num_heads = model.config.num_attention_heads
|
||||
self.num_kv_heads = model.config.num_key_value_heads
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
# Some models use GQA and different sizes for o_proj
|
||||
# and q_proj, that allows for that.
|
||||
if hasattr(model.config, "head_dim"):
|
||||
self.head_size = model.config.head_dim
|
||||
if hasattr(config, "head_dim"):
|
||||
self.head_size = config.head_dim
|
||||
else:
|
||||
self.head_size = (
|
||||
model.config.hidden_size // model.config.num_attention_heads
|
||||
)
|
||||
self.head_size = config.hidden_size // config.num_attention_heads
|
||||
|
||||
# Skip it for models in the exception list
|
||||
if model.config.model_type not in REPLICATED_ATTENTION_MODELS:
|
||||
if config.model_type not in REPLICATED_ATTENTION_MODELS:
|
||||
self.num_heads = self.num_heads // self.process_group.size()
|
||||
self.num_kv_heads = (
|
||||
self.num_kv_heads // self.process_group.size()
|
||||
@ -227,26 +433,47 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
# We first copy the original model.forward because we still need it in the monkey patch
|
||||
self.model.original_forward = self.model.forward
|
||||
self.model.forward = self._model_forward
|
||||
self.model.get_position_ids = self.get_position_ids
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
def get_position_ids(self, input_ids, image_grid_thw, position_ids):
|
||||
return position_ids
|
||||
|
||||
def pre_process_inputs(self, **kwargs):
|
||||
input_ids = kwargs["input_ids"]
|
||||
position_ids = kwargs["position_ids"]
|
||||
return {
|
||||
"input_ids": input_ids.unsqueeze(0),
|
||||
"position_ids": position_ids.unsqueeze(0),
|
||||
}
|
||||
|
||||
def post_process_outputs(self, logits, lm_head_indices):
|
||||
return logits.squeeze(dim=0)
|
||||
|
||||
@classmethod
|
||||
def fallback(
|
||||
cls,
|
||||
model_id: str,
|
||||
model_class,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
batch_class: Optional[type] = VlmCausalLMBatch,
|
||||
processor_kwargs: Optional[dict] = None,
|
||||
):
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
model_class=model_class,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=batch_class,
|
||||
processor_kwargs=processor_kwargs,
|
||||
)
|
||||
|
||||
def _model_forward(
|
||||
@ -262,14 +489,33 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
prefill_cache_indices=None, # not used, but passed to match original signature
|
||||
adapter_data=None, # not supported, but passed to match original signature
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||
|
||||
inputs = self.pre_process_inputs(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
from loguru import logger
|
||||
|
||||
logger.info(
|
||||
f"input_ids: {input_ids.shape}, position_ids:{inputs.get('local_seqlen', None)}"
|
||||
)
|
||||
|
||||
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||
logits = self.model.original_forward(
|
||||
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
|
||||
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
|
||||
input_ids=inputs["input_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||
logits_to_keep=logits_to_keep,
|
||||
@ -282,6 +528,227 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
max_s=max_s,
|
||||
kv_head_mapping=self.kv_head_mapping,
|
||||
kv_scales=self.kv_scales,
|
||||
).logits.squeeze(dim=0)
|
||||
pixel_values=pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
image_grid_thw=image_grid_thw,
|
||||
attention_mask=inputs.get("attention_mask", None),
|
||||
use_sdpa=inputs.get("use_sdpa", False),
|
||||
cache_position=inputs.get("cache_position", None),
|
||||
local_seqlen=inputs.get("local_seqlen", None),
|
||||
local_block_tables=inputs.get("local_block_tables", None),
|
||||
).logits
|
||||
|
||||
logits = self.post_process_outputs(logits, lm_head_indices)
|
||||
|
||||
return logits, None
|
||||
|
||||
|
||||
class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor):
|
||||
if image_grid_thw is None:
|
||||
return (
|
||||
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||
.unsqueeze(1)
|
||||
.repeat(1, 3)
|
||||
)
|
||||
|
||||
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
vision_end_token_id = self.config.vision_end_token_id
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
input_ids_len = input_ids.shape[0]
|
||||
|
||||
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
|
||||
prev_vision_end = torch.cat(
|
||||
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
|
||||
)
|
||||
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
|
||||
vision_widths_max = torch.cat(
|
||||
[
|
||||
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||
]
|
||||
)
|
||||
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
|
||||
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
|
||||
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
|
||||
|
||||
# create position ids for each vision segment based on the image grid
|
||||
llm_pos_ids_list = []
|
||||
for i, _ in enumerate(vision_segments):
|
||||
t, h, w = (
|
||||
image_grid_thw[i][0],
|
||||
image_grid_thw[i][1] // spatial_merge_size,
|
||||
image_grid_thw[i][2] // spatial_merge_size,
|
||||
)
|
||||
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
|
||||
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
|
||||
w_indices = torch.arange(w, device=device).repeat(t * h)
|
||||
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||
|
||||
# offset by the position of the last vision segment
|
||||
im = image_position_ids + vision_segment_lengths[i]
|
||||
llm_pos_ids_list.append(im)
|
||||
|
||||
# create position ids for each text segment
|
||||
text_ranges = [
|
||||
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||
+ text_segment_lengths[i]
|
||||
for i, seq_len in enumerate(text_lengths_between_vision)
|
||||
]
|
||||
|
||||
full_llm_pos_ids_list = [
|
||||
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||
]
|
||||
# import ipdb
|
||||
|
||||
# ipdb.set_trace()
|
||||
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||
final_text_len = input_ids_len - vision_ends[-1]
|
||||
if final_text_len > 0:
|
||||
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||
full_llm_pos_ids_list.append(m + max_s)
|
||||
|
||||
position_ids = (
|
||||
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||
)
|
||||
return position_ids
|
||||
|
||||
def post_process_outputs(self, logits, lm_head_indices):
|
||||
return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0)
|
||||
|
||||
def pre_process_inputs(self, **kwargs):
|
||||
input_ids = kwargs["input_ids"]
|
||||
position_ids = kwargs["position_ids"]
|
||||
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
position_ids = position_ids.transpose(0, 1).unsqueeze(1)
|
||||
return {"input_ids": input_ids, "position_ids": position_ids}
|
||||
|
||||
|
||||
class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
def get_attention_mask(self, input_ids, cu_seqlen_prefill):
|
||||
device = input_ids.device
|
||||
dtype = self.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
|
||||
batch_size = len(lengths)
|
||||
|
||||
sequence_length = max(lengths)
|
||||
target_length = sequence_length
|
||||
# Create the padding mask from the computed lengths.
|
||||
# pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
||||
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
|
||||
lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1)
|
||||
pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
|
||||
|
||||
# Build the base causal mask (for non-image tokens):
|
||||
causal_mask = torch.tril(
|
||||
torch.ones(
|
||||
(sequence_length, sequence_length), dtype=torch.bool, device=device
|
||||
)
|
||||
)
|
||||
base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
|
||||
1
|
||||
) # [batch, sequence_length, sequence_length]
|
||||
base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint
|
||||
|
||||
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||
input_ids.device
|
||||
)
|
||||
|
||||
image_token_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
torch.split(image_token_mask, lengths), batch_first=True, padding_value=0
|
||||
)
|
||||
bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze(
|
||||
1
|
||||
)
|
||||
|
||||
# Combine the causal base mask and the bidirectional mask.
|
||||
combined_mask = torch.logical_or(
|
||||
base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)
|
||||
).to(device)
|
||||
# combined_mask now has shape [batch, 1, sequence_length, sequence_length]
|
||||
|
||||
full_attention_mask = torch.zeros(
|
||||
(batch_size, 1, sequence_length, target_length),
|
||||
device=device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
full_attention_mask[:, :, :, :sequence_length] = combined_mask
|
||||
|
||||
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
|
||||
|
||||
return final_attention_mask
|
||||
|
||||
def pre_process_inputs(self, **kwargs):
|
||||
input_ids = kwargs["input_ids"]
|
||||
position_ids = kwargs["position_ids"]
|
||||
cu_seqlen_prefill = kwargs["cu_seqlen_prefill"]
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids.unsqueeze(0),
|
||||
"position_ids": position_ids.unsqueeze(0),
|
||||
}
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
attention_mask = self.get_attention_mask(
|
||||
input_ids.squeeze(0), cu_seqlen_prefill
|
||||
)
|
||||
inputs["attention_mask"] = attention_mask
|
||||
inputs["use_sdpa"] = True
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
def pre_process_inputs(self, **kwargs):
|
||||
input_ids = kwargs["input_ids"]
|
||||
position_ids = kwargs["position_ids"]
|
||||
seqlen = kwargs["seqlen"]
|
||||
block_tables = kwargs["block_tables"]
|
||||
|
||||
inputs = super().pre_process_inputs(**kwargs)
|
||||
inputs["cache_position"] = position_ids
|
||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||
# from loguru import logger
|
||||
|
||||
# logger.info(f"input_ids: {input_ids.shape}, position_ids: {position_ids.shape}")
|
||||
cu_seqlen_k = seqlen.cu_seqlen_k
|
||||
cu_seqlen_q = seqlen.cu_seqlen_q
|
||||
seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
|
||||
(
|
||||
seqlens_q_local_np,
|
||||
virt_q_cu_seqlens_np,
|
||||
virt_k_seqlens_np,
|
||||
virt_block_table,
|
||||
) = make_local_attention_virtual_batches(
|
||||
self.model.config.text_config.attention_chunk_size,
|
||||
cu_seqlen_q.cpu().numpy(),
|
||||
seq_lens_np.cpu().numpy(),
|
||||
block_tables,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
local_seqlen = Seqlen(
|
||||
input_lengths=torch.from_numpy(virt_k_seqlens_np).to(
|
||||
input_ids.device, non_blocking=True
|
||||
),
|
||||
cache_lengths=torch.zeros(virt_k_seqlens_np.shape).to(
|
||||
input_ids.device, non_blocking=True
|
||||
),
|
||||
cu_seqlen_q=torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
input_ids.device, non_blocking=True
|
||||
),
|
||||
max_q=int(seqlens_q_local_np.max()),
|
||||
max_k=int(virt_k_seqlens_np.max()),
|
||||
)
|
||||
|
||||
inputs["local_seqlen"] = local_seqlen
|
||||
inputs["local_block_tables"] = virt_block_table
|
||||
|
||||
return inputs
|
||||
|
Loading…
Reference in New Issue
Block a user