Working state.

This commit is contained in:
Nicolas Patry 2024-09-28 22:10:10 +02:00
parent ef4fa3ea7c
commit 2ac607a215
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
4 changed files with 212 additions and 55 deletions

View File

@ -76,6 +76,7 @@ FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
@ -1138,7 +1139,7 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == MLLAMA:
if FLASH_ATTENTION:
return VlmCausalLM(
return MllamaCausalLM(
model_id=model_id,
model_class=MllamaForConditionalGeneration,
batch_class=MllamaCausalLMBatch,

View File

@ -451,7 +451,6 @@ class FlashLlamaLayer(nn.Module):
max_s,
adapter_data,
cross_attention_states,
image_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -576,7 +575,6 @@ class FlashLlamaModel(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
cross_attention_states=None,
image_indices=None,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -601,7 +599,6 @@ class FlashLlamaModel(torch.nn.Module):
max_s,
adapter_data,
cross_attention_states,
image_indices,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -649,7 +646,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
image_indices=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
@ -665,7 +661,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
image_indices=image_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -856,20 +856,17 @@ class FlashLlamaCrossLayer(torch.nn.Module):
max_s,
adapter_data,
cross_attention_states, # [ IB, ...]
image_indices,
) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None:
return hidden_states, residual
if residual is not None:
hidden_states += residual
# indices = cross_attention_states[-1]
if cu_seqlen_prefill is not None:
indices = cross_attention_states[-1]
out_hidden_states = hidden_states[:]
hidden_states = hidden_states[:]
else:
out_hidden_states = hidden_states[:]
hidden_states = hidden_states[image_indices]
if len(indices) > 0:
assert max(indices) < hidden_states.shape[0]
hidden_states = hidden_states[indices]
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@ -885,10 +882,7 @@ class FlashLlamaCrossLayer(torch.nn.Module):
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
if cu_seqlen_prefill is not None:
out_hidden_states[:] = hidden_states
else:
out_hidden_states[image_indices] = hidden_states
out_hidden_states[indices] = hidden_states
hidden_states = out_hidden_states
return hidden_states, None
@ -971,40 +965,24 @@ class MllamaForConditionalGeneration(nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor],
cross_attention_states: Optional[torch.Tensor],
image_indices,
adapter_data: Optional[torch.Tensor] = None,
# XXX: Putting these as optional so that the cuda warmup calls can go through.
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
# if cross_att_sention_mask is not None:
# cross_attention_mask, full_text_row_masked_out_mask = (
# _prepare_cross_attention_mask(
# cross_attention_mask,
# num_vision_tokens=self.vision_model.num_patches,
# dtype=self.dtype,
# )
# )
# else:
# full_text_row_masked_out_mask = None
# if cross_attention_mask is not None and cache_position is not None:
# cross_attention_mask = cross_attention_mask[:, :, cache_position]
# full_text_row_masked_out_mask = full_text_row_masked_out_mask[
# :, :, cache_position
# ]
if cross_attention_states is not None:
seqlen_q = len(image_indices)
n_images = cross_attention_states.shape[0]
seqlen_k = cross_attention_states.shape[1]
device = cross_attention_states.device
if cu_seqlen_prefill is not None:
# raise RuntimeError("TODO")
offset = 0
cu_q = []
indices = []
for index in image_indices:
cu_q.append(offset)
length = seqlen.input_lengths[index]
length = seqlen.input_lengths[index].item()
assert index < seqlen.cu_seqlen_q.shape[0]
input_ids_offset = seqlen.cu_seqlen_q[index]
indices.extend(range(input_ids_offset, input_ids_offset + length))
offset += length
@ -1061,8 +1039,6 @@ class MllamaForConditionalGeneration(nn.Module):
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
# cross_attention_mask=cross_attention_mask,
image_indices=image_indices,
)
return outputs

View File

@ -1,7 +1,7 @@
from io import BytesIO
from PIL import Image
import torch
from typing import Iterable, List
from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request
from dataclasses import dataclass
@ -9,10 +9,13 @@ from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from typing import Optional
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
block_tables_to_ragged,
)
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.layers.attention import Seqlen
tracer = trace.get_tracer(__name__)
@ -36,11 +39,17 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
image_indices = []
attention_states = []
for b in batches:
if b.cross_attention_states is not None:
attention_states.append(b.cross_attention_states)
image_indices.extend([i + offset for i in b.image_indices])
offset += len(b.image_indices)
if len(attention_states) > 0:
assert len(image_indices) > 0
batch.cross_attention_states = torch.cat(attention_states, dim=0)
batch.image_indices = image_indices
else:
batch.cross_attention_states = None
batch.image_indices = []
return batch
@tracer.start_as_current_span("filter")
@ -64,8 +73,14 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
prev_i = i
batch.image_indices = new_image_indices
batch.cross_attention_states = self.cross_attention_states[indices]
assert offset <= batch.cross_attention_states.shape[0]
if len(new_image_indices) > 0:
assert max(new_image_indices) < self.cross_attention_states.shape[0]
assert offset <= self.cross_attention_states.shape[0]
batch.cross_attention_states = self.cross_attention_states[
new_image_indices
]
else:
batch.cross_attention_states = None
return batch
@classmethod
@ -79,23 +94,20 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
for i, r in enumerate(requests):
# Each input is encoded into a list, where each element of this input list is either a string or a URL
curr_text = ""
has_image = False
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
curr_text += chunk.text
elif chunk_type == "image":
has_image = True
image = Image.open(BytesIO(chunk.image.data))
# TODO unsure about BOS
curr_text += "<|image|>"
image_input = processor.image_processor(image, return_tensors="pt")
image_inputs.append(image_input)
image_indices.append(i)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
texts.append(curr_text)
if has_image:
image_indices.append(i)
input_ids = tokenizer(
curr_text,
@ -124,6 +136,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
else:
image_inputs = None
if image_inputs is not None:
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
return batch_tokenized_inputs, image_inputs
@classmethod
@ -162,3 +177,173 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.image_indices = []
assert batch.image_indices is not None
return batch
class MllamaCausalLM(VlmCausalLM):
def forward(
self,
batch: VlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if (
cu_seqlen_prefill is not None
or cuda_graph is None
# Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None
):
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
image_indices=batch.image_indices[:],
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits