mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Starting to get there.
This commit is contained in:
parent
85771989d6
commit
ef4fa3ea7c
@ -22,7 +22,6 @@ from torch import nn
|
|||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
@ -734,6 +733,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
cu_seqlen_k,
|
cu_seqlen_k,
|
||||||
max_q,
|
max_q,
|
||||||
max_k,
|
max_k,
|
||||||
|
indices,
|
||||||
) = cross_attention_states
|
) = cross_attention_states
|
||||||
|
|
||||||
key_states = self.k_proj(cross_attention_states)
|
key_states = self.k_proj(cross_attention_states)
|
||||||
@ -862,6 +862,8 @@ class FlashLlamaCrossLayer(torch.nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
|
|
||||||
|
# indices = cross_attention_states[-1]
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
out_hidden_states = hidden_states[:]
|
out_hidden_states = hidden_states[:]
|
||||||
hidden_states = hidden_states[:]
|
hidden_states = hidden_states[:]
|
||||||
@ -892,115 +894,6 @@ class FlashLlamaCrossLayer(torch.nn.Module):
|
|||||||
return hidden_states, None
|
return hidden_states, None
|
||||||
|
|
||||||
|
|
||||||
class MllamaTextSelfAttention(nn.Module):
|
|
||||||
def __init__(self, *, prefix, config, weights, layer_idx):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.dropout = config.dropout
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_key_value_heads = config.num_key_value_heads
|
|
||||||
self.head_dim = config.hidden_size // self.num_heads
|
|
||||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
||||||
|
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
|
||||||
self.num_key_value_heads = (
|
|
||||||
self.num_key_value_heads // weights.process_group.size()
|
|
||||||
)
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
|
|
||||||
self.qkv_proj = TensorParallelColumnLinear.load_multi(
|
|
||||||
config,
|
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
||||||
dim=0,
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
|
||||||
config,
|
|
||||||
prefix=f"{prefix}.o_proj",
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
position_embeddings: torch.Tensor,
|
|
||||||
past_key_value=None,
|
|
||||||
cache_position=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
qkv = self.qkv_proj(hidden_states)
|
|
||||||
query_states, key_states, value_states = qkv.split(
|
|
||||||
[
|
|
||||||
self.head_dim * self.num_heads,
|
|
||||||
self.head_dim * self.num_key_value_heads,
|
|
||||||
self.head_dim * self.num_key_value_heads,
|
|
||||||
],
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
cos, sin = position_embeddings
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin
|
|
||||||
)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
||||||
key_states, value_states = past_key_value.update(
|
|
||||||
key_states, value_states, self.layer_idx, cache_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
causal_mask = attention_mask
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
|
||||||
|
|
||||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
|
||||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
|
||||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
|
||||||
query_states = query_states.contiguous()
|
|
||||||
key_states = key_states.contiguous()
|
|
||||||
value_states = value_states.contiguous()
|
|
||||||
|
|
||||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
|
||||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
|
||||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
# TODO
|
|
||||||
# attn_mask=causal_mask,
|
|
||||||
dropout_p=self.dropout if self.training else 0.0,
|
|
||||||
is_causal=is_causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.view(bsz, q_len, -1)
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
return attn_output, None, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
|
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
|
||||||
class MllamaTextRMSNorm(nn.Module):
|
class MllamaTextRMSNorm(nn.Module):
|
||||||
def __init__(self, weight, eps):
|
def __init__(self, weight, eps):
|
||||||
@ -1026,144 +919,6 @@ class MllamaTextRMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT
|
|
||||||
class MllamaSelfAttentionDecoderLayer(nn.Module):
|
|
||||||
def __init__(self, *, prefix, config, weights, layer_idx):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
|
|
||||||
self.self_attn = MllamaTextSelfAttention(
|
|
||||||
prefix=f"{prefix}.self_attn",
|
|
||||||
config=config,
|
|
||||||
weights=weights,
|
|
||||||
layer_idx=layer_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
|
||||||
self.input_layernorm = MllamaTextRMSNorm.load(
|
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
|
||||||
)
|
|
||||||
self.post_attention_layernorm = MllamaTextRMSNorm.load(
|
|
||||||
prefix=f"{prefix}.post_attention_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.rms_norm_eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value=None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
position_embeddings: Optional[
|
|
||||||
Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
] = None, # will become mandatory in v4.45
|
|
||||||
image_indices: Optional[torch.LongTensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[
|
|
||||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
||||||
]:
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
# Self Attention
|
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class MllamaRotaryEmbedding(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config,
|
|
||||||
weights,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
device = weights.device
|
|
||||||
self.rope_type = config.rope_scaling["rope_type"]
|
|
||||||
self.max_seq_len_cached = config.max_position_embeddings
|
|
||||||
self.original_max_seq_len = config.max_position_embeddings
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
||||||
inv_freq.to(device=device)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
self.original_inv_freq = self.inv_freq
|
|
||||||
|
|
||||||
def _dynamic_frequency_update(self, position_ids, device):
|
|
||||||
"""
|
|
||||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
|
||||||
1 - growing beyond the cached sequence length (allow scaling)
|
|
||||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
|
||||||
"""
|
|
||||||
seq_len = torch.max(position_ids) + 1
|
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"inv_freq", inv_freq, persistent=False
|
|
||||||
) # TODO joao: may break with compilation
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
|
|
||||||
if (
|
|
||||||
seq_len < self.original_max_seq_len
|
|
||||||
and self.max_seq_len_cached > self.original_max_seq_len
|
|
||||||
): # reset
|
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, x, position_ids):
|
|
||||||
if "dynamic" in self.rope_type:
|
|
||||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
|
||||||
|
|
||||||
# Core RoPE block
|
|
||||||
inv_freq_expanded = (
|
|
||||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
|
||||||
)
|
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
|
||||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
|
||||||
device_type = x.device.type
|
|
||||||
device_type = (
|
|
||||||
device_type
|
|
||||||
if isinstance(device_type, str) and device_type != "mps"
|
|
||||||
else "cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
|
|
||||||
1, 2
|
|
||||||
)
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
cos = emb.cos()
|
|
||||||
sin = emb.sin()
|
|
||||||
|
|
||||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
|
||||||
cos = cos * self.attention_scaling
|
|
||||||
sin = sin * self.attention_scaling
|
|
||||||
|
|
||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class MllamaForConditionalGeneration(nn.Module):
|
class MllamaForConditionalGeneration(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1192,14 +947,15 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||||
)
|
)
|
||||||
# logger.info(f"PIxel values {pixel_values.shape}")
|
# logger.info(f"PIxel values {pixel_values.shape}")
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
vision_states = self.vision_model(
|
vision_states = self.vision_model(
|
||||||
pixel_values, aspect_ratio_ids, aspect_ratio_mask
|
pixel_values, aspect_ratio_ids, aspect_ratio_mask
|
||||||
)
|
)
|
||||||
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
|
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
|
||||||
-1, vision_states.shape[-2], self.hidden_size
|
-1, vision_states.shape[-2], self.hidden_size
|
||||||
)
|
)
|
||||||
n, m, h = cross_attention_states.shape
|
_, _, h = cross_attention_states.shape
|
||||||
cross_attention_states = cross_attention_states.view(1, n * m, h)
|
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
|
||||||
# logger.info(f"cross {cross_attention_states.shape}")
|
# logger.info(f"cross {cross_attention_states.shape}")
|
||||||
return cross_attention_states
|
return cross_attention_states
|
||||||
|
|
||||||
@ -1219,7 +975,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
cross_attention_states: Optional[torch.Tensor],
|
cross_attention_states: Optional[torch.Tensor],
|
||||||
image_indices,
|
image_indices,
|
||||||
):
|
):
|
||||||
# if cross_attention_mask is not None:
|
# if cross_att_sention_mask is not None:
|
||||||
# cross_attention_mask, full_text_row_masked_out_mask = (
|
# cross_attention_mask, full_text_row_masked_out_mask = (
|
||||||
# _prepare_cross_attention_mask(
|
# _prepare_cross_attention_mask(
|
||||||
# cross_attention_mask,
|
# cross_attention_mask,
|
||||||
@ -1237,24 +993,59 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
# ]
|
# ]
|
||||||
|
|
||||||
if cross_attention_states is not None:
|
if cross_attention_states is not None:
|
||||||
seqlen_q = input_ids.shape[0]
|
seqlen_q = len(image_indices)
|
||||||
|
n_images = cross_attention_states.shape[0]
|
||||||
seqlen_k = cross_attention_states.shape[1]
|
seqlen_k = cross_attention_states.shape[1]
|
||||||
|
device = cross_attention_states.device
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# raise RuntimeError("TODO")
|
||||||
|
offset = 0
|
||||||
|
cu_q = []
|
||||||
|
indices = []
|
||||||
|
for index in image_indices:
|
||||||
|
cu_q.append(offset)
|
||||||
|
length = seqlen.input_lengths[index]
|
||||||
|
input_ids_offset = seqlen.cu_seqlen_q[index]
|
||||||
|
indices.extend(range(input_ids_offset, input_ids_offset + length))
|
||||||
|
offset += length
|
||||||
|
cu_q.append(offset)
|
||||||
|
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
|
||||||
|
|
||||||
device = input_ids.device
|
cu_seqlen_k = (
|
||||||
cu_seqlen_q = torch.Tensor([0, seqlen_q]).to(
|
torch.arange(
|
||||||
dtype=torch.int32, device=device
|
n_images + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
cu_seqlen_k = torch.Tensor([0, seqlen_k]).to(
|
* seqlen_k
|
||||||
dtype=torch.int32, device=device
|
)
|
||||||
|
max_q = cu_seqlen_q[-1].item()
|
||||||
|
max_k = seqlen_k
|
||||||
|
else:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
seqlen_q + 1, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
seqlen_k = cross_attention_states.shape[1]
|
||||||
|
n_images = cross_attention_states.shape[0]
|
||||||
|
cu_seqlen_k = (
|
||||||
|
torch.arange(
|
||||||
|
n_images + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
* seqlen_k
|
||||||
)
|
)
|
||||||
max_q = seqlen_q
|
max_q = seqlen_q
|
||||||
max_k = seqlen_k
|
max_k = seqlen_k
|
||||||
|
indices = image_indices[:]
|
||||||
|
|
||||||
cross_attention_states = (
|
cross_attention_states = (
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
cu_seqlen_q,
|
cu_seqlen_q,
|
||||||
cu_seqlen_k,
|
cu_seqlen_k,
|
||||||
max_q,
|
max_q,
|
||||||
max_k,
|
max_k,
|
||||||
|
indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self.text_model(
|
outputs = self.text_model(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
from typing import Iterable
|
from typing import Iterable, List
|
||||||
from text_generation_server.pb.generate_pb2 import Request
|
from text_generation_server.pb.generate_pb2 import Request
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -20,10 +20,54 @@ tracer = trace.get_tracer(__name__)
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MllamaCausalLMBatch(VlmCausalLMBatch):
|
class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||||
|
image_indices: List[int] = 42
|
||||||
aspect_ratio_ids: Optional[torch.Tensor] = None
|
aspect_ratio_ids: Optional[torch.Tensor] = None
|
||||||
aspect_ratio_mask: Optional[torch.Tensor] = None
|
aspect_ratio_mask: Optional[torch.Tensor] = None
|
||||||
cross_attention_states: Optional[torch.Tensor] = None
|
cross_attention_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@tracer.start_as_current_span("concatenate")
|
||||||
|
def concatenate(cls, batches):
|
||||||
|
batch = super().concatenate(batches)
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
image_indices = []
|
||||||
|
attention_states = []
|
||||||
|
for b in batches:
|
||||||
|
attention_states.append(b.cross_attention_states)
|
||||||
|
image_indices.extend([i + offset for i in b.image_indices])
|
||||||
|
offset += len(b.image_indices)
|
||||||
|
batch.cross_attention_states = torch.cat(attention_states, dim=0)
|
||||||
|
batch.image_indices = image_indices
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("filter")
|
||||||
|
def filter(self, request_ids: List[int]):
|
||||||
|
assert self.image_indices is not None
|
||||||
|
batch = super().filter(request_ids)
|
||||||
|
assert self.image_indices is not None
|
||||||
|
indices = []
|
||||||
|
for i, request_id in enumerate(request_ids):
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
indices.append(idx)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
new_image_indices = []
|
||||||
|
prev_i = None
|
||||||
|
for i in self.image_indices:
|
||||||
|
if i in indices:
|
||||||
|
new_image_indices.append(offset)
|
||||||
|
if i != prev_i:
|
||||||
|
offset += 1
|
||||||
|
prev_i = i
|
||||||
|
|
||||||
|
batch.image_indices = new_image_indices
|
||||||
|
batch.cross_attention_states = self.cross_attention_states[indices]
|
||||||
|
assert offset <= batch.cross_attention_states.shape[0]
|
||||||
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(
|
def batch_tokenized_inputs(
|
||||||
cls, requests: Iterable[Request], tokenizer, processor, config
|
cls, requests: Iterable[Request], tokenizer, processor, config
|
||||||
@ -115,5 +159,6 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.aspect_ratio_ids = None
|
batch.aspect_ratio_ids = None
|
||||||
batch.aspect_ratio_mask = None
|
batch.aspect_ratio_mask = None
|
||||||
batch.image_indices = None
|
batch.image_indices = []
|
||||||
|
assert batch.image_indices is not None
|
||||||
return batch
|
return batch
|
||||||
|
@ -141,7 +141,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches):
|
def concatenate(cls, batches):
|
||||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
batch = super().concatenate(batches)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
@ -402,7 +402,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
image_indices=batch.image_indices,
|
image_indices=batch.image_indices[:],
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
|
Loading…
Reference in New Issue
Block a user