mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix tests
This commit is contained in:
parent
f85a308ef1
commit
b7a1280f25
@ -158,8 +158,8 @@ impl Client {
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
postfix_len: truncate,
|
||||
cache_len: 0,
|
||||
chunk_len: None,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
@ -246,8 +246,8 @@ impl Health for ShardedClient {
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
postfix_len: 1,
|
||||
cache_len: 0,
|
||||
chunk_len: None,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
|
@ -158,8 +158,8 @@ impl Client {
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
postfix_len: truncate,
|
||||
cache_len: 0,
|
||||
chunk_len: None,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
@ -235,9 +235,9 @@ impl Health for ShardedClient {
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
cache_len: 0,
|
||||
adapter_id: None,
|
||||
postfix_len: 1,
|
||||
chunk_len: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
|
@ -280,7 +280,7 @@ impl State {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (block_allocation, postfix_len) = match &self.block_allocator {
|
||||
let block_allocation = match &self.block_allocator {
|
||||
None => {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
@ -297,7 +297,7 @@ impl State {
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
(None, entry.request.input_length)
|
||||
None
|
||||
}
|
||||
Some(block_allocator) => {
|
||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||
@ -337,7 +337,7 @@ impl State {
|
||||
}
|
||||
};
|
||||
|
||||
let mut postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
||||
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
||||
|
||||
// Check equality too as if we don't we might end up with a postfix_len = 0
|
||||
// in the next iteration of the loop
|
||||
@ -345,9 +345,9 @@ impl State {
|
||||
// Entry is over budget
|
||||
if self.support_chunking {
|
||||
// We support chunking, just set postfix_len to exactly match prefill_token_budget
|
||||
postfix_len = prefill_token_budget - prefill_tokens;
|
||||
let chunk_len = prefill_token_budget - prefill_tokens;
|
||||
// Push this entry inside the batch
|
||||
batch.push((id, entry, Some(block_allocation), postfix_len));
|
||||
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
|
||||
break 'entry_loop;
|
||||
} else {
|
||||
// We don't support chunking, this entry needs to go back to the buffer
|
||||
@ -363,10 +363,10 @@ impl State {
|
||||
|
||||
prefill_tokens += postfix_len;
|
||||
|
||||
(Some(block_allocation), postfix_len)
|
||||
Some(block_allocation)
|
||||
}
|
||||
};
|
||||
batch.push((id, entry, block_allocation, postfix_len));
|
||||
batch.push((id, entry, block_allocation, None));
|
||||
if Some(batch.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
@ -395,7 +395,7 @@ impl State {
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
for (id, mut entry, block_allocation, postfix_len) in batch {
|
||||
for (id, mut entry, block_allocation, chunk_len) in batch {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
@ -447,9 +447,9 @@ impl State {
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len,
|
||||
cache_len: prefix_len,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
postfix_len,
|
||||
chunk_len,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
@ -582,7 +582,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_append() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
@ -598,7 +598,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
@ -606,7 +606,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
@ -638,7 +638,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
@ -658,7 +658,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 2);
|
||||
let mut state = State::new(false, 1, false, None, 0, 2, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
@ -691,14 +691,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
@ -706,7 +706,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -739,7 +739,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -755,7 +755,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -780,7 +780,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, false, None, 2, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 2, 16, false);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -799,7 +799,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
@ -158,8 +158,8 @@ async fn prefill(
|
||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
postfix_len: sequence_length,
|
||||
cache_len: 0,
|
||||
chunk_len: None,
|
||||
adapter_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
@ -9,13 +9,16 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import docker
|
||||
import pytest
|
||||
import base64
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||
from docker.errors import NotFound
|
||||
from syrupy.extensions.json import JSONSnapshotExtension
|
||||
|
||||
from text_generation import AsyncClient
|
||||
from text_generation.types import (
|
||||
BestOfSequence,
|
||||
@ -639,3 +642,22 @@ def generate_multi():
|
||||
return responses
|
||||
|
||||
return generate_load_inner
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
@pytest.fixture
|
||||
def chicken():
|
||||
path = Path(__file__).parent / "images" / "chicken_on_money.png"
|
||||
|
||||
with open(path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cow_beach():
|
||||
path = Path(__file__).parent / "images" / "cow_beach.png"
|
||||
|
||||
with open(path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
|
||||
return flash_pali_gemma_handle.client
|
||||
|
||||
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||
cow = get_cow_beach()
|
||||
inputs = f"Where is the cow standing?\n"
|
||||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):
|
||||
inputs = f"Where is the cow standing?\n"
|
||||
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
||||
|
||||
assert response.generated_text == "beach"
|
||||
@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
cow_beach = get_cow_beach()
|
||||
async def test_flash_pali_gemma_two_images(
|
||||
flash_pali_gemma, response_snapshot, chicken, cow_beach
|
||||
):
|
||||
response = await flash_pali_gemma.generate(
|
||||
f"caption\n",
|
||||
max_new_tokens=20,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -16,22 +15,8 @@ async def idefics(idefics_handle):
|
||||
return idefics_handle.client
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idefics(idefics, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
async def test_idefics(idefics, response_snapshot, chicken):
|
||||
response = await idefics.generate(
|
||||
f"User:Can you tell me a very short story based on the image?",
|
||||
max_new_tokens=10,
|
||||
@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot):
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_idefics_two_images(idefics, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
cow_beach = get_cow_beach()
|
||||
async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach):
|
||||
response = await idefics.generate(
|
||||
f"User:Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||
max_new_tokens=20,
|
||||
@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
async def test_idefics_load(idefics, generate_load, response_snapshot, chicken):
|
||||
responses = await generate_load(
|
||||
idefics,
|
||||
f"User:Can you tell me a very short story based on the image?",
|
||||
|
@ -1,18 +1,4 @@
|
||||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
async def test_flash_idefics2_next_simple(
|
||||
flash_idefics2_next, response_snapshot, chicken
|
||||
):
|
||||
response = await flash_idefics2_next.generate(
|
||||
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
||||
max_new_tokens=10,
|
||||
@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
cow_beach = get_cow_beach()
|
||||
async def test_flash_idefics2_two_images(
|
||||
flash_idefics2_next, response_snapshot, chicken, cow_beach
|
||||
):
|
||||
response = await flash_idefics2_next.generate(
|
||||
f"User:Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||
max_new_tokens=20,
|
||||
@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_idefics2_next_load(
|
||||
flash_idefics2_next, generate_load, response_snapshot
|
||||
flash_idefics2_next, generate_load, response_snapshot, chicken
|
||||
):
|
||||
chicken = get_chicken()
|
||||
responses = await generate_load(
|
||||
flash_idefics2_next,
|
||||
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
||||
|
@ -1,12 +1,4 @@
|
||||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle):
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken):
|
||||
response = await flash_llava_next.generate(
|
||||
f"User:Can you tell me a very short story based on the image?",
|
||||
max_new_tokens=10,
|
||||
@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_load(
|
||||
flash_llava_next, generate_load, response_snapshot
|
||||
flash_llava_next, generate_load, response_snapshot, chicken
|
||||
):
|
||||
chicken = get_chicken()
|
||||
responses = await generate_load(
|
||||
flash_llava_next,
|
||||
f"User:Can you tell me a very short story based on the image?",
|
||||
|
@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
|
||||
@ -15,22 +14,8 @@ async def mllama(mllama_handle):
|
||||
return mllama_handle.client
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mllama_simpl(mllama, response_snapshot):
|
||||
# chicken = get_chicken()
|
||||
response = await mllama.chat(
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
|
@ -139,12 +139,14 @@ message Request {
|
||||
repeated uint32 slots = 10;
|
||||
/// LORA adapter index
|
||||
optional string adapter_id = 11;
|
||||
/// Prefix length that can be retrieved from the KV cache.
|
||||
uint32 prefix_len = 12;
|
||||
/// Tokens that can be retrieved from the KV cache.
|
||||
/// This value is set for the first prefill and never reset
|
||||
uint32 cache_len = 12;
|
||||
/// Context truncation
|
||||
bool add_special_tokens = 13;
|
||||
/// Postfix length for prefill chunking
|
||||
uint32 postfix_len = 14;
|
||||
/// Chunk of tokens that must be computed for the first prefill
|
||||
/// This value is set for the first prefill and never reset
|
||||
optional uint32 chunk_len = 14;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
@ -280,24 +280,36 @@ class FlashCausalLMBatch(Batch):
|
||||
prompt_length = len(tokenized_input)
|
||||
prompt_lengths.append(prompt_length)
|
||||
|
||||
cache_length = r.prefix_len
|
||||
input_length = r.postfix_len
|
||||
cache_length = r.cache_len
|
||||
|
||||
assert (
|
||||
cache_length <= prompt_length
|
||||
), f"Prefix {cache_length} vs input {prompt_length}"
|
||||
if cache_length == prompt_length:
|
||||
assert False, "unreachable"
|
||||
|
||||
# `chunk_len` is an optional field in the protobuf
|
||||
# It is only set if the model support chunking
|
||||
if r.HasField("chunk_len"):
|
||||
input_length = r.chunk_len
|
||||
|
||||
if cache_length + input_length < prompt_length:
|
||||
# FIXME: speculate is not supported for context chunking at the moment
|
||||
assert speculate == 0
|
||||
assert get_support_chunking()
|
||||
assert input_length > 0
|
||||
|
||||
postfix_ids = tokenized_input[cache_length : cache_length + input_length]
|
||||
|
||||
postfix_ids = tokenized_input[
|
||||
cache_length : cache_length + input_length
|
||||
]
|
||||
assert (
|
||||
len(postfix_ids) == input_length
|
||||
), "Rust and Python tokenizers are not aligned"
|
||||
else:
|
||||
# Use all the remaining ids
|
||||
postfix_ids = tokenized_input[cache_length:]
|
||||
input_length = len(postfix_ids)
|
||||
|
||||
input_lengths.append(input_length)
|
||||
|
||||
prefix_offsets.append(prompt_length - 5)
|
||||
@ -1097,6 +1109,7 @@ class FlashCausalLM(Model):
|
||||
head_size: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||
support_chunking: bool = True,
|
||||
):
|
||||
self.quantize = quantize
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
@ -1224,7 +1237,7 @@ class FlashCausalLM(Model):
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=config.sliding_window,
|
||||
support_chunking=True,
|
||||
support_chunking=support_chunking,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -1,14 +1,17 @@
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Iterable, Optional, Tuple, List, Dict
|
||||
from text_generation_server.pb.generate_pb2 import Request
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
|
||||
max=config.text_config.vocab_size - 1
|
||||
)
|
||||
if isinstance(batch.input_ids, list):
|
||||
if len(batch) > 1:
|
||||
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = batch.input_ids[0]
|
||||
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||
|
||||
if image_inputs is not None:
|
||||
@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
class MllamaCausalLM(VlmCausalLM):
|
||||
def forward(
|
||||
self,
|
||||
batch: VlmCausalLMBatch,
|
||||
batch: MllamaCausalLMBatch,
|
||||
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
prefix_lens_tensor = (
|
||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
cache_lengths_tensor = (
|
||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
).reshape(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||
max_s = batch.max_seqlen
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
# Try to find an associated cuda graph
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
@ -269,38 +278,25 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
# Only run cuda graphs when there's no images.
|
||||
or batch.cross_attention_states is not None
|
||||
):
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
if PREFIX_CACHING:
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths_tensor=input_lengths,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=prefix_lens_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
max_q=batch.max_input_length,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
|
||||
if batch.pixel_values is not None:
|
||||
cross_attention_states = self.model.vision_forward(
|
||||
pixel_values=batch.pixel_values,
|
||||
aspect_ratio_ids=batch.aspect_ratio_ids,
|
||||
aspect_ratio_mask=batch.aspect_ratio_mask,
|
||||
)
|
||||
batch.cross_attention_states = cross_attention_states
|
||||
|
||||
cross_attention_states = batch.cross_attention_states
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -312,14 +308,18 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
cross_attention_states=cross_attention_states,
|
||||
adapter_data=adapter_data,
|
||||
image_indices=batch.image_indices[:],
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
@ -330,20 +330,32 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
|
||||
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||
# so it doesn't matter if we override it with bogus values.
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + prefix_lens_tensor
|
||||
)
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||
state=cuda_graph["state"],
|
||||
):
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
|
@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# FIXME: VLM do not work with context chunking yet
|
||||
support_chunking=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -356,7 +358,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
else:
|
||||
cuda_graph = None
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
if PREFIX_CACHING:
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
@ -368,13 +370,12 @@ class VlmCausalLM(FlashCausalLM):
|
||||
input_lengths_tensor=input_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
max_k = (input_lengths + cache_lengths_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
max_q=batch.max_input_length,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -416,7 +417,10 @@ class VlmCausalLM(FlashCausalLM):
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
|
||||
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||
# so it doesn't matter if we override it with bogus values.
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
Loading…
Reference in New Issue
Block a user