Working state.

This commit is contained in:
Nicolas Patry 2024-09-28 22:10:10 +02:00
parent ef4fa3ea7c
commit 2ac607a215
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
4 changed files with 212 additions and 55 deletions

View File

@ -76,6 +76,7 @@ FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM, FlashDeepseekV2ForCausalLM,
DeepseekV2Config, DeepseekV2Config,
@ -1138,7 +1139,7 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == MLLAMA: if model_type == MLLAMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return VlmCausalLM( return MllamaCausalLM(
model_id=model_id, model_id=model_id,
model_class=MllamaForConditionalGeneration, model_class=MllamaForConditionalGeneration,
batch_class=MllamaCausalLMBatch, batch_class=MllamaCausalLMBatch,

View File

@ -451,7 +451,6 @@ class FlashLlamaLayer(nn.Module):
max_s, max_s,
adapter_data, adapter_data,
cross_attention_states, cross_attention_states,
image_indices,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -576,7 +575,6 @@ class FlashLlamaModel(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
adapter_data, adapter_data,
cross_attention_states=None, cross_attention_states=None,
image_indices=None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -601,7 +599,6 @@ class FlashLlamaModel(torch.nn.Module):
max_s, max_s,
adapter_data, adapter_data,
cross_attention_states, cross_attention_states,
image_indices,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -649,7 +646,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None, cross_attention_states=None,
image_indices=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
@ -665,7 +661,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data, adapter_data=adapter_data,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
image_indices=image_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -856,20 +856,17 @@ class FlashLlamaCrossLayer(torch.nn.Module):
max_s, max_s,
adapter_data, adapter_data,
cross_attention_states, # [ IB, ...] cross_attention_states, # [ IB, ...]
image_indices,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None: if cross_attention_states is None:
return hidden_states, residual return hidden_states, residual
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
# indices = cross_attention_states[-1] indices = cross_attention_states[-1]
if cu_seqlen_prefill is not None: out_hidden_states = hidden_states[:]
out_hidden_states = hidden_states[:] if len(indices) > 0:
hidden_states = hidden_states[:] assert max(indices) < hidden_states.shape[0]
else: hidden_states = hidden_states[indices]
out_hidden_states = hidden_states[:]
hidden_states = hidden_states[image_indices]
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -885,10 +882,7 @@ class FlashLlamaCrossLayer(torch.nn.Module):
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
if cu_seqlen_prefill is not None: out_hidden_states[indices] = hidden_states
out_hidden_states[:] = hidden_states
else:
out_hidden_states[image_indices] = hidden_states
hidden_states = out_hidden_states hidden_states = out_hidden_states
return hidden_states, None return hidden_states, None
@ -971,40 +965,24 @@ class MllamaForConditionalGeneration(nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor], # XXX: Putting these as optional so that the cuda warmup calls can go through.
image_indices, cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
): ):
# if cross_att_sention_mask is not None:
# cross_attention_mask, full_text_row_masked_out_mask = (
# _prepare_cross_attention_mask(
# cross_attention_mask,
# num_vision_tokens=self.vision_model.num_patches,
# dtype=self.dtype,
# )
# )
# else:
# full_text_row_masked_out_mask = None
# if cross_attention_mask is not None and cache_position is not None:
# cross_attention_mask = cross_attention_mask[:, :, cache_position]
# full_text_row_masked_out_mask = full_text_row_masked_out_mask[
# :, :, cache_position
# ]
if cross_attention_states is not None: if cross_attention_states is not None:
seqlen_q = len(image_indices) seqlen_q = len(image_indices)
n_images = cross_attention_states.shape[0] n_images = cross_attention_states.shape[0]
seqlen_k = cross_attention_states.shape[1] seqlen_k = cross_attention_states.shape[1]
device = cross_attention_states.device device = cross_attention_states.device
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# raise RuntimeError("TODO")
offset = 0 offset = 0
cu_q = [] cu_q = []
indices = [] indices = []
for index in image_indices: for index in image_indices:
cu_q.append(offset) cu_q.append(offset)
length = seqlen.input_lengths[index] length = seqlen.input_lengths[index].item()
assert index < seqlen.cu_seqlen_q.shape[0]
input_ids_offset = seqlen.cu_seqlen_q[index] input_ids_offset = seqlen.cu_seqlen_q[index]
indices.extend(range(input_ids_offset, input_ids_offset + length)) indices.extend(range(input_ids_offset, input_ids_offset + length))
offset += length offset += length
@ -1061,8 +1039,6 @@ class MllamaForConditionalGeneration(nn.Module):
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
adapter_data=adapter_data, adapter_data=adapter_data,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
# cross_attention_mask=cross_attention_mask,
image_indices=image_indices,
) )
return outputs return outputs

View File

@ -1,7 +1,7 @@
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
import torch import torch
from typing import Iterable, List from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request from text_generation_server.pb.generate_pb2 import Request
from dataclasses import dataclass from dataclasses import dataclass
@ -9,10 +9,13 @@ from opentelemetry import trace
from transformers import ( from transformers import (
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from typing import Optional from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
block_tables_to_ragged,
)
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.layers.attention import Seqlen
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -36,11 +39,17 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
image_indices = [] image_indices = []
attention_states = [] attention_states = []
for b in batches: for b in batches:
attention_states.append(b.cross_attention_states) if b.cross_attention_states is not None:
attention_states.append(b.cross_attention_states)
image_indices.extend([i + offset for i in b.image_indices]) image_indices.extend([i + offset for i in b.image_indices])
offset += len(b.image_indices) offset += len(b.image_indices)
batch.cross_attention_states = torch.cat(attention_states, dim=0) if len(attention_states) > 0:
batch.image_indices = image_indices assert len(image_indices) > 0
batch.cross_attention_states = torch.cat(attention_states, dim=0)
batch.image_indices = image_indices
else:
batch.cross_attention_states = None
batch.image_indices = []
return batch return batch
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -64,8 +73,14 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
prev_i = i prev_i = i
batch.image_indices = new_image_indices batch.image_indices = new_image_indices
batch.cross_attention_states = self.cross_attention_states[indices] if len(new_image_indices) > 0:
assert offset <= batch.cross_attention_states.shape[0] assert max(new_image_indices) < self.cross_attention_states.shape[0]
assert offset <= self.cross_attention_states.shape[0]
batch.cross_attention_states = self.cross_attention_states[
new_image_indices
]
else:
batch.cross_attention_states = None
return batch return batch
@classmethod @classmethod
@ -79,23 +94,20 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
for i, r in enumerate(requests): for i, r in enumerate(requests):
# Each input is encoded into a list, where each element of this input list is either a string or a URL # Each input is encoded into a list, where each element of this input list is either a string or a URL
curr_text = "" curr_text = ""
has_image = False
for chunk in r.input_chunks.chunks: for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk") chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text": if chunk_type == "text":
curr_text += chunk.text curr_text += chunk.text
elif chunk_type == "image": elif chunk_type == "image":
has_image = True
image = Image.open(BytesIO(chunk.image.data)) image = Image.open(BytesIO(chunk.image.data))
# TODO unsure about BOS # TODO unsure about BOS
curr_text += "<|image|>" curr_text += "<|image|>"
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
image_inputs.append(image_input) image_inputs.append(image_input)
image_indices.append(i)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
texts.append(curr_text) texts.append(curr_text)
if has_image:
image_indices.append(i)
input_ids = tokenizer( input_ids = tokenizer(
curr_text, curr_text,
@ -124,6 +136,9 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
else: else:
image_inputs = None image_inputs = None
if image_inputs is not None:
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs
@classmethod @classmethod
@ -162,3 +177,173 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.image_indices = [] batch.image_indices = []
assert batch.image_indices is not None assert batch.image_indices is not None
return batch return batch
class MllamaCausalLM(VlmCausalLM):
def forward(
self,
batch: VlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if (
cu_seqlen_prefill is not None
or cuda_graph is None
# Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None
):
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
image_indices=batch.image_indices[:],
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits