Flashing mllama.

This commit is contained in:
Nicolas Patry 2024-09-27 17:17:14 +02:00
parent 2441142c8b
commit 85771989d6
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
9 changed files with 586 additions and 842 deletions

View File

@ -29,7 +29,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
tracing::debug!("Loading template: {:#?}", template_str);
tracing::debug!("Loading template: {}", template_str);
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)

View File

@ -652,6 +652,8 @@ fn prepare_input(
.encode(tokenizer_query, add_special_tokens)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
tracing::info!("Validation said the input is {}", encoding.len());
Ok((encoding, input_chunks))
}

View File

@ -112,7 +112,11 @@ try:
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
@ -149,7 +153,7 @@ except ImportError as e:
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(IDEFICSSharded)
__all__.append(IdeficsCausalLM)
MAMBA_AVAILABLE = True
try:
@ -1122,7 +1126,7 @@ def get_model(
)
if model_type == IDEFICS:
if FLASH_ATTENTION:
return IDEFICSSharded(
return IdeficsCausalLM(
model_id,
revision,
quantize=quantize,
@ -1134,13 +1138,16 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == MLLAMA:
if FLASH_ATTENTION:
return IDEFICSSharded(
model_id,
revision,
return VlmCausalLM(
model_id=model_id,
model_class=MllamaForConditionalGeneration,
batch_class=MllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))

View File

@ -450,6 +450,8 @@ class FlashLlamaLayer(nn.Module):
seqlen,
max_s,
adapter_data,
cross_attention_states,
image_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -487,6 +489,7 @@ class FlashLlamaModel(torch.nn.Module):
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
with no_fp8(weights):
self.layers.append(
FlashLlamaLayer(
@ -499,22 +502,38 @@ class FlashLlamaModel(torch.nn.Module):
)
)
self.layers.extend(
[
FlashLlamaLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
# Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1):
if layer_id in self.cross_attention_layers:
from text_generation_server.models.custom_modeling.mllama import (
FlashLlamaCrossLayer,
)
self.layers.append(
FlashLlamaCrossLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
)
)
else:
self.layers.append(
FlashLlamaLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
)
)
# Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1)
]
)
with no_fp8(weights):
last_layer_id = config.num_hidden_layers - 1
@ -556,6 +575,8 @@ class FlashLlamaModel(torch.nn.Module):
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
cross_attention_states=None,
image_indices=None,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -579,6 +600,8 @@ class FlashLlamaModel(torch.nn.Module):
seqlen,
max_s,
adapter_data,
cross_attention_states,
image_indices,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -625,6 +648,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
image_indices=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
@ -639,6 +664,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
image_indices=image_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -19,28 +19,24 @@ from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
import math
import flash_attn_2_cuda
from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_outputs import (
CausalLMOutputWithPast,
BaseModelOutputWithPast,
)
from transformers.cache_utils import (
StaticCache,
DynamicCache,
)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
import torch.nn.functional as F
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
SpeculativeHead,
FastLinear,
)
from text_generation_server.layers.attention import (
Seqlen,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
@ -675,7 +671,7 @@ class MllamaTextCrossAttention(nn.Module):
self.num_key_value_heads = self.config.num_key_value_heads
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // self.num_heads
self.head_size = config.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.layer_idx = layer_idx
@ -715,77 +711,70 @@ class MllamaTextCrossAttention(nn.Module):
self.k_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
)
self.softmax_scale = self.head_size**-0.5
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_value=None,
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
# past_key_value=None,
# attention_mask: Optional[torch.Tensor] = None,
# cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size()
# hidden_states = hidden_states.unsqueeze(0)
# bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
query_states = query_states.view(-1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states)
if cross_attention_states is not None:
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
(
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
) = cross_attention_states
key_states = self.k_norm(key_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
key_states = self.k_norm(key_states)
if past_key_value is not None:
# if we have a new image + new tokens, we only computed key_states on that new image
# we still update the cross key states, past_image, new_image. And use it!
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx,
{"cache_position": cache_position},
)
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_value[self.layer_idx][0],
past_key_value[self.layer_idx][1],
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
causal = False
# logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
attn_output = flash_attn_2_cuda.varlen_fwd(
query_states,
key_states,
value_states,
None,
cu_seqlen_q,
cu_seqlen_k,
None,
None,
None, # block_tables
None,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal, # Causal
-1, # window_size_left,
-1,
0.0, # softcap
False,
None,
)[0]
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
@ -817,15 +806,16 @@ class MllamaTextMLP(nn.Module):
gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
result = self.down_proj(
self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1]
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
)
return result
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
class FlashLlamaCrossLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
def __init__(self, *, prefix, config, weights, layer_idx) -> None:
def __init__(self, *, prefix, config, weights, index) -> None:
layer_idx = index
super().__init__()
self.cross_attn = MllamaTextCrossAttention(
prefix=f"{prefix}.cross_attn",
@ -854,86 +844,52 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor,
attention_mask: torch.Tensor,
full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
past_key_value=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.Tensor:
if past_key_value is not None:
is_mixed = False
if cross_attention_states is None:
out_hidden_states = hidden_states[:]
indices = []
for i, k in enumerate(past_key_value[self.layer_idx][0]):
if isinstance(k, torch.Tensor):
indices.append(i)
from loguru import logger
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
cross_attention_states, # [ IB, ...]
image_indices,
) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None:
return hidden_states, residual
if residual is not None:
hidden_states += residual
if cu_seqlen_prefill is not None:
out_hidden_states = hidden_states[:]
hidden_states = hidden_states[:]
else:
out_hidden_states = hidden_states[:]
hidden_states = hidden_states[image_indices]
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
logger.info(f"Indices {indices}")
if len(indices) == 0:
return hidden_states
is_mixed = True
if len(indices) == hidden_states.shape[0]:
is_mixed = False
hidden_states = self.cross_attn(
hidden_states=hidden_states,
# attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
if is_mixed:
hidden_states = hidden_states[indices]
# Dirty hack
_past_key_value = [None] * len(past_key_value)
_past_key_value[self.layer_idx] = (
torch.stack(
[
k
for i, k in enumerate(past_key_value[self.layer_idx][0])
if i in indices
],
dim=0,
),
torch.stack(
[
k
for i, k in enumerate(past_key_value[self.layer_idx][1])
if i in indices
],
dim=0,
),
)
logger.info(f"Hidden states {hidden_states.shape}")
logger.info(f"k {_past_key_value[self.layer_idx][0].shape}")
logger.info(f"v {_past_key_value[self.layer_idx][1].shape}")
past_key_value = _past_key_value
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if cu_seqlen_prefill is not None:
out_hidden_states[:] = hidden_states
else:
out_hidden_states[image_indices] = hidden_states
hidden_states = out_hidden_states
hidden_states, attn_weights, past_key_value = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
past_key_value=past_key_value,
cache_position=cache_position,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if full_text_row_masked_out_mask is not None:
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
if is_mixed:
out_hidden_states[indices] = hidden_states
hidden_states = out_hidden_states
from loguru import logger
logger.info(f"After Hidden states {hidden_states.shape}")
return hidden_states
return hidden_states, None
class MllamaTextSelfAttention(nn.Module):
@ -1103,6 +1059,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module):
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.45
image_indices: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
@ -1207,332 +1164,6 @@ class MllamaRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class MllamaTextModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.cross_attention_layers = config.cross_attention_layers
self.layers = []
for layer_idx in range(config.num_hidden_layers):
if layer_idx in self.cross_attention_layers:
self.layers.append(
MllamaCrossAttentionDecoderLayer(
prefix=f"{prefix}.layers.{layer_idx}",
config=config,
weights=weights,
layer_idx=layer_idx,
)
)
else:
self.layers.append(
MllamaSelfAttentionDecoderLayer(
prefix=f"{prefix}.layers.{layer_idx}",
config=config,
weights=weights,
layer_idx=layer_idx,
)
)
self.norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.rotary_emb = MllamaRotaryEmbedding(config=config, weights=weights)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.FloatTensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
full_text_row_masked_out_mask: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
for idx, decoder_layer in enumerate(self.layers):
# if (
# idx in self.cross_attention_layers
# and cross_attention_states is None
# and (
# past_key_values is None
# or (
# past_key_values is not None
# and any(past_key_values.get_seq_length(idx) == 0
# )
# )
# ):
# continue
layer_outputs = decoder_layer(
hidden_states,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
attention_mask=causal_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
position_ids=position_ids,
past_key_value=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
# TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line
# self.config._attn_implementation == "sdpa" and
# if self.config._attn_implementation == "sdpa" and not using_static_cache:
if self.config._attn_implementation == "sdpa" and not using_static_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
class MllamaForCausalLM(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(
prefix=f"{prefix}.model", config=config, weights=weights
)
self.lm_head = SpeculativeHead.load(
prefix=f"{prefix}.lm_head",
config=config,
weights=weights,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.LongTensor] = None,
cross_attention_mask: Optional[torch.LongTensor] = None,
full_text_row_masked_out_mask: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
cross_attention_states=cross_attention_states,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
# if lm_head_indices is not None:
# hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return (
CausalLMOutputWithPast(
logits=logits,
past_key_values=outputs.past_key_values,
),
speculative_logits,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif (
input_ids.shape[1] != cache_position.shape[0]
): # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {
"input_ids": input_ids.clone(memory_format=torch.contiguous_format),
"inputs_embeds": None,
}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
}
)
return model_inputs
class MllamaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
@ -1545,99 +1176,102 @@ class MllamaForConditionalGeneration(nn.Module):
self.vision_model = MllamaVisionModel(
prefix="vision_model", config=config.vision_config, weights=weights
)
self.language_model = MllamaForCausalLM(
prefix="language_model", config=config.text_config, weights=weights
)
self.multi_modal_projector = FastLinear.load(
prefix="multi_modal_projector", config=config, weights=weights, bias=True
)
self.text_model = FlashLlamaForCausalLM(
prefix="language_model", config=config.text_config, weights=weights
)
self.config = config
self.dtype = weights.dtype
self.device = weights.device
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# logger.info(f"PIxel values {pixel_values.shape}")
vision_states = self.vision_model(
pixel_values, aspect_ratio_ids, aspect_ratio_mask
)
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
-1, vision_states.shape[-2], self.hidden_size
)
n, m, h = cross_attention_states.shape
cross_attention_states = cross_attention_states.view(1, n * m, h)
# logger.info(f"cross {cross_attention_states.shape}")
return cross_attention_states
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[List[List[int]]] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[List[List[List[int]]]] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
image_hidden_states=None,
image_attention_mask=None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor],
cross_attention_states: Optional[torch.Tensor],
image_indices,
):
if past_key_values is None:
past_key_values = DynamicCache(
num_hidden_layers=self.config.text_config.num_hidden_layers
)
elif isinstance(past_key_values, list):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
# if cross_attention_mask is not None:
# cross_attention_mask, full_text_row_masked_out_mask = (
# _prepare_cross_attention_mask(
# cross_attention_mask,
# num_vision_tokens=self.vision_model.num_patches,
# dtype=self.dtype,
# )
# )
# else:
# full_text_row_masked_out_mask = None
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
# if cross_attention_mask is not None and cache_position is not None:
# cross_attention_mask = cross_attention_mask[:, :, cache_position]
# full_text_row_masked_out_mask = full_text_row_masked_out_mask[
# :, :, cache_position
# ]
if cross_attention_states is not None:
seqlen_q = input_ids.shape[0]
seqlen_k = cross_attention_states.shape[1]
device = input_ids.device
cu_seqlen_q = torch.Tensor([0, seqlen_q]).to(
dtype=torch.int32, device=device
)
cu_seqlen_k = torch.Tensor([0, seqlen_k]).to(
dtype=torch.int32, device=device
)
max_q = seqlen_q
max_k = seqlen_k
cross_attention_states = (
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError(
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
)
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# get vision tokens from vision model
vision_states = self.vision_model(
pixel_values, aspect_ratio_ids, aspect_ratio_mask
)
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
-1, vision_states.shape[-2], self.hidden_size
)
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = (
_prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
)
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None and cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
:, :, cache_position
]
outputs = self.language_model(
outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
# cross_attention_mask=cross_attention_mask,
image_indices=image_indices,
)
return outputs

View File

@ -6,6 +6,9 @@ import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase,
ProcessorMixin,
)
@ -20,6 +23,18 @@ from text_generation_server.models.types import (
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
import torch.distributed
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
@ -36,9 +51,6 @@ class IdeficsCausalLMBatch(Batch):
attention_mask: torch.Tensor
position_ids: torch.Tensor
pixel_values: Optional[torch.Tensor]
aspect_ratio_ids: Optional[torch.Tensor]
aspect_ratio_mask: Optional[torch.Tensor]
cross_attention_mask: Optional[torch.Tensor]
image_hidden_states: Optional[torch.Tensor]
image_attention_mask: Optional[torch.Tensor]
past_key_values: Optional[List[Tuple]]
@ -121,69 +133,33 @@ class IdeficsCausalLMBatch(Batch):
)
# TODO Check impact on idefics
prompts = []
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
prompt = []
for chunk in inp:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
prompt.append(chunk.text)
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
prompt.append(image)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
prompts.append(prompt)
if config.model_type == "idefics":
prompts = []
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
prompt = []
for chunk in inp:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
prompt.append(chunk.text)
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
prompt.append(image)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
prompts.append(prompt)
# The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
tokenized_inputs = processor(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_truncation,
# TODO Check impact on idefics
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device)
else:
images = []
texts = []
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
curr_images = []
curr_text = ""
for chunk in inp:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
curr_text += chunk.text
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
curr_images.append(image)
# TODO unsure about BOS
curr_text += "<|image|>"
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
images.append(curr_images)
texts.append(curr_text)
# The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
if all(len(im) == 0 for im in images):
images = None
tokenized_inputs = processor(
images=images,
text=texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_truncation,
).to(device)
# The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
tokenized_inputs = processor(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_truncation,
# TODO Check impact on idefics
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(
@ -208,10 +184,7 @@ class IdeficsCausalLMBatch(Batch):
# Do the same for image_attention_mask
if pixel_values is None:
image_attention_mask = None
aspect_ratio_ids = None
aspect_ratio_mask = None
cross_attention_mask = None
elif "image_attention_mask" in tokenized_inputs:
else:
image_attention_mask = input_ids.new_zeros(
(
pb.size,
@ -222,21 +195,6 @@ class IdeficsCausalLMBatch(Batch):
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
"image_attention_mask"
]
aspect_ratio_ids = None
aspect_ratio_mask = None
cross_attention_mask = None
elif "cross_attention_mask" in tokenized_inputs:
image_attention_mask = None
aspect_ratio_ids = tokenized_inputs["aspect_ratio_ids"]
aspect_ratio_mask = tokenized_inputs["aspect_ratio_mask"]
cross_attention_mask = tokenized_inputs["cross_attention_mask"]
pixel_values = pixel_values.to(dtype=dtype)
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
tokenized_inputs["input_ids"] = tokenized_inputs["input_ids"].clamp(
max=processor.tokenizer.vocab_size - 1
)
else:
raise RuntimeError("Unhandled state for idefics/mllama")
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
@ -266,9 +224,6 @@ class IdeficsCausalLMBatch(Batch):
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
cross_attention_mask=cross_attention_mask,
)
@tracer.start_as_current_span("filter")
@ -332,21 +287,15 @@ class IdeficsCausalLMBatch(Batch):
+ new_padding_right_offset,
]
# Do the same for pixel_values and image_attention_mask
if self.pixel_values is not None:
pixel_values = self.pixel_values[keep_indices]
else:
pixel_values = None
if self.image_attention_mask is not None:
self.image_attention_mask = self.image_attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.image_attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
:,
]
pixel_values = self.pixel_values[keep_indices]
self.image_attention_mask = self.image_attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.image_attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
:,
]
if self.image_hidden_states is None:
image_hidden_states = None
else:
@ -360,15 +309,7 @@ class IdeficsCausalLMBatch(Batch):
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
if not isinstance(past_keys, torch.Tensor):
past_keys = [k for i, k in enumerate(past_keys) if i in keep_indices]
past_values = [
k for i, k in enumerate(past_values) if i in keep_indices
]
layer[0] = past_keys
layer[1] = past_values
continue
elif len(past_keys.shape) == 3:
if len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
@ -397,9 +338,6 @@ class IdeficsCausalLMBatch(Batch):
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
self.aspect_ratio_ids = None
self.aspect_ratio_mask = None
self.cross_attention_mask = None
return self
@ -417,8 +355,7 @@ class IdeficsCausalLMBatch(Batch):
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
if batch.pixel_values is not None:
max_num_images = max(max_num_images, batch.pixel_values.size(1))
max_num_images = max(max_num_images, batch.pixel_values.size(1))
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
@ -481,19 +418,16 @@ class IdeficsCausalLMBatch(Batch):
(total_batch_size, max_input_length + padding_right_offset),
)
if batch.pixel_values is not None:
curr_batch_max_num_images = batch.pixel_values.size(1)
if pixel_values is None:
pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224)
)
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
batch.pixel_values
curr_batch_max_num_images = batch.pixel_values.size(1)
if pixel_values is None:
pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224)
)
else:
pixel_values = None
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
batch.pixel_values
)
if image_attention_mask is None and batch.image_attention_mask is not None:
if image_attention_mask is None:
image_attention_mask = batch.image_attention_mask.new_zeros(
(
total_batch_size,
@ -517,14 +451,13 @@ class IdeficsCausalLMBatch(Batch):
:,
batch_left_offset : -batch.padding_right_offset,
]
if batch.image_attention_mask is not None:
image_attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
:curr_batch_max_num_images,
] = batch.image_attention_mask[
:, batch_left_offset : -batch.padding_right_offset, :
]
image_attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
:curr_batch_max_num_images,
] = batch.image_attention_mask[
:, batch_left_offset : -batch.padding_right_offset, :
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
@ -538,14 +471,7 @@ class IdeficsCausalLMBatch(Batch):
# And ensure that we can update tensors in-place
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [
[
(
t.view(len(batch), -1, *t.shape[-2:])
if isinstance(t, torch.Tensor)
else t
)
for t in layer
]
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
]
elif len(batch.past_key_values[0][0].shape) == 3:
@ -584,98 +510,52 @@ class IdeficsCausalLMBatch(Batch):
# Iterate over attention layers
# Concatenate past key values layer by layer to allow incremental garbage collection
for j in range(len(first_past_kvs)):
if any(
not isinstance(batch.past_key_values[j][0], torch.Tensor)
for batch in batches
):
# XXX: Special handling for cross attention for mllama
padded_past_keys = [
k for batch in batches for k in batch.past_key_values[j][0]
]
padded_past_values = [
k for batch in batches for k in batch.past_key_values[j][1]
]
past_key_values.append([padded_past_keys, padded_past_values])
else:
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
if seqlen > max_input_length:
# XXX: This is probably a cross attention key value
# If not this is ok
_padded_past_keys_shape = (
total_batch_size,
_num_heads,
seqlen,
_head_dim,
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
start_index = 0
for batch in batches:
past_keys = batch.past_key_values[j][0]
# Clear reference to the original tensor
batch.past_key_values[j][0] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
)
else:
_padded_past_keys_shape = padded_past_keys_shape
padded_past_keys = first_past_kvs[j][0].new_zeros(
_padded_past_keys_shape
)
start_index = 0
for batch in batches:
past_keys = batch.past_key_values[j][0]
# Clear reference to the original tensor
batch.past_key_values[j][0] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if past_keys.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama
past_seq_len = past_keys.shape[2]
if batch.keys_head_dim_last:
padded_past_keys[
start_index:end_index, :, -past_seq_len:, :
] = past_keys[:, :, -past_seq_len:, :]
else:
# BLOOM case
padded_past_keys[
start_index:end_index, :, :, -past_seq_len:
] = past_keys[:, :, :, -past_seq_len:]
del past_keys
start_index = end_index
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
if seqlen > max_input_length:
# XXX: This is probably a cross attention key value
# If not this is ok
_padded_past_values_shape = (
total_batch_size,
_num_heads,
seqlen,
_head_dim,
# BLOOM case
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:]
)
else:
_padded_past_values_shape = padded_past_values_shape
padded_past_values = first_past_kvs[j][1].new_zeros(
_padded_past_values_shape
del past_keys
start_index = end_index
padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape
)
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
# Clear reference to the original tensor
batch.past_key_values[j][1] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
# Clear reference to the original tensor
batch.past_key_values[j][1] = None
del past_values
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if past_values.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama
past_seq_len = past_values.shape[2]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Update values
start_index = end_index
# Update values
start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values])
past_key_values.append([padded_past_keys, padded_past_values])
return cls(
batch_id=batches[0].batch_id,
@ -698,10 +578,6 @@ class IdeficsCausalLMBatch(Batch):
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
# No need to keep this around. for Mllamma
aspect_ratio_ids=None,
aspect_ratio_mask=None,
cross_attention_mask=None,
)
def __len__(self):
@ -709,6 +585,88 @@ class IdeficsCausalLMBatch(Batch):
class IdeficsCausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
# 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
self.device, self.dtype = device, dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
config.vision_config.quantize = quantize
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = IdeficsForVisionText2Text(config, weights)
self.config = config
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
return IdeficsCausalLMBatch
@ -722,9 +680,6 @@ class IdeficsCausalLM(Model):
image_hidden_states,
image_attention_mask,
past_key_values: Optional = None,
aspect_ratio_ids=None,
aspect_ratio_mask=None,
cross_attention_mask=None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
kwargs = {
@ -734,23 +689,18 @@ class IdeficsCausalLM(Model):
"image_hidden_states": image_hidden_states,
"image_attention_mask": image_attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
"return_dict": True,
}
if self.has_position_ids:
kwargs["position_ids"] = position_ids
if aspect_ratio_ids is not None:
kwargs["aspect_ratio_ids"] = aspect_ratio_ids
if aspect_ratio_mask is not None:
kwargs["aspect_ratio_mask"] = aspect_ratio_mask
if cross_attention_mask is not None:
kwargs["cross_attention_mask"] = cross_attention_mask
outputs, speculative_logits = self.model.forward(**kwargs)
assert outputs.past_key_values is not None
return (
outputs.logits,
speculative_logits,
outputs.past_key_values,
getattr(outputs, "image_hidden_states", None),
outputs.image_hidden_states,
)
@tracer.start_as_current_span("generate_token")
@ -785,13 +735,9 @@ class IdeficsCausalLM(Model):
image_hidden_states=batch.image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=batch.past_key_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
cross_attention_mask=batch.cross_attention_mask,
)
# Hardcoded remove image tokens
if self.config.model_type == "idefics":
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
start_decode = time.time_ns()
@ -935,12 +881,9 @@ class IdeficsCausalLM(Model):
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
if batch.image_attention_mask is not None:
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
)
if batch.cross_attention_mask is not None:
batch.cross_attention_mask = batch.cross_attention_mask[:, -1:]
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
)
# Decrease right offset
batch.padding_right_offset -= 1
@ -950,8 +893,7 @@ class IdeficsCausalLM(Model):
# Update past key values
batch.past_key_values = past
batch.image_hidden_states = image_hidden_states
if self.model.config.model_type == "mllama":
batch.pixel_values = None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -0,0 +1,119 @@
from io import BytesIO
from PIL import Image
import torch
from typing import Iterable
from text_generation_server.pb.generate_pb2 import Request
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from typing import Optional
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
from text_generation_server.pb import generate_pb2
tracer = trace.get_tracer(__name__)
@dataclass
class MllamaCausalLMBatch(VlmCausalLMBatch):
aspect_ratio_ids: Optional[torch.Tensor] = None
aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
image_inputs = []
texts = []
image_indices = []
batch_tokenized_inputs = []
for i, r in enumerate(requests):
# Each input is encoded into a list, where each element of this input list is either a string or a URL
curr_text = ""
has_image = False
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
curr_text += chunk.text
elif chunk_type == "image":
has_image = True
image = Image.open(BytesIO(chunk.image.data))
# TODO unsure about BOS
curr_text += "<|image|>"
image_input = processor.image_processor(image, return_tensors="pt")
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
texts.append(curr_text)
if has_image:
image_indices.append(i)
input_ids = tokenizer(
curr_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
batch_tokenized_inputs.append(input_ids)
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "aspect_ratio_ids" in image_input:
new_image_inputs["aspect_ratio_ids"] = torch.cat(
[img["aspect_ratio_ids"] for img in image_inputs], dim=0
)
if "aspect_ratio_mask" in image_input:
new_image_inputs["aspect_ratio_mask"] = torch.cat(
[img["aspect_ratio_mask"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
image_inputs["image_indices"] = image_indices
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1
)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(
device=device, dtype=dtype
)
batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device)
batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(
device=device
)
batch.image_indices = image_inputs["image_indices"]
else:
batch.pixel_values = None
batch.aspect_ratio_ids = None
batch.aspect_ratio_mask = None
batch.image_indices = None
return batch

View File

@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import (
FlashCausalLM,
block_tables_to_ragged,
)
from loguru import logger
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
@ -57,7 +58,6 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger
log_master(
logger.info,
@ -135,8 +135,8 @@ def get_number_of_features(height: int, width: int, config) -> int:
class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
pixel_attention_mask: Optional[List[torch.Tensor]] = None
image_sizes: Optional[List[Tuple[int, int]]] = None
@classmethod
@tracer.start_as_current_span("concatenate")
@ -378,6 +378,17 @@ class VlmCausalLM(FlashCausalLM):
max_q=max_s,
max_k=max_k,
)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -389,18 +400,14 @@ class VlmCausalLM(FlashCausalLM):
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
image_indices=batch.image_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
@ -418,7 +425,7 @@ class VlmCausalLM(FlashCausalLM):
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (

View File

@ -22,8 +22,14 @@ try:
VlmCausalLMBatch,
)
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch}
VLM_BATCH_TYPES = {
PaliGemmaBatch,
VlmCausalLMBatch,
IdeficsCausalLMBatch,
MllamaCausalLMBatch,
}
except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash.
VLM_BATCH_TYPES = set()