From ef4fa3ea7cd72b90d81bb18cdc94737ab0a3e311 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 28 Sep 2024 00:38:23 +0200 Subject: [PATCH] Starting to get there. --- .../models/custom_modeling/mllama.py | 313 +++--------------- .../models/mllama_causal_lm.py | 49 ++- .../models/vlm_causal_lm.py | 4 +- 3 files changed, 101 insertions(+), 265 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index 64f99f21..9d904978 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -22,7 +22,6 @@ from torch import nn import flash_attn_2_cuda from transformers.activations import ACT2FN -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS import torch.nn.functional as F from text_generation_server.layers import ( @@ -734,6 +733,7 @@ class MllamaTextCrossAttention(nn.Module): cu_seqlen_k, max_q, max_k, + indices, ) = cross_attention_states key_states = self.k_proj(cross_attention_states) @@ -862,6 +862,8 @@ class FlashLlamaCrossLayer(torch.nn.Module): return hidden_states, residual if residual is not None: hidden_states += residual + + # indices = cross_attention_states[-1] if cu_seqlen_prefill is not None: out_hidden_states = hidden_states[:] hidden_states = hidden_states[:] @@ -892,115 +894,6 @@ class FlashLlamaCrossLayer(torch.nn.Module): return hidden_states, None -class MllamaTextSelfAttention(nn.Module): - def __init__(self, *, prefix, config, weights, layer_idx): - super().__init__() - self.config = config - self.num_heads = config.num_attention_heads - self.dropout = config.dropout - self.hidden_size = config.hidden_size - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - - self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = ( - self.num_key_value_heads // weights.process_group.size() - ) - self.layer_idx = layer_idx - - self.qkv_proj = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) - self.o_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - past_key_value=None, - cache_position=None, - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - - qkv = self.qkv_proj(hidden_states) - query_states, key_states, value_states = qkv.split( - [ - self.head_dim * self.num_heads, - self.head_dim * self.num_key_value_heads, - self.head_dim * self.num_key_value_heads, - ], - dim=2, - ) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - # TODO - # attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value - - # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText class MllamaTextRMSNorm(nn.Module): def __init__(self, weight, eps): @@ -1026,144 +919,6 @@ class MllamaTextRMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT -class MllamaSelfAttentionDecoderLayer(nn.Module): - def __init__(self, *, prefix, config, weights, layer_idx): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = MllamaTextSelfAttention( - prefix=f"{prefix}.self_attn", - config=config, - weights=weights, - layer_idx=layer_idx, - ) - - self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.input_layernorm = MllamaTextRMSNorm.load( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = MllamaTextRMSNorm.load( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.rms_norm_eps, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value=None, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # will become mandatory in v4.45 - image_indices: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - - -class MllamaRotaryEmbedding(nn.Module): - def __init__( - self, - *, - config, - weights, - ): - super().__init__() - device = weights.device - self.rope_type = config.rope_scaling["rope_type"] - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - inv_freq.to(device=device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer( - "inv_freq", inv_freq, persistent=False - ) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if ( - seq_len < self.original_max_seq_len - and self.max_seq_len_cached > self.original_max_seq_len - ): # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = ( - device_type - if isinstance(device_type, str) and device_type != "mps" - else "cpu" - ) - - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( - 1, 2 - ) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class MllamaForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() @@ -1192,14 +947,15 @@ class MllamaForConditionalGeneration(nn.Module): "`aspect_ratio_ids` must be provided if `pixel_values` is provided" ) # logger.info(f"PIxel values {pixel_values.shape}") + batch_size = pixel_values.shape[0] vision_states = self.vision_model( pixel_values, aspect_ratio_ids, aspect_ratio_mask ) cross_attention_states = self.multi_modal_projector(vision_states).reshape( -1, vision_states.shape[-2], self.hidden_size ) - n, m, h = cross_attention_states.shape - cross_attention_states = cross_attention_states.view(1, n * m, h) + _, _, h = cross_attention_states.shape + cross_attention_states = cross_attention_states.view(batch_size, -1, h) # logger.info(f"cross {cross_attention_states.shape}") return cross_attention_states @@ -1219,7 +975,7 @@ class MllamaForConditionalGeneration(nn.Module): cross_attention_states: Optional[torch.Tensor], image_indices, ): - # if cross_attention_mask is not None: + # if cross_att_sention_mask is not None: # cross_attention_mask, full_text_row_masked_out_mask = ( # _prepare_cross_attention_mask( # cross_attention_mask, @@ -1237,24 +993,59 @@ class MllamaForConditionalGeneration(nn.Module): # ] if cross_attention_states is not None: - seqlen_q = input_ids.shape[0] + seqlen_q = len(image_indices) + n_images = cross_attention_states.shape[0] seqlen_k = cross_attention_states.shape[1] + device = cross_attention_states.device + if cu_seqlen_prefill is not None: + # raise RuntimeError("TODO") + offset = 0 + cu_q = [] + indices = [] + for index in image_indices: + cu_q.append(offset) + length = seqlen.input_lengths[index] + input_ids_offset = seqlen.cu_seqlen_q[index] + indices.extend(range(input_ids_offset, input_ids_offset + length)) + offset += length + cu_q.append(offset) + cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32) + + cu_seqlen_k = ( + torch.arange( + n_images + 1, + device=device, + dtype=torch.int32, + ) + * seqlen_k + ) + max_q = cu_seqlen_q[-1].item() + max_k = seqlen_k + else: + cu_seqlen_q = torch.arange( + seqlen_q + 1, device=device, dtype=torch.int32 + ) + seqlen_k = cross_attention_states.shape[1] + n_images = cross_attention_states.shape[0] + cu_seqlen_k = ( + torch.arange( + n_images + 1, + device=device, + dtype=torch.int32, + ) + * seqlen_k + ) + max_q = seqlen_q + max_k = seqlen_k + indices = image_indices[:] - device = input_ids.device - cu_seqlen_q = torch.Tensor([0, seqlen_q]).to( - dtype=torch.int32, device=device - ) - cu_seqlen_k = torch.Tensor([0, seqlen_k]).to( - dtype=torch.int32, device=device - ) - max_q = seqlen_q - max_k = seqlen_k cross_attention_states = ( cross_attention_states, cu_seqlen_q, cu_seqlen_k, max_q, max_k, + indices, ) outputs = self.text_model( diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 9cd80441..c62ccb24 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -1,7 +1,7 @@ from io import BytesIO from PIL import Image import torch -from typing import Iterable +from typing import Iterable, List from text_generation_server.pb.generate_pb2 import Request from dataclasses import dataclass @@ -20,10 +20,54 @@ tracer = trace.get_tracer(__name__) @dataclass class MllamaCausalLMBatch(VlmCausalLMBatch): + image_indices: List[int] = 42 aspect_ratio_ids: Optional[torch.Tensor] = None aspect_ratio_mask: Optional[torch.Tensor] = None cross_attention_states: Optional[torch.Tensor] = None + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches): + batch = super().concatenate(batches) + batch.pixel_values = None + batch.pixel_attention_mask = None + + offset = 0 + image_indices = [] + attention_states = [] + for b in batches: + attention_states.append(b.cross_attention_states) + image_indices.extend([i + offset for i in b.image_indices]) + offset += len(b.image_indices) + batch.cross_attention_states = torch.cat(attention_states, dim=0) + batch.image_indices = image_indices + return batch + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]): + assert self.image_indices is not None + batch = super().filter(request_ids) + assert self.image_indices is not None + indices = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + indices.append(idx) + + offset = 0 + new_image_indices = [] + prev_i = None + for i in self.image_indices: + if i in indices: + new_image_indices.append(offset) + if i != prev_i: + offset += 1 + prev_i = i + + batch.image_indices = new_image_indices + batch.cross_attention_states = self.cross_attention_states[indices] + assert offset <= batch.cross_attention_states.shape[0] + return batch + @classmethod def batch_tokenized_inputs( cls, requests: Iterable[Request], tokenizer, processor, config @@ -115,5 +159,6 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): batch.pixel_values = None batch.aspect_ratio_ids = None batch.aspect_ratio_mask = None - batch.image_indices = None + batch.image_indices = [] + assert batch.image_indices is not None return batch diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 304ddd60..bcb33c35 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -141,7 +141,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches): - batch = super(VlmCausalLMBatch, cls).concatenate(batches) + batch = super().concatenate(batches) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None @@ -402,7 +402,7 @@ class VlmCausalLM(FlashCausalLM): lm_head_indices=lm_head_indices, cross_attention_states=cross_attention_states, adapter_data=adapter_data, - image_indices=batch.image_indices, + image_indices=batch.image_indices[:], ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None