diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0fd72d03..babd851d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -76,6 +76,7 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.mllama_causal_lm import MllamaCausalLM from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( FlashDeepseekV2ForCausalLM, DeepseekV2Config, @@ -1138,7 +1139,7 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == MLLAMA: if FLASH_ATTENTION: - return VlmCausalLM( + return MllamaCausalLM( model_id=model_id, model_class=MllamaForConditionalGeneration, batch_class=MllamaCausalLMBatch, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 15af17df..c9ec70cc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -451,7 +451,6 @@ class FlashLlamaLayer(nn.Module): max_s, adapter_data, cross_attention_states, - image_indices, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -576,7 +575,6 @@ class FlashLlamaModel(torch.nn.Module): prefill_cache_indices: Optional[torch.Tensor], adapter_data, cross_attention_states=None, - image_indices=None, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -601,7 +599,6 @@ class FlashLlamaModel(torch.nn.Module): max_s, adapter_data, cross_attention_states, - image_indices, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -649,7 +646,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, - image_indices=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -665,7 +661,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): prefill_cache_indices=prefill_cache_indices, adapter_data=adapter_data, cross_attention_states=cross_attention_states, - image_indices=image_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index 9d904978..50c4d49c 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -856,20 +856,17 @@ class FlashLlamaCrossLayer(torch.nn.Module): max_s, adapter_data, cross_attention_states, # [ IB, ...] - image_indices, ) -> Tuple[torch.Tensor, torch.Tensor]: if cross_attention_states is None: return hidden_states, residual if residual is not None: hidden_states += residual - # indices = cross_attention_states[-1] - if cu_seqlen_prefill is not None: - out_hidden_states = hidden_states[:] - hidden_states = hidden_states[:] - else: - out_hidden_states = hidden_states[:] - hidden_states = hidden_states[image_indices] + indices = cross_attention_states[-1] + out_hidden_states = hidden_states[:] + if len(indices) > 0: + assert max(indices) < hidden_states.shape[0] + hidden_states = hidden_states[indices] residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -885,10 +882,7 @@ class FlashLlamaCrossLayer(torch.nn.Module): hidden_states = self.mlp(hidden_states) hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - if cu_seqlen_prefill is not None: - out_hidden_states[:] = hidden_states - else: - out_hidden_states[image_indices] = hidden_states + out_hidden_states[indices] = hidden_states hidden_states = out_hidden_states return hidden_states, None @@ -971,40 +965,24 @@ class MllamaForConditionalGeneration(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], - adapter_data: Optional[torch.Tensor], - cross_attention_states: Optional[torch.Tensor], - image_indices, + adapter_data: Optional[torch.Tensor] = None, + # XXX: Putting these as optional so that the cuda warmup calls can go through. + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, ): - # if cross_att_sention_mask is not None: - # cross_attention_mask, full_text_row_masked_out_mask = ( - # _prepare_cross_attention_mask( - # cross_attention_mask, - # num_vision_tokens=self.vision_model.num_patches, - # dtype=self.dtype, - # ) - # ) - # else: - # full_text_row_masked_out_mask = None - - # if cross_attention_mask is not None and cache_position is not None: - # cross_attention_mask = cross_attention_mask[:, :, cache_position] - # full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - # :, :, cache_position - # ] - if cross_attention_states is not None: seqlen_q = len(image_indices) n_images = cross_attention_states.shape[0] seqlen_k = cross_attention_states.shape[1] device = cross_attention_states.device if cu_seqlen_prefill is not None: - # raise RuntimeError("TODO") offset = 0 cu_q = [] indices = [] for index in image_indices: cu_q.append(offset) - length = seqlen.input_lengths[index] + length = seqlen.input_lengths[index].item() + assert index < seqlen.cu_seqlen_q.shape[0] input_ids_offset = seqlen.cu_seqlen_q[index] indices.extend(range(input_ids_offset, input_ids_offset + length)) offset += length @@ -1061,8 +1039,6 @@ class MllamaForConditionalGeneration(nn.Module): lm_head_indices=lm_head_indices, adapter_data=adapter_data, cross_attention_states=cross_attention_states, - # cross_attention_mask=cross_attention_mask, - image_indices=image_indices, ) return outputs diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index c62ccb24..ef12b621 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -1,7 +1,7 @@ from io import BytesIO from PIL import Image import torch -from typing import Iterable, List +from typing import Iterable, Optional, Tuple, List, Dict from text_generation_server.pb.generate_pb2 import Request from dataclasses import dataclass @@ -9,10 +9,13 @@ from opentelemetry import trace from transformers import ( PreTrainedTokenizerBase, ) -from typing import Optional - -from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch +from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM from text_generation_server.pb import generate_pb2 +from text_generation_server.models.flash_causal_lm import ( + block_tables_to_ragged, +) +from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION +from text_generation_server.layers.attention import Seqlen tracer = trace.get_tracer(__name__) @@ -36,11 +39,17 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): image_indices = [] attention_states = [] for b in batches: - attention_states.append(b.cross_attention_states) + if b.cross_attention_states is not None: + attention_states.append(b.cross_attention_states) image_indices.extend([i + offset for i in b.image_indices]) offset += len(b.image_indices) - batch.cross_attention_states = torch.cat(attention_states, dim=0) - batch.image_indices = image_indices + if len(attention_states) > 0: + assert len(image_indices) > 0 + batch.cross_attention_states = torch.cat(attention_states, dim=0) + batch.image_indices = image_indices + else: + batch.cross_attention_states = None + batch.image_indices = [] return batch @tracer.start_as_current_span("filter") @@ -64,8 +73,14 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): prev_i = i batch.image_indices = new_image_indices - batch.cross_attention_states = self.cross_attention_states[indices] - assert offset <= batch.cross_attention_states.shape[0] + if len(new_image_indices) > 0: + assert max(new_image_indices) < self.cross_attention_states.shape[0] + assert offset <= self.cross_attention_states.shape[0] + batch.cross_attention_states = self.cross_attention_states[ + new_image_indices + ] + else: + batch.cross_attention_states = None return batch @classmethod @@ -79,23 +94,20 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): for i, r in enumerate(requests): # Each input is encoded into a list, where each element of this input list is either a string or a URL curr_text = "" - has_image = False for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": curr_text += chunk.text elif chunk_type == "image": - has_image = True image = Image.open(BytesIO(chunk.image.data)) # TODO unsure about BOS curr_text += "<|image|>" image_input = processor.image_processor(image, return_tensors="pt") image_inputs.append(image_input) + image_indices.append(i) else: raise RuntimeError(f"Invalid chunk type {chunk_type}") texts.append(curr_text) - if has_image: - image_indices.append(i) input_ids = tokenizer( curr_text, @@ -124,6 +136,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): else: image_inputs = None + if image_inputs is not None: + assert len(image_indices) == image_inputs["pixel_values"].shape[0] + return batch_tokenized_inputs, image_inputs @classmethod @@ -162,3 +177,173 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): batch.image_indices = [] assert batch.image_indices is not None return batch + + +class MllamaCausalLM(VlmCausalLM): + def forward( + self, + batch: VlmCausalLMBatch, + adapter_data: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Model Forward + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + bs = input_ids.shape[0] + # Try to find an associated cuda graph + bs = input_ids.shape[0] + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + else: + cuda_graph = None + if ( + cu_seqlen_prefill is not None + or cuda_graph is None + # Only run cuda graphs when there's no images. + or batch.cross_attention_states is not None + ): + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths_tensor=input_lengths, + prefix_lens_tensor=prefix_lens_tensor, + ): + max_k = (input_lengths + prefix_lens_tensor).max().item() + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=max_s, + max_k=max_k, + ) + + if batch.pixel_values is not None: + cross_attention_states = self.model.vision_forward( + pixel_values=batch.pixel_values, + aspect_ratio_ids=batch.aspect_ratio_ids, + aspect_ratio_mask=batch.aspect_ratio_mask, + ) + batch.cross_attention_states = cross_attention_states + + cross_attention_states = batch.cross_attention_states + + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + cross_attention_states=cross_attention_states, + adapter_data=adapter_data, + image_indices=batch.image_indices[:], + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + return logits, speculative_logits + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + cuda_graph["slots"].fill_(0) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) + + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits