text-generation-inference/server/text_generation_server/models/mllama_causal_lm.py

375 lines
15 KiB
Python
Raw Normal View History

import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
tracer = trace.get_tracer(__name__)
@dataclass
class MllamaCausalLMBatch(VlmCausalLMBatch):
image_indices: List[int] = 42
aspect_ratio_ids: Optional[torch.Tensor] = None
aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super().concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
offset = 0
image_indices = []
attention_states = []
for b in batches:
if b.cross_attention_states is not None:
attention_states.append(b.cross_attention_states)
image_indices.extend([i + offset for i in b.image_indices])
offset += len(b.image_indices)
if len(attention_states) > 0:
assert len(image_indices) > 0
batch.cross_attention_states = torch.cat(attention_states, dim=0)
batch.image_indices = image_indices
else:
batch.cross_attention_states = None
batch.image_indices = []
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super().filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
offset = 0
new_image_indices = []
prev_i = None
for i in self.image_indices:
if i in indices:
new_image_indices.append(offset)
if i != prev_i:
offset += 1
prev_i = i
batch.image_indices = new_image_indices
if len(new_image_indices) > 0:
assert max(new_image_indices) < self.cross_attention_states.shape[0]
assert offset <= self.cross_attention_states.shape[0]
batch.cross_attention_states = self.cross_attention_states[
new_image_indices
]
else:
batch.cross_attention_states = None
return batch
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
image_inputs = []
texts = []
image_indices = []
batch_tokenized_inputs = []
for i, r in enumerate(requests):
# Each input is encoded into a list, where each element of this input list is either a string or a URL
curr_text = ""
curr_image = None
curr_i = None
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
curr_text += chunk.text
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# TODO unsure about BOS
curr_text += "<|image|>"
image_input = processor.image_processor(image, return_tensors="pt")
curr_image = image_input
curr_i = i
# image_inputs.append(image_input)
# image_indices.append(i)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
texts.append(curr_text)
if curr_image is not None:
image_inputs.append(curr_image)
image_indices.append(curr_i)
input_ids = tokenizer(
curr_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
batch_tokenized_inputs.append(input_ids)
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "aspect_ratio_ids" in image_input:
new_image_inputs["aspect_ratio_ids"] = torch.cat(
[img["aspect_ratio_ids"] for img in image_inputs], dim=0
)
if "aspect_ratio_mask" in image_input:
new_image_inputs["aspect_ratio_mask"] = torch.cat(
[img["aspect_ratio_mask"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
image_inputs["image_indices"] = image_indices
else:
image_inputs = None
if image_inputs is not None:
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1
)
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(
device=device, dtype=dtype
)
batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device)
batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(
device=device
)
batch.image_indices = image_inputs["image_indices"]
else:
batch.pixel_values = None
batch.aspect_ratio_ids = None
batch.aspect_ratio_mask = None
batch.image_indices = []
assert batch.image_indices is not None
return batch
class MllamaCausalLM(VlmCausalLM):
def forward(
self,
batch: MllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if (
cu_seqlen_prefill is not None
or cuda_graph is None
# Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None
):
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
cache_lengths=batch.cache_lengths,
input_lengths_tensor=batch.input_lengths_tensor,
cache_lengths_tensor=batch.cache_lengths_tensor,
max_current_length=batch.max_current_length,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
cache_lengths_tensor=cache_lengths_tensor,
):
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
image_indices=batch.image_indices[:],
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
cache_lengths=batch.cache_lengths,
input_lengths_tensor=batch.input_lengths_tensor,
cache_lengths_tensor=batch.cache_lengths_tensor,
max_current_length=batch.max_current_length,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits