diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 8280795d..d43f789e 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -158,8 +158,8 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], - prefix_len: 0, - postfix_len: truncate, + cache_len: 0, + chunk_len: None, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 39e99776..854a5895 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -246,8 +246,8 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), - prefix_len: 0, - postfix_len: 1, + cache_len: 0, + chunk_len: None, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 804c77d4..fe810f24 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -158,8 +158,8 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], - prefix_len: 0, - postfix_len: truncate, + cache_len: 0, + chunk_len: None, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index e25bf71e..e181cd28 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -235,9 +235,9 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), - prefix_len: 0, + cache_len: 0, adapter_id: None, - postfix_len: 1, + chunk_len: None, }; let batch = Batch { id: u64::MAX, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index a07c725c..36fbed87 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -280,7 +280,7 @@ impl State { continue; } - let (block_allocation, postfix_len) = match &self.block_allocator { + let block_allocation = match &self.block_allocator { None => { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation @@ -297,7 +297,7 @@ impl State { self.entries.push_front((id, entry)); break 'entry_loop; } - (None, entry.request.input_length) + None } Some(block_allocator) => { // If users wants the prefill logprobs, we cannot reuse the cache. @@ -337,7 +337,7 @@ impl State { } }; - let mut postfix_len = entry.request.input_length - block_allocation.prefix_len; + let postfix_len = entry.request.input_length - block_allocation.prefix_len; // Check equality too as if we don't we might end up with a postfix_len = 0 // in the next iteration of the loop @@ -345,9 +345,9 @@ impl State { // Entry is over budget if self.support_chunking { // We support chunking, just set postfix_len to exactly match prefill_token_budget - postfix_len = prefill_token_budget - prefill_tokens; + let chunk_len = prefill_token_budget - prefill_tokens; // Push this entry inside the batch - batch.push((id, entry, Some(block_allocation), postfix_len)); + batch.push((id, entry, Some(block_allocation), Some(chunk_len))); break 'entry_loop; } else { // We don't support chunking, this entry needs to go back to the buffer @@ -363,10 +363,10 @@ impl State { prefill_tokens += postfix_len; - (Some(block_allocation), postfix_len) + Some(block_allocation) } }; - batch.push((id, entry, block_allocation, postfix_len)); + batch.push((id, entry, block_allocation, None)); if Some(batch.len()) == max_size { break; } @@ -395,7 +395,7 @@ impl State { let mut batch_entries = IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - for (id, mut entry, block_allocation, postfix_len) in batch { + for (id, mut entry, block_allocation, chunk_len) in batch { // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); // Add relationships @@ -447,9 +447,9 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, - prefix_len, + cache_len: prefix_len, adapter_id: entry.request.adapter_id.clone(), - postfix_len, + chunk_len, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -582,7 +582,7 @@ mod tests { #[tokio::test] async fn test_append() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -598,7 +598,7 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -606,7 +606,7 @@ mod tests { #[tokio::test] async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -638,7 +638,7 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, false, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -658,7 +658,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, false, None, 0, 2); + let mut state = State::new(false, 1, false, None, 0, 2, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -691,14 +691,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -706,7 +706,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -739,7 +739,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -755,7 +755,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -780,7 +780,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, false, None, 2, 16); + let queue = Queue::new(false, 1, false, None, 2, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -799,7 +799,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry, _) = default_entry(); queue.append(entry); diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 43a84e70..63fc7808 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -158,8 +158,8 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], - prefix_len: 0, - postfix_len: sequence_length, + cache_len: 0, + chunk_len: None, adapter_id: None, }) .collect(); diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index dbe69244..dfbff7e5 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -9,13 +9,16 @@ import subprocess import sys import tempfile import time -from typing import Dict, List, Optional - import docker import pytest +import base64 + +from pathlib import Path +from typing import Dict, List, Optional from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from docker.errors import NotFound from syrupy.extensions.json import JSONSnapshotExtension + from text_generation import AsyncClient from text_generation.types import ( BestOfSequence, @@ -639,3 +642,22 @@ def generate_multi(): return responses return generate_load_inner + + +# TODO fix the server parsser to count inline image tokens correctly +@pytest.fixture +def chicken(): + path = Path(__file__).parent / "images" / "chicken_on_money.png" + + with open(path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.fixture +def cow_beach(): + path = Path(__file__).parent / "images" / "cow_beach.png" + + with open(path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index 52ecaed4..93962eb3 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -1,5 +1,4 @@ import pytest -import base64 @pytest.fixture(scope="module") @@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle): return flash_pali_gemma_handle.client -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): - cow = get_cow_beach() - inputs = f"![]({cow})Where is the cow standing?\n" +async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach): + inputs = f"![]({cow_beach})Where is the cow standing?\n" response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) assert response.generated_text == "beach" @@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): - chicken = get_chicken() - cow_beach = get_cow_beach() +async def test_flash_pali_gemma_two_images( + flash_pali_gemma, response_snapshot, chicken, cow_beach +): response = await flash_pali_gemma.generate( f"caption![]({chicken})![]({cow_beach})\n", max_new_tokens=20, diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index eb573385..e5d08bb7 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -1,5 +1,4 @@ import pytest -import base64 @pytest.fixture(scope="module") @@ -16,22 +15,8 @@ async def idefics(idefics_handle): return idefics_handle.client -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - @pytest.mark.asyncio -async def test_idefics(idefics, response_snapshot): - chicken = get_chicken() +async def test_idefics(idefics, response_snapshot, chicken): response = await idefics.generate( f"User:![]({chicken})Can you tell me a very short story based on the image?", max_new_tokens=10, @@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot): @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_idefics_two_images(idefics, response_snapshot): - chicken = get_chicken() - cow_beach = get_cow_beach() +async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach): response = await idefics.generate( f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", max_new_tokens=20, @@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot): @pytest.mark.release @pytest.mark.asyncio -async def test_idefics_load(idefics, generate_load, response_snapshot): - chicken = get_chicken() +async def test_idefics_load(idefics, generate_load, response_snapshot, chicken): responses = await generate_load( idefics, f"User:![]({chicken})Can you tell me a very short story based on the image?", diff --git a/integration-tests/models/test_idefics2.py b/integration-tests/models/test_idefics2.py index c5f48da3..881e37f9 100644 --- a/integration-tests/models/test_idefics2.py +++ b/integration-tests/models/test_idefics2.py @@ -1,18 +1,4 @@ import pytest -import base64 - - -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" @pytest.fixture(scope="module") @@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot): - chicken = get_chicken() +async def test_flash_idefics2_next_simple( + flash_idefics2_next, response_snapshot, chicken +): response = await flash_idefics2_next.generate( f"User:![]({chicken})Write me a short story \nAssistant:", max_new_tokens=10, @@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot @pytest.mark.asyncio @pytest.mark.private -async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot): - chicken = get_chicken() - cow_beach = get_cow_beach() +async def test_flash_idefics2_two_images( + flash_idefics2_next, response_snapshot, chicken, cow_beach +): response = await flash_idefics2_next.generate( f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", max_new_tokens=20, @@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap @pytest.mark.asyncio @pytest.mark.private async def test_flash_idefics2_next_load( - flash_idefics2_next, generate_load, response_snapshot + flash_idefics2_next, generate_load, response_snapshot, chicken ): - chicken = get_chicken() responses = await generate_load( flash_idefics2_next, f"User:![]({chicken})Write me a short story \nAssistant:", diff --git a/integration-tests/models/test_llava_next.py b/integration-tests/models/test_llava_next.py index ea277d71..1ac8f172 100644 --- a/integration-tests/models/test_llava_next.py +++ b/integration-tests/models/test_llava_next.py @@ -1,12 +1,4 @@ import pytest -import base64 - - -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" @pytest.fixture(scope="module") @@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle): @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): - chicken = get_chicken() +async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken): response = await flash_llava_next.generate( f"User:![]({chicken})Can you tell me a very short story based on the image?", max_new_tokens=10, @@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): @pytest.mark.asyncio @pytest.mark.private async def test_flash_llava_next_load( - flash_llava_next, generate_load, response_snapshot + flash_llava_next, generate_load, response_snapshot, chicken ): - chicken = get_chicken() responses = await generate_load( flash_llava_next, f"User:![]({chicken})Can you tell me a very short story based on the image?", diff --git a/integration-tests/models/test_mllama.py b/integration-tests/models/test_mllama.py index 1b4264aa..02781707 100644 --- a/integration-tests/models/test_mllama.py +++ b/integration-tests/models/test_mllama.py @@ -1,5 +1,4 @@ import pytest -import base64 import asyncio @@ -15,22 +14,8 @@ async def mllama(mllama_handle): return mllama_handle.client -# TODO fix the server parsser to count inline image tokens correctly -def get_chicken(): - with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - -def get_cow_beach(): - with open("integration-tests/images/cow_beach.png", "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string.decode('utf-8')}" - - @pytest.mark.asyncio async def test_mllama_simpl(mllama, response_snapshot): - # chicken = get_chicken() response = await mllama.chat( max_tokens=10, temperature=0.0, diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index e4dfefef..c91e7cc4 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -139,12 +139,14 @@ message Request { repeated uint32 slots = 10; /// LORA adapter index optional string adapter_id = 11; - /// Prefix length that can be retrieved from the KV cache. - uint32 prefix_len = 12; + /// Tokens that can be retrieved from the KV cache. + /// This value is set for the first prefill and never reset + uint32 cache_len = 12; /// Context truncation bool add_special_tokens = 13; - /// Postfix length for prefill chunking - uint32 postfix_len = 14; + /// Chunk of tokens that must be computed for the first prefill + /// This value is set for the first prefill and never reset + optional uint32 chunk_len = 14; } message Batch { diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7e256dcf..8222722a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -280,24 +280,36 @@ class FlashCausalLMBatch(Batch): prompt_length = len(tokenized_input) prompt_lengths.append(prompt_length) - cache_length = r.prefix_len - input_length = r.postfix_len + cache_length = r.cache_len + assert ( cache_length <= prompt_length ), f"Prefix {cache_length} vs input {prompt_length}" if cache_length == prompt_length: assert False, "unreachable" - if cache_length + input_length < prompt_length: - # FIXME: speculate is not supported for context chunking at the moment - assert speculate == 0 - assert get_support_chunking() - assert input_length > 0 - postfix_ids = tokenized_input[cache_length : cache_length + input_length] + # `chunk_len` is an optional field in the protobuf + # It is only set if the model support chunking + if r.HasField("chunk_len"): + input_length = r.chunk_len + + if cache_length + input_length < prompt_length: + # FIXME: speculate is not supported for context chunking at the moment + assert speculate == 0 + assert get_support_chunking() + assert input_length > 0 + + postfix_ids = tokenized_input[ + cache_length : cache_length + input_length + ] + assert ( + len(postfix_ids) == input_length + ), "Rust and Python tokenizers are not aligned" + else: + # Use all the remaining ids + postfix_ids = tokenized_input[cache_length:] + input_length = len(postfix_ids) - assert ( - len(postfix_ids) == input_length - ), "Rust and Python tokenizers are not aligned" input_lengths.append(input_length) prefix_offsets.append(prompt_length - 5) @@ -1097,6 +1109,7 @@ class FlashCausalLM(Model): head_size: Optional[int] = None, skip_special_tokens: bool = True, kv_cache_dtype: Optional[torch.dtype] = None, + support_chunking: bool = True, ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() @@ -1224,7 +1237,7 @@ class FlashCausalLM(Model): rank=rank, world_size=world_size, sliding_window=config.sliding_window, - support_chunking=True, + support_chunking=support_chunking, ) @property diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 3aa475c3..83e44039 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -1,14 +1,17 @@ -from io import BytesIO -from PIL import Image import torch + +import numpy as np + from typing import Iterable, Optional, Tuple, List, Dict from text_generation_server.pb.generate_pb2 import Request - +from io import BytesIO +from PIL import Image from dataclasses import dataclass from opentelemetry import trace from transformers import ( PreTrainedTokenizerBase, ) + from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( @@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( max=config.text_config.vocab_size - 1 ) + if isinstance(batch.input_ids, list): + if len(batch) > 1: + input_ids = np.concatenate(batch.input_ids, dtype=np.int64) + else: + input_ids = batch.input_ids[0] + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) if image_inputs is not None: @@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): class MllamaCausalLM(VlmCausalLM): def forward( self, - batch: VlmCausalLMBatch, + batch: MllamaCausalLMBatch, adapter_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward @@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) - bs = input_ids.shape[0] # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) @@ -269,38 +278,25 @@ class MllamaCausalLM(VlmCausalLM): # Only run cuda graphs when there's no images. or batch.cross_attention_states is not None ): - input_lengths = input_lengths + prefix_lens_tensor if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, + cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (input_lengths + prefix_lens_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, - cache_lengths=prefix_lens_tensor, + cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, + max_q=batch.max_input_length, + max_k=batch.max_current_length, ) - - if batch.pixel_values is not None: - cross_attention_states = self.model.vision_forward( - pixel_values=batch.pixel_values, - aspect_ratio_ids=batch.aspect_ratio_ids, - aspect_ratio_mask=batch.aspect_ratio_mask, - ) - batch.cross_attention_states = cross_attention_states - - cross_attention_states = batch.cross_attention_states - logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -312,14 +308,18 @@ class MllamaCausalLM(VlmCausalLM): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - cross_attention_states=cross_attention_states, - adapter_data=adapter_data, - image_indices=batch.image_indices[:], + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None if batch.pixel_values is not None: batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph @@ -330,22 +330,34 @@ class MllamaCausalLM(VlmCausalLM): block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + cache_lengths=batch.cache_lengths, ) cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables else: cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor - # Replay the graph - cuda_graph["graph"].replay() + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() # Slice output to the correct shape speculative_logits = ( diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index a06add13..150cf0d0 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM): model_id=model_id, revision=revision, trust_remote_code=trust_remote_code, + # FIXME: VLM do not work with context chunking yet + support_chunking=False, **kwargs, ) @@ -356,7 +358,7 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - if PREFIX_CACHING: + if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, @@ -368,13 +370,12 @@ class VlmCausalLM(FlashCausalLM): input_lengths_tensor=input_lengths, cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (input_lengths + cache_lengths_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, + max_q=batch.max_input_length, + max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( input_ids=input_ids, @@ -416,7 +417,10 @@ class VlmCausalLM(FlashCausalLM): cuda_graph["block_tables"][ : block_tables.shape[0], : block_tables.shape[1] ] = block_tables - cuda_graph["slots"].fill_(-1) + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. + cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths