From 85771989d6f4fbeae6d306b6fcb74e8adea08604 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 27 Sep 2024 17:17:14 +0200 Subject: [PATCH] Flashing mllama. --- router/src/infer/chat_template.rs | 2 +- router/src/validation.rs | 2 + .../text_generation_server/models/__init__.py | 19 +- .../custom_modeling/flash_llama_modeling.py | 57 +- .../models/custom_modeling/mllama.py | 736 +++++------------- .../models/idefics_causal_lm.py | 456 +++++------ .../models/mllama_causal_lm.py | 119 +++ .../models/vlm_causal_lm.py | 29 +- server/text_generation_server/server.py | 8 +- 9 files changed, 586 insertions(+), 842 deletions(-) create mode 100644 server/text_generation_server/models/mllama_causal_lm.py diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index a736fc12..1071d0ba 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -29,7 +29,7 @@ impl ChatTemplate { env.set_unknown_method_callback(pycompat::unknown_method_callback); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); - tracing::debug!("Loading template: {:#?}", template_str); + tracing::debug!("Loading template: {}", template_str); // leaking env and template_str as read-only, static resources for performance. let template = Box::leak(env) diff --git a/router/src/validation.rs b/router/src/validation.rs index 85b4220b..10aaba53 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -652,6 +652,8 @@ fn prepare_input( .encode(tokenizer_query, add_special_tokens) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; + tracing::info!("Validation said the input is {}", encoding.len()); + Ok((encoding, input_chunks)) } diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3dbf6713..0fd72d03 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -112,7 +112,11 @@ try: from text_generation_server.models.custom_modeling.flash_phi_modeling import ( FlashPhiForCausalLM, ) - from text_generation_server.models.idefics import IDEFICSSharded + from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM + from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch + from text_generation_server.models.custom_modeling.mllama import ( + MllamaForConditionalGeneration, + ) from text_generation_server.models.custom_modeling.llava_next import ( LlavaNextForConditionalGeneration, ) @@ -149,7 +153,7 @@ except ImportError as e: if FLASH_ATTENTION: __all__.append(FlashCausalLM) - __all__.append(IDEFICSSharded) + __all__.append(IdeficsCausalLM) MAMBA_AVAILABLE = True try: @@ -1122,7 +1126,7 @@ def get_model( ) if model_type == IDEFICS: if FLASH_ATTENTION: - return IDEFICSSharded( + return IdeficsCausalLM( model_id, revision, quantize=quantize, @@ -1134,13 +1138,16 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == MLLAMA: if FLASH_ATTENTION: - return IDEFICSSharded( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=MllamaForConditionalGeneration, + batch_class=MllamaCausalLMBatch, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 358fbbaa..15af17df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -450,6 +450,8 @@ class FlashLlamaLayer(nn.Module): seqlen, max_s, adapter_data, + cross_attention_states, + image_indices, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -487,6 +489,7 @@ class FlashLlamaModel(torch.nn.Module): # Skip fp8 quant for first and last layers self.layers = nn.ModuleList() + self.cross_attention_layers = getattr(config, "cross_attention_layers", []) with no_fp8(weights): self.layers.append( FlashLlamaLayer( @@ -499,22 +502,38 @@ class FlashLlamaModel(torch.nn.Module): ) ) - self.layers.extend( - [ - FlashLlamaLayer( - index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), - config=config, - weights=weights, + # Skip first and last layers + for layer_id in range(1, config.num_hidden_layers - 1): + if layer_id in self.cross_attention_layers: + from text_generation_server.models.custom_modeling.mllama import ( + FlashLlamaCrossLayer, + ) + + self.layers.append( + FlashLlamaCrossLayer( + index=layer_id, + prefix=( + f"model.layers.{layer_id}" + if not prefix + else f"{prefix}.model.layers.{layer_id}" + ), + config=config, + weights=weights, + ) + ) + else: + self.layers.append( + FlashLlamaLayer( + index=layer_id, + prefix=( + f"model.layers.{layer_id}" + if not prefix + else f"{prefix}.model.layers.{layer_id}" + ), + config=config, + weights=weights, + ) ) - # Skip first and last layers - for layer_id in range(1, config.num_hidden_layers - 1) - ] - ) with no_fp8(weights): last_layer_id = config.num_hidden_layers - 1 @@ -556,6 +575,8 @@ class FlashLlamaModel(torch.nn.Module): true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], adapter_data, + cross_attention_states=None, + image_indices=None, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -579,6 +600,8 @@ class FlashLlamaModel(torch.nn.Module): seqlen, max_s, adapter_data, + cross_attention_states, + image_indices, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -625,6 +648,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + cross_attention_states=None, + image_indices=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -639,6 +664,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, adapter_data=adapter_data, + cross_attention_states=cross_attention_states, + image_indices=image_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index 90fcb812..64f99f21 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -19,28 +19,24 @@ from typing import Optional, Tuple, List import torch import torch.utils.checkpoint from torch import nn -import math +import flash_attn_2_cuda from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.modeling_outputs import ( - CausalLMOutputWithPast, - BaseModelOutputWithPast, -) -from transformers.cache_utils import ( - StaticCache, - DynamicCache, -) -from transformers.modeling_attn_mask_utils import AttentionMaskConverter import torch.nn.functional as F from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - SpeculativeHead, FastLinear, ) +from text_generation_server.layers.attention import ( + Seqlen, +) +from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, +) # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb @@ -675,7 +671,7 @@ class MllamaTextCrossAttention(nn.Module): self.num_key_value_heads = self.config.num_key_value_heads self.dropout = config.dropout self.hidden_size = config.hidden_size - self.head_dim = config.hidden_size // self.num_heads + self.head_size = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.layer_idx = layer_idx @@ -715,77 +711,70 @@ class MllamaTextCrossAttention(nn.Module): self.k_norm = MllamaTextRMSNorm.load( prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps ) + self.softmax_scale = self.head_size**-0.5 def forward( self, hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, - past_key_value=None, - attention_mask: Optional[torch.Tensor] = None, - cache_position: Optional[torch.LongTensor] = None, + # past_key_value=None, + # attention_mask: Optional[torch.Tensor] = None, + # cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - bsz, q_len, _ = hidden_states.size() + # hidden_states = hidden_states.unsqueeze(0) + # bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) + query_states = query_states.view(-1, self.num_heads, self.head_size) query_states = self.q_norm(query_states) - if cross_attention_states is not None: - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + ( + cross_attention_states, + cu_seqlen_q, + cu_seqlen_k, + max_q, + max_k, + ) = cross_attention_states - key_states = self.k_norm(key_states) + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) + value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) + key_states = self.k_norm(key_states) - if past_key_value is not None: - # if we have a new image + new tokens, we only computed key_states on that new image - # we still update the cross key states, past_image, new_image. And use it! - key_states, value_states = past_key_value.update( - key_states, - value_states, - self.layer_idx, - {"cache_position": cache_position}, - ) + # key_states = key_states.repeat(1, self.num_key_value_groups, 1) + # value_states = value_states.repeat(1, self.num_key_value_groups, 1) - elif cache_position[0] != 0: - key_states, value_states = ( - past_key_value[self.layer_idx][0], - past_key_value[self.layer_idx][1], - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) + causal = False + # logger.info( + # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" + # ) + attn_output = flash_attn_2_cuda.varlen_fwd( + query_states, + key_states, + value_states, + None, + cu_seqlen_q, + cu_seqlen_k, + None, + None, + None, # block_tables + None, + max_q, + max_k, + 0.0, + self.softmax_scale, + False, + causal, # Causal + -1, # window_size_left, + -1, + 0.0, # softcap + False, + None, + )[0] + attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - return attn_output, attn_weights, past_key_value + return attn_output # Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText @@ -817,15 +806,16 @@ class MllamaTextMLP(nn.Module): gate_up_states = self.gate_up_proj(x) gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) result = self.down_proj( - self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1] + self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] ) return result -class MllamaCrossAttentionDecoderLayer(torch.nn.Module): +class FlashLlamaCrossLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" - def __init__(self, *, prefix, config, weights, layer_idx) -> None: + def __init__(self, *, prefix, config, weights, index) -> None: + layer_idx = index super().__init__() self.cross_attn = MllamaTextCrossAttention( prefix=f"{prefix}.cross_attn", @@ -854,86 +844,52 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): def forward( self, - hidden_states: torch.Tensor, - cross_attention_states: torch.Tensor, - cross_attention_mask: torch.Tensor, - attention_mask: torch.Tensor, - full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], - past_key_value=None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> torch.Tensor: - if past_key_value is not None: - is_mixed = False - if cross_attention_states is None: - out_hidden_states = hidden_states[:] - indices = [] - for i, k in enumerate(past_key_value[self.layer_idx][0]): - if isinstance(k, torch.Tensor): - indices.append(i) - from loguru import logger + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + adapter_data, + cross_attention_states, # [ IB, ...] + image_indices, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if cross_attention_states is None: + return hidden_states, residual + if residual is not None: + hidden_states += residual + if cu_seqlen_prefill is not None: + out_hidden_states = hidden_states[:] + hidden_states = hidden_states[:] + else: + out_hidden_states = hidden_states[:] + hidden_states = hidden_states[image_indices] + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) - logger.info(f"Indices {indices}") - if len(indices) == 0: - return hidden_states - is_mixed = True - if len(indices) == hidden_states.shape[0]: - is_mixed = False + hidden_states = self.cross_attn( + hidden_states=hidden_states, + # attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states - if is_mixed: - hidden_states = hidden_states[indices] - # Dirty hack - _past_key_value = [None] * len(past_key_value) - _past_key_value[self.layer_idx] = ( - torch.stack( - [ - k - for i, k in enumerate(past_key_value[self.layer_idx][0]) - if i in indices - ], - dim=0, - ), - torch.stack( - [ - k - for i, k in enumerate(past_key_value[self.layer_idx][1]) - if i in indices - ], - dim=0, - ), - ) - logger.info(f"Hidden states {hidden_states.shape}") - logger.info(f"k {_past_key_value[self.layer_idx][0].shape}") - logger.info(f"v {_past_key_value[self.layer_idx][1].shape}") - past_key_value = _past_key_value + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if cu_seqlen_prefill is not None: + out_hidden_states[:] = hidden_states + else: + out_hidden_states[image_indices] = hidden_states + hidden_states = out_hidden_states - hidden_states, attn_weights, past_key_value = self.cross_attn( - hidden_states=hidden_states, - attention_mask=cross_attention_mask, - cross_attention_states=cross_attention_states, - past_key_value=past_key_value, - cache_position=cache_position, - ) - hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if full_text_row_masked_out_mask is not None: - hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore - hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - - if is_mixed: - out_hidden_states[indices] = hidden_states - hidden_states = out_hidden_states - from loguru import logger - - logger.info(f"After Hidden states {hidden_states.shape}") - - return hidden_states + return hidden_states, None class MllamaTextSelfAttention(nn.Module): @@ -1103,6 +1059,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module): position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, # will become mandatory in v4.45 + image_indices: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] @@ -1207,332 +1164,6 @@ class MllamaRotaryEmbedding(nn.Module): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class MllamaTextModel(nn.Module): - def __init__(self, *, prefix, config, weights): - super().__init__() - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.config = config - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embed_tokens", weights=weights - ) - self.cross_attention_layers = config.cross_attention_layers - - self.layers = [] - for layer_idx in range(config.num_hidden_layers): - if layer_idx in self.cross_attention_layers: - self.layers.append( - MllamaCrossAttentionDecoderLayer( - prefix=f"{prefix}.layers.{layer_idx}", - config=config, - weights=weights, - layer_idx=layer_idx, - ) - ) - else: - self.layers.append( - MllamaSelfAttentionDecoderLayer( - prefix=f"{prefix}.layers.{layer_idx}", - config=config, - weights=weights, - layer_idx=layer_idx, - ) - ) - - self.norm = MllamaTextRMSNorm.load( - prefix=f"{prefix}.norm", - weights=weights, - eps=config.rms_norm_eps, - ) - self.rotary_emb = MllamaRotaryEmbedding(config=config, weights=weights) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.FloatTensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, - past_key_values=None, - inputs_embeds: Optional[torch.FloatTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - ): - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - hidden_states = inputs_embeds - - if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values, - ) - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - - for idx, decoder_layer in enumerate(self.layers): - # if ( - # idx in self.cross_attention_layers - # and cross_attention_states is None - # and ( - # past_key_values is None - # or ( - # past_key_values is not None - # and any(past_key_values.get_seq_length(idx) == 0 - # ) - # ) - # ): - # continue - - layer_outputs = decoder_layer( - hidden_states, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - attention_mask=causal_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - position_ids=position_ids, - past_key_value=past_key_values, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - hidden_states = layer_outputs - - hidden_states = self.norm(hidden_states) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - ) - - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line - # self.config._attn_implementation == "sdpa" and - # if self.config._attn_implementation == "sdpa" and not using_static_cache: - if self.config._attn_implementation == "sdpa" and not using_static_cache: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_length() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - min_dtype=min_dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended( - causal_mask, min_dtype - ) - - return causal_mask - - -class MllamaForCausalLM(nn.Module): - def __init__(self, *, prefix, config, weights): - super().__init__() - self.vocab_size = config.vocab_size - self.model = MllamaTextModel( - prefix=f"{prefix}.model", config=config, weights=weights - ) - self.lm_head = SpeculativeHead.load( - prefix=f"{prefix}.lm_head", - config=config, - weights=weights, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, - past_key_values=None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - ): - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - cross_attention_states=cross_attention_states, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - ) - - hidden_states = outputs.last_hidden_state - # if lm_head_indices is not None: - # hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.lm_head(hidden_states) - return ( - CausalLMOutputWithPast( - logits=logits, - past_key_values=outputs.past_key_values, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - num_logits_to_keep=None, - **kwargs, - ): - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - if past_key_values is not None: - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif ( - input_ids.shape[1] != cache_position.shape[0] - ): # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = { - "input_ids": input_ids.clone(memory_format=torch.contiguous_format), - "inputs_embeds": None, - } - - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - dtype = self.lm_head.weight.dtype - min_dtype = torch.finfo(dtype).min - - attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_length(), - dtype=dtype, - device=device, - min_dtype=min_dtype, - cache_position=cache_position, - batch_size=batch_size, - ) - - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "attention_mask": attention_mask, - } - ) - return model_inputs - - class MllamaForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() @@ -1545,99 +1176,102 @@ class MllamaForConditionalGeneration(nn.Module): self.vision_model = MllamaVisionModel( prefix="vision_model", config=config.vision_config, weights=weights ) - self.language_model = MllamaForCausalLM( - prefix="language_model", config=config.text_config, weights=weights - ) self.multi_modal_projector = FastLinear.load( prefix="multi_modal_projector", config=config, weights=weights, bias=True ) + self.text_model = FlashLlamaForCausalLM( + prefix="language_model", config=config.text_config, weights=weights + ) self.config = config self.dtype = weights.dtype self.device = weights.device + def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask): + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + # logger.info(f"PIxel values {pixel_values.shape}") + vision_states = self.vision_model( + pixel_values, aspect_ratio_ids, aspect_ratio_mask + ) + cross_attention_states = self.multi_modal_projector(vision_states).reshape( + -1, vision_states.shape[-2], self.hidden_size + ) + n, m, h = cross_attention_states.shape + cross_attention_states = cross_attention_states.view(1, n * m, h) + # logger.info(f"cross {cross_attention_states.shape}") + return cross_attention_states + def forward( self, - input_ids: torch.LongTensor = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[List[List[int]]] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[List[List[List[int]]]] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - image_hidden_states=None, - image_attention_mask=None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor], + adapter_data: Optional[torch.Tensor], + cross_attention_states: Optional[torch.Tensor], + image_indices, ): - if past_key_values is None: - past_key_values = DynamicCache( - num_hidden_layers=self.config.text_config.num_hidden_layers - ) - elif isinstance(past_key_values, list): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # if cross_attention_mask is not None: + # cross_attention_mask, full_text_row_masked_out_mask = ( + # _prepare_cross_attention_mask( + # cross_attention_mask, + # num_vision_tokens=self.vision_model.num_patches, + # dtype=self.dtype, + # ) + # ) + # else: + # full_text_row_masked_out_mask = None - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + # if cross_attention_mask is not None and cache_position is not None: + # cross_attention_mask = cross_attention_mask[:, :, cache_position] + # full_text_row_masked_out_mask = full_text_row_masked_out_mask[ + # :, :, cache_position + # ] + + if cross_attention_states is not None: + seqlen_q = input_ids.shape[0] + seqlen_k = cross_attention_states.shape[1] + + device = input_ids.device + cu_seqlen_q = torch.Tensor([0, seqlen_q]).to( + dtype=torch.int32, device=device + ) + cu_seqlen_k = torch.Tensor([0, seqlen_k]).to( + dtype=torch.int32, device=device + ) + max_q = seqlen_q + max_k = seqlen_k + cross_attention_states = ( + cross_attention_states, + cu_seqlen_q, + cu_seqlen_k, + max_q, + max_k, ) - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and cross_attention_states is not None: - raise ValueError( - "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" - ) - - if pixel_values is not None: - if aspect_ratio_ids is None: - raise ValueError( - "`aspect_ratio_ids` must be provided if `pixel_values` is provided" - ) - # get vision tokens from vision model - - vision_states = self.vision_model( - pixel_values, aspect_ratio_ids, aspect_ratio_mask - ) - cross_attention_states = self.multi_modal_projector(vision_states).reshape( - -1, vision_states.shape[-2], self.hidden_size - ) - - if cross_attention_mask is not None: - cross_attention_mask, full_text_row_masked_out_mask = ( - _prepare_cross_attention_mask( - cross_attention_mask, - num_vision_tokens=self.vision_model.num_patches, - dtype=self.dtype, - ) - ) - else: - full_text_row_masked_out_mask = None - - if cross_attention_mask is not None and cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, cache_position - ] - - outputs = self.language_model( + outputs = self.text_model( input_ids=input_ids, - attention_mask=attention_mask, position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + lm_head_indices=lm_head_indices, + adapter_data=adapter_data, cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + # cross_attention_mask=cross_attention_mask, + image_indices=image_indices, ) return outputs diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 94e03efd..9a7a6fe1 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -6,6 +6,9 @@ import time from dataclasses import dataclass from opentelemetry import trace from transformers import ( + AutoConfig, + AutoProcessor, + AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin, ) @@ -20,6 +23,18 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +import torch.distributed +from text_generation_server.models.custom_modeling.idefics_modeling import ( + IdeficsForVisionText2Text, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) +from text_generation_server.utils.quantization import get_loader + +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -36,9 +51,6 @@ class IdeficsCausalLMBatch(Batch): attention_mask: torch.Tensor position_ids: torch.Tensor pixel_values: Optional[torch.Tensor] - aspect_ratio_ids: Optional[torch.Tensor] - aspect_ratio_mask: Optional[torch.Tensor] - cross_attention_mask: Optional[torch.Tensor] image_hidden_states: Optional[torch.Tensor] image_attention_mask: Optional[torch.Tensor] past_key_values: Optional[List[Tuple]] @@ -121,69 +133,33 @@ class IdeficsCausalLMBatch(Batch): ) # TODO Check impact on idefics + prompts = [] + for inp in inputs: + # Each input is encoded into a list, where each element of this input list is either a string or a URL + prompt = [] + for chunk in inp: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + prompt.append(chunk.text) + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + prompt.append(image) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + prompts.append(prompt) - if config.model_type == "idefics": - prompts = [] - for inp in inputs: - # Each input is encoded into a list, where each element of this input list is either a string or a URL - prompt = [] - for chunk in inp: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - prompt.append(chunk.text) - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - prompt.append(image) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - prompts.append(prompt) - - # The processor replaces the call to tokenizer, and - # a/ takes care of fetching images from the URL - # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model - tokenized_inputs = processor( - prompts, - return_tensors="pt", - padding=True, - truncation=True, - max_length=max_truncation, - # TODO Check impact on idefics - # add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token - ).to(device) - else: - images = [] - texts = [] - for inp in inputs: - # Each input is encoded into a list, where each element of this input list is either a string or a URL - curr_images = [] - curr_text = "" - for chunk in inp: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - curr_text += chunk.text - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - curr_images.append(image) - # TODO unsure about BOS - curr_text += "<|image|>" - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - images.append(curr_images) - texts.append(curr_text) - - # The processor replaces the call to tokenizer, and - # a/ takes care of fetching images from the URL - # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model - if all(len(im) == 0 for im in images): - images = None - tokenized_inputs = processor( - images=images, - text=texts, - return_tensors="pt", - padding=True, - truncation=True, - max_length=max_truncation, - ).to(device) + # The processor replaces the call to tokenizer, and + # a/ takes care of fetching images from the URL + # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model + tokenized_inputs = processor( + prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_truncation, + # TODO Check impact on idefics + # add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token + ).to(device) for _ in pb.requests: input_len = tokenized_inputs["input_ids"].shape[1] prefix_offsets.append( @@ -208,10 +184,7 @@ class IdeficsCausalLMBatch(Batch): # Do the same for image_attention_mask if pixel_values is None: image_attention_mask = None - aspect_ratio_ids = None - aspect_ratio_mask = None - cross_attention_mask = None - elif "image_attention_mask" in tokenized_inputs: + else: image_attention_mask = input_ids.new_zeros( ( pb.size, @@ -222,21 +195,6 @@ class IdeficsCausalLMBatch(Batch): image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ "image_attention_mask" ] - aspect_ratio_ids = None - aspect_ratio_mask = None - cross_attention_mask = None - elif "cross_attention_mask" in tokenized_inputs: - image_attention_mask = None - aspect_ratio_ids = tokenized_inputs["aspect_ratio_ids"] - aspect_ratio_mask = tokenized_inputs["aspect_ratio_mask"] - cross_attention_mask = tokenized_inputs["cross_attention_mask"] - pixel_values = pixel_values.to(dtype=dtype) - # XXX: <|image|> token is actually out of bounds and bugs out the logit processors. - tokenized_inputs["input_ids"] = tokenized_inputs["input_ids"].clamp( - max=processor.tokenizer.vocab_size - 1 - ) - else: - raise RuntimeError("Unhandled state for idefics/mllama") position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) @@ -266,9 +224,6 @@ class IdeficsCausalLMBatch(Batch): max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - cross_attention_mask=cross_attention_mask, ) @tracer.start_as_current_span("filter") @@ -332,21 +287,15 @@ class IdeficsCausalLMBatch(Batch): + new_padding_right_offset, ] # Do the same for pixel_values and image_attention_mask - if self.pixel_values is not None: - pixel_values = self.pixel_values[keep_indices] - else: - pixel_values = None - - if self.image_attention_mask is not None: - self.image_attention_mask = self.image_attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.image_attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - :, - ] - + pixel_values = self.pixel_values[keep_indices] + self.image_attention_mask = self.image_attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.image_attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + :, + ] if self.image_hidden_states is None: image_hidden_states = None else: @@ -360,15 +309,7 @@ class IdeficsCausalLMBatch(Batch): past_kv_length = max_input_length - 1 for layer in self.past_key_values: past_keys, past_values = layer - if not isinstance(past_keys, torch.Tensor): - past_keys = [k for i, k in enumerate(past_keys) if i in keep_indices] - past_values = [ - k for i, k in enumerate(past_values) if i in keep_indices - ] - layer[0] = past_keys - layer[1] = past_values - continue - elif len(past_keys.shape) == 3: + if len(past_keys.shape) == 3: # Force past to be of dim [self_size, num_heads, ...] for easy indexing past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) @@ -397,9 +338,6 @@ class IdeficsCausalLMBatch(Batch): self.max_input_length = max_input_length self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens - self.aspect_ratio_ids = None - self.aspect_ratio_mask = None - self.cross_attention_mask = None return self @@ -417,8 +355,7 @@ class IdeficsCausalLMBatch(Batch): for batch in batches: total_batch_size += len(batch) max_input_length = max(max_input_length, batch.max_input_length) - if batch.pixel_values is not None: - max_num_images = max(max_num_images, batch.pixel_values.size(1)) + max_num_images = max(max_num_images, batch.pixel_values.size(1)) padding_right_offset = max(padding_right_offset, batch.padding_right_offset) # Batch attributes @@ -481,19 +418,16 @@ class IdeficsCausalLMBatch(Batch): (total_batch_size, max_input_length + padding_right_offset), ) - if batch.pixel_values is not None: - curr_batch_max_num_images = batch.pixel_values.size(1) - if pixel_values is None: - pixel_values = batch.pixel_values.new_zeros( - (total_batch_size, max_num_images, 3, 224, 224) - ) - pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( - batch.pixel_values + curr_batch_max_num_images = batch.pixel_values.size(1) + if pixel_values is None: + pixel_values = batch.pixel_values.new_zeros( + (total_batch_size, max_num_images, 3, 224, 224) ) - else: - pixel_values = None + pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( + batch.pixel_values + ) - if image_attention_mask is None and batch.image_attention_mask is not None: + if image_attention_mask is None: image_attention_mask = batch.image_attention_mask.new_zeros( ( total_batch_size, @@ -517,14 +451,13 @@ class IdeficsCausalLMBatch(Batch): :, batch_left_offset : -batch.padding_right_offset, ] - if batch.image_attention_mask is not None: - image_attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - :curr_batch_max_num_images, - ] = batch.image_attention_mask[ - :, batch_left_offset : -batch.padding_right_offset, : - ] + image_attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + :curr_batch_max_num_images, + ] = batch.image_attention_mask[ + :, batch_left_offset : -batch.padding_right_offset, : + ] # Create empty tensor # position_ids is always of shape [batch_size, 1] @@ -538,14 +471,7 @@ class IdeficsCausalLMBatch(Batch): # And ensure that we can update tensors in-place if isinstance(batch.past_key_values[0], tuple): batch.past_key_values = [ - [ - ( - t.view(len(batch), -1, *t.shape[-2:]) - if isinstance(t, torch.Tensor) - else t - ) - for t in layer - ] + [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values ] elif len(batch.past_key_values[0][0].shape) == 3: @@ -584,98 +510,52 @@ class IdeficsCausalLMBatch(Batch): # Iterate over attention layers # Concatenate past key values layer by layer to allow incremental garbage collection for j in range(len(first_past_kvs)): - if any( - not isinstance(batch.past_key_values[j][0], torch.Tensor) - for batch in batches - ): - # XXX: Special handling for cross attention for mllama - padded_past_keys = [ - k for batch in batches for k in batch.past_key_values[j][0] - ] - padded_past_values = [ - k for batch in batches for k in batch.past_key_values[j][1] - ] - past_key_values.append([padded_past_keys, padded_past_values]) - else: - _, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape - if seqlen > max_input_length: - # XXX: This is probably a cross attention key value - # If not this is ok - _padded_past_keys_shape = ( - total_batch_size, - _num_heads, - seqlen, - _head_dim, + padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) + start_index = 0 + for batch in batches: + past_keys = batch.past_key_values[j][0] + # Clear reference to the original tensor + batch.past_key_values[j][0] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the keys to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + if batch.keys_head_dim_last: + padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( + past_keys[:, :, -past_seq_len:, :] ) else: - _padded_past_keys_shape = padded_past_keys_shape - - padded_past_keys = first_past_kvs[j][0].new_zeros( - _padded_past_keys_shape - ) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if past_keys.shape[2] > past_seq_len: - # XXX: This is a cross attention kv in mllama - past_seq_len = past_keys.shape[2] - if batch.keys_head_dim_last: - padded_past_keys[ - start_index:end_index, :, -past_seq_len:, : - ] = past_keys[:, :, -past_seq_len:, :] - else: - # BLOOM case - padded_past_keys[ - start_index:end_index, :, :, -past_seq_len: - ] = past_keys[:, :, :, -past_seq_len:] - del past_keys - - start_index = end_index - - _, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape - if seqlen > max_input_length: - # XXX: This is probably a cross attention key value - # If not this is ok - _padded_past_values_shape = ( - total_batch_size, - _num_heads, - seqlen, - _head_dim, + # BLOOM case + padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( + past_keys[:, :, :, -past_seq_len:] ) - else: - _padded_past_values_shape = padded_past_values_shape - padded_past_values = first_past_kvs[j][1].new_zeros( - _padded_past_values_shape + del past_keys + + start_index = end_index + + padded_past_values = first_past_kvs[j][1].new_zeros( + padded_past_values_shape + ) + start_index = 0 + for batch in batches: + past_values = batch.past_key_values[j][1] + # Clear reference to the original tensor + batch.past_key_values[j][1] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past values to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( + past_values[:, :, -past_seq_len:, :] ) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None + del past_values - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if past_values.shape[2] > past_seq_len: - # XXX: This is a cross attention kv in mllama - past_seq_len = past_values.shape[2] - padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( - past_values[:, :, -past_seq_len:, :] - ) - del past_values + # Update values + start_index = end_index - # Update values - start_index = end_index - - past_key_values.append([padded_past_keys, padded_past_values]) + past_key_values.append([padded_past_keys, padded_past_values]) return cls( batch_id=batches[0].batch_id, @@ -698,10 +578,6 @@ class IdeficsCausalLMBatch(Batch): padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, max_tokens=max_tokens, - # No need to keep this around. for Mllamma - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=None, ) def __len__(self): @@ -709,6 +585,88 @@ class IdeficsCausalLMBatch(Batch): class IdeficsCausalLM(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.quantize = quantize + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + # 9b seems to work correctly enough in float16, but 80b seems + # to be really saturating for f16. + dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + self.device, self.dtype = device, dtype + + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + config.vision_config.quantize = quantize + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + self.processor = AutoProcessor.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + weights_loader=weights_loader, + ) + + model = IdeficsForVisionText2Text(config, weights) + + self.config = config + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + @property def batch_type(self) -> Type[IdeficsCausalLMBatch]: return IdeficsCausalLMBatch @@ -722,9 +680,6 @@ class IdeficsCausalLM(Model): image_hidden_states, image_attention_mask, past_key_values: Optional = None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=None, ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { @@ -734,23 +689,18 @@ class IdeficsCausalLM(Model): "image_hidden_states": image_hidden_states, "image_attention_mask": image_attention_mask, "past_key_values": past_key_values, + "use_cache": True, + "return_dict": True, } if self.has_position_ids: kwargs["position_ids"] = position_ids - if aspect_ratio_ids is not None: - kwargs["aspect_ratio_ids"] = aspect_ratio_ids - if aspect_ratio_mask is not None: - kwargs["aspect_ratio_mask"] = aspect_ratio_mask - if cross_attention_mask is not None: - kwargs["cross_attention_mask"] = cross_attention_mask outputs, speculative_logits = self.model.forward(**kwargs) - assert outputs.past_key_values is not None return ( outputs.logits, speculative_logits, outputs.past_key_values, - getattr(outputs, "image_hidden_states", None), + outputs.image_hidden_states, ) @tracer.start_as_current_span("generate_token") @@ -785,13 +735,9 @@ class IdeficsCausalLM(Model): image_hidden_states=batch.image_hidden_states, image_attention_mask=image_attention_mask, past_key_values=batch.past_key_values, - aspect_ratio_ids=batch.aspect_ratio_ids, - aspect_ratio_mask=batch.aspect_ratio_mask, - cross_attention_mask=batch.cross_attention_mask, ) # Hardcoded remove image tokens - if self.config.model_type == "idefics": - logits[:, 32000:32001] = torch.finfo(logits.dtype).min + logits[:, 32000:32001] = torch.finfo(logits.dtype).min start_decode = time.time_ns() @@ -935,12 +881,9 @@ class IdeficsCausalLM(Model): # Update attention_mask as we added a new token to input_ids batch.attention_mask[:, -batch.padding_right_offset] = 1 - if batch.image_attention_mask is not None: - batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( - batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] - ) - if batch.cross_attention_mask is not None: - batch.cross_attention_mask = batch.cross_attention_mask[:, -1:] + batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( + batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] + ) # Decrease right offset batch.padding_right_offset -= 1 @@ -950,8 +893,7 @@ class IdeficsCausalLM(Model): # Update past key values batch.past_key_values = past batch.image_hidden_states = image_hidden_states - if self.model.config.model_type == "mllama": - batch.pixel_values = None + forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py new file mode 100644 index 00000000..9cd80441 --- /dev/null +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -0,0 +1,119 @@ +from io import BytesIO +from PIL import Image +import torch +from typing import Iterable +from text_generation_server.pb.generate_pb2 import Request + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import ( + PreTrainedTokenizerBase, +) +from typing import Optional + +from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch +from text_generation_server.pb import generate_pb2 + + +tracer = trace.get_tracer(__name__) + + +@dataclass +class MllamaCausalLMBatch(VlmCausalLMBatch): + aspect_ratio_ids: Optional[torch.Tensor] = None + aspect_ratio_mask: Optional[torch.Tensor] = None + cross_attention_states: Optional[torch.Tensor] = None + + @classmethod + def batch_tokenized_inputs( + cls, requests: Iterable[Request], tokenizer, processor, config + ): + image_inputs = [] + texts = [] + image_indices = [] + batch_tokenized_inputs = [] + for i, r in enumerate(requests): + # Each input is encoded into a list, where each element of this input list is either a string or a URL + curr_text = "" + has_image = False + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + curr_text += chunk.text + elif chunk_type == "image": + has_image = True + image = Image.open(BytesIO(chunk.image.data)) + # TODO unsure about BOS + curr_text += "<|image|>" + image_input = processor.image_processor(image, return_tensors="pt") + image_inputs.append(image_input) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + texts.append(curr_text) + if has_image: + image_indices.append(i) + + input_ids = tokenizer( + curr_text, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + batch_tokenized_inputs.append(input_ids) + if image_inputs: + image_input = image_inputs[0] + new_image_inputs = { + "pixel_values": torch.cat( + [img["pixel_values"] for img in image_inputs], dim=0 + ), + } + if "aspect_ratio_ids" in image_input: + new_image_inputs["aspect_ratio_ids"] = torch.cat( + [img["aspect_ratio_ids"] for img in image_inputs], dim=0 + ) + if "aspect_ratio_mask" in image_input: + new_image_inputs["aspect_ratio_mask"] = torch.cat( + [img["aspect_ratio_mask"] for img in image_inputs], dim=0 + ) + image_inputs = new_image_inputs + image_inputs["image_indices"] = image_indices + else: + image_inputs = None + + return batch_tokenized_inputs, image_inputs + + @classmethod + def from_pb_processor( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor, + config, + dtype: torch.dtype, + device: torch.device, + ) -> "VlmCausalLMBatch": + batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( + pb.requests, tokenizer, processor, config + ) + batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + # XXX: <|image|> token is actually out of bounds and bugs out the logit processors. + batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( + max=config.text_config.vocab_size - 1 + ) + batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) + + if image_inputs is not None: + batch.pixel_values = image_inputs["pixel_values"].to( + device=device, dtype=dtype + ) + batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device) + batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to( + device=device + ) + batch.image_indices = image_inputs["image_indices"] + else: + batch.pixel_values = None + batch.aspect_ratio_ids = None + batch.aspect_ratio_mask = None + batch.image_indices = None + return batch diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7f7d2e4d..304ddd60 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLM, block_tables_to_ragged, ) +from loguru import logger from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.utils.log import log_master from transformers import AutoProcessor @@ -57,7 +58,6 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) - from loguru import logger log_master( logger.info, @@ -135,8 +135,8 @@ def get_number_of_features(height: int, width: int, config) -> int: class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] - pixel_attention_mask: Optional[List[torch.Tensor]] - image_sizes: Optional[List[Tuple[int, int]]] + pixel_attention_mask: Optional[List[torch.Tensor]] = None + image_sizes: Optional[List[Tuple[int, int]]] = None @classmethod @tracer.start_as_current_span("concatenate") @@ -378,6 +378,17 @@ class VlmCausalLM(FlashCausalLM): max_q=max_s, max_k=max_k, ) + + if batch.pixel_values is not None: + cross_attention_states = self.model.vision_forward( + pixel_values=batch.pixel_values, + aspect_ratio_ids=batch.aspect_ratio_ids, + aspect_ratio_mask=batch.aspect_ratio_mask, + ) + batch.cross_attention_states = cross_attention_states + + cross_attention_states = batch.cross_attention_states + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -389,18 +400,14 @@ class VlmCausalLM(FlashCausalLM): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, + cross_attention_states=cross_attention_states, + adapter_data=adapter_data, + image_indices=batch.image_indices, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph @@ -418,7 +425,7 @@ class VlmCausalLM(FlashCausalLM): cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables - cuda_graph["slots"].fill_(-1) + cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 22871ec5..e7dfd8e4 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -22,8 +22,14 @@ try: VlmCausalLMBatch, ) from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch + from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch - VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch} + VLM_BATCH_TYPES = { + PaliGemmaBatch, + VlmCausalLMBatch, + IdeficsCausalLMBatch, + MllamaCausalLMBatch, + } except (ImportError, NotImplementedError): # These imports can fail on CPU/Non flash. VLM_BATCH_TYPES = set()