fix tests

This commit is contained in:
OlivierDehaene 2024-10-10 14:52:09 +02:00
parent f85a308ef1
commit b7a1280f25
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
16 changed files with 162 additions and 180 deletions

View File

@ -158,8 +158,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
postfix_len: truncate,
cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,

View File

@ -246,8 +246,8 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
postfix_len: 1,
cache_len: 0,
chunk_len: None,
adapter_id: None,
};
let batch = Batch {

View File

@ -158,8 +158,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
postfix_len: truncate,
cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,

View File

@ -235,9 +235,9 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
cache_len: 0,
adapter_id: None,
postfix_len: 1,
chunk_len: None,
};
let batch = Batch {
id: u64::MAX,

View File

@ -280,7 +280,7 @@ impl State {
continue;
}
let (block_allocation, postfix_len) = match &self.block_allocator {
let block_allocation = match &self.block_allocator {
None => {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
@ -297,7 +297,7 @@ impl State {
self.entries.push_front((id, entry));
break 'entry_loop;
}
(None, entry.request.input_length)
None
}
Some(block_allocator) => {
// If users wants the prefill logprobs, we cannot reuse the cache.
@ -337,7 +337,7 @@ impl State {
}
};
let mut postfix_len = entry.request.input_length - block_allocation.prefix_len;
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
// Check equality too as if we don't we might end up with a postfix_len = 0
// in the next iteration of the loop
@ -345,9 +345,9 @@ impl State {
// Entry is over budget
if self.support_chunking {
// We support chunking, just set postfix_len to exactly match prefill_token_budget
postfix_len = prefill_token_budget - prefill_tokens;
let chunk_len = prefill_token_budget - prefill_tokens;
// Push this entry inside the batch
batch.push((id, entry, Some(block_allocation), postfix_len));
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
break 'entry_loop;
} else {
// We don't support chunking, this entry needs to go back to the buffer
@ -363,10 +363,10 @@ impl State {
prefill_tokens += postfix_len;
(Some(block_allocation), postfix_len)
Some(block_allocation)
}
};
batch.push((id, entry, block_allocation, postfix_len));
batch.push((id, entry, block_allocation, None));
if Some(batch.len()) == max_size {
break;
}
@ -395,7 +395,7 @@ impl State {
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
for (id, mut entry, block_allocation, postfix_len) in batch {
for (id, mut entry, block_allocation, chunk_len) in batch {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
@ -447,9 +447,9 @@ impl State {
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
prefix_len,
cache_len: prefix_len,
adapter_id: entry.request.adapter_id.clone(),
postfix_len,
chunk_len,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -582,7 +582,7 @@ mod tests {
#[tokio::test]
async fn test_append() {
let mut state = State::new(false, 1, false, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
@ -598,7 +598,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_empty() {
let mut state = State::new(false, 1, false, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
@ -606,7 +606,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, false, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -638,7 +638,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, false, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -658,7 +658,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, false, None, 0, 2);
let mut state = State::new(false, 1, false, None, 0, 2, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -691,14 +691,14 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -706,7 +706,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -739,7 +739,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -755,7 +755,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -780,7 +780,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, false, None, 2, 16);
let queue = Queue::new(false, 1, false, None, 2, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -799,7 +799,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry, _) = default_entry();
queue.append(entry);

View File

@ -158,8 +158,8 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
prefix_len: 0,
postfix_len: sequence_length,
cache_len: 0,
chunk_len: None,
adapter_id: None,
})
.collect();

View File

@ -9,13 +9,16 @@ import subprocess
import sys
import tempfile
import time
from typing import Dict, List, Optional
import docker
import pytest
import base64
from pathlib import Path
from typing import Dict, List, Optional
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient
from text_generation.types import (
BestOfSequence,
@ -639,3 +642,22 @@ def generate_multi():
return responses
return generate_load_inner
# TODO fix the server parsser to count inline image tokens correctly
@pytest.fixture
def chicken():
path = Path(__file__).parent / "images" / "chicken_on_money.png"
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture
def cow_beach():
path = Path(__file__).parent / "images" / "cow_beach.png"
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"

View File

@ -1,5 +1,4 @@
import pytest
import base64
@pytest.fixture(scope="module")
@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
return flash_pali_gemma_handle.client
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
cow = get_cow_beach()
inputs = f"![]({cow})Where is the cow standing?\n"
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):
inputs = f"![]({cow_beach})Where is the cow standing?\n"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
assert response.generated_text == "beach"
@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
async def test_flash_pali_gemma_two_images(
flash_pali_gemma, response_snapshot, chicken, cow_beach
):
response = await flash_pali_gemma.generate(
f"caption![]({chicken})![]({cow_beach})\n",
max_new_tokens=20,

View File

@ -1,5 +1,4 @@
import pytest
import base64
@pytest.fixture(scope="module")
@ -16,22 +15,8 @@ async def idefics(idefics_handle):
return idefics_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_idefics(idefics, response_snapshot):
chicken = get_chicken()
async def test_idefics(idefics, response_snapshot, chicken):
response = await idefics.generate(
f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10,
@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot):
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_idefics_two_images(idefics, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach):
response = await idefics.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20,
@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
@pytest.mark.release
@pytest.mark.asyncio
async def test_idefics_load(idefics, generate_load, response_snapshot):
chicken = get_chicken()
async def test_idefics_load(idefics, generate_load, response_snapshot, chicken):
responses = await generate_load(
idefics,
f"User:![]({chicken})Can you tell me a very short story based on the image?",

View File

@ -1,18 +1,4 @@
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
async def test_flash_idefics2_next_simple(
flash_idefics2_next, response_snapshot, chicken
):
response = await flash_idefics2_next.generate(
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
async def test_flash_idefics2_two_images(
flash_idefics2_next, response_snapshot, chicken, cow_beach
):
response = await flash_idefics2_next.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20,
@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_load(
flash_idefics2_next, generate_load, response_snapshot
flash_idefics2_next, generate_load, response_snapshot, chicken
):
chicken = get_chicken()
responses = await generate_load(
flash_idefics2_next,
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",

View File

@ -1,12 +1,4 @@
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle):
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
chicken = get_chicken()
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken):
response = await flash_llava_next.generate(
f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10,
@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_load(
flash_llava_next, generate_load, response_snapshot
flash_llava_next, generate_load, response_snapshot, chicken
):
chicken = get_chicken()
responses = await generate_load(
flash_llava_next,
f"User:![]({chicken})Can you tell me a very short story based on the image?",

View File

@ -1,5 +1,4 @@
import pytest
import base64
import asyncio
@ -15,22 +14,8 @@ async def mllama(mllama_handle):
return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat(
max_tokens=10,
temperature=0.0,

View File

@ -139,12 +139,14 @@ message Request {
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
/// Tokens that can be retrieved from the KV cache.
/// This value is set for the first prefill and never reset
uint32 cache_len = 12;
/// Context truncation
bool add_special_tokens = 13;
/// Postfix length for prefill chunking
uint32 postfix_len = 14;
/// Chunk of tokens that must be computed for the first prefill
/// This value is set for the first prefill and never reset
optional uint32 chunk_len = 14;
}
message Batch {

View File

@ -280,24 +280,36 @@ class FlashCausalLMBatch(Batch):
prompt_length = len(tokenized_input)
prompt_lengths.append(prompt_length)
cache_length = r.prefix_len
input_length = r.postfix_len
cache_length = r.cache_len
assert (
cache_length <= prompt_length
), f"Prefix {cache_length} vs input {prompt_length}"
if cache_length == prompt_length:
assert False, "unreachable"
if cache_length + input_length < prompt_length:
# FIXME: speculate is not supported for context chunking at the moment
assert speculate == 0
assert get_support_chunking()
assert input_length > 0
postfix_ids = tokenized_input[cache_length : cache_length + input_length]
# `chunk_len` is an optional field in the protobuf
# It is only set if the model support chunking
if r.HasField("chunk_len"):
input_length = r.chunk_len
if cache_length + input_length < prompt_length:
# FIXME: speculate is not supported for context chunking at the moment
assert speculate == 0
assert get_support_chunking()
assert input_length > 0
postfix_ids = tokenized_input[
cache_length : cache_length + input_length
]
assert (
len(postfix_ids) == input_length
), "Rust and Python tokenizers are not aligned"
else:
# Use all the remaining ids
postfix_ids = tokenized_input[cache_length:]
input_length = len(postfix_ids)
assert (
len(postfix_ids) == input_length
), "Rust and Python tokenizers are not aligned"
input_lengths.append(input_length)
prefix_offsets.append(prompt_length - 5)
@ -1097,6 +1109,7 @@ class FlashCausalLM(Model):
head_size: Optional[int] = None,
skip_special_tokens: bool = True,
kv_cache_dtype: Optional[torch.dtype] = None,
support_chunking: bool = True,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
@ -1224,7 +1237,7 @@ class FlashCausalLM(Model):
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
support_chunking=True,
support_chunking=support_chunking,
)
@property

View File

@ -1,14 +1,17 @@
from io import BytesIO
from PIL import Image
import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1
)
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None:
@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
class MllamaCausalLM(VlmCausalLM):
def forward(
self,
batch: VlmCausalLMBatch,
batch: MllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM):
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM):
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
@ -269,38 +278,25 @@ class MllamaCausalLM(VlmCausalLM):
# Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None
):
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
cache_lengths=batch.cache_lengths,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor,
cache_lengths_tensor=cache_lengths_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=prefix_lens_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -312,14 +308,18 @@ class MllamaCausalLM(VlmCausalLM):
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
image_indices=batch.image_indices[:],
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
@ -330,22 +330,34 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
cache_lengths=batch.cache_lengths,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
# Replay the graph
cuda_graph["graph"].replay()
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (

View File

@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM):
model_id=model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet
support_chunking=False,
**kwargs,
)
@ -356,7 +358,7 @@ class VlmCausalLM(FlashCausalLM):
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
if PREFIX_CACHING:
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
@ -368,13 +370,12 @@ class VlmCausalLM(FlashCausalLM):
input_lengths_tensor=input_lengths,
cache_lengths_tensor=cache_lengths_tensor,
):
max_k = (input_lengths + cache_lengths_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
@ -416,7 +417,10 @@ class VlmCausalLM(FlashCausalLM):
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths