mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix tests
This commit is contained in:
parent
f85a308ef1
commit
b7a1280f25
@ -158,8 +158,8 @@ impl Client {
|
|||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
postfix_len: truncate,
|
chunk_len: None,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
@ -246,8 +246,8 @@ impl Health for ShardedClient {
|
|||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
postfix_len: 1,
|
chunk_len: None,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
|
@ -158,8 +158,8 @@ impl Client {
|
|||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
postfix_len: truncate,
|
chunk_len: None,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
@ -235,9 +235,9 @@ impl Health for ShardedClient {
|
|||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
postfix_len: 1,
|
chunk_len: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
|
@ -280,7 +280,7 @@ impl State {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let (block_allocation, postfix_len) = match &self.block_allocator {
|
let block_allocation = match &self.block_allocator {
|
||||||
None => {
|
None => {
|
||||||
// We pad to max input length in the Python shards
|
// We pad to max input length in the Python shards
|
||||||
// We need to take these padding tokens into the equation
|
// We need to take these padding tokens into the equation
|
||||||
@ -297,7 +297,7 @@ impl State {
|
|||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break 'entry_loop;
|
break 'entry_loop;
|
||||||
}
|
}
|
||||||
(None, entry.request.input_length)
|
None
|
||||||
}
|
}
|
||||||
Some(block_allocator) => {
|
Some(block_allocator) => {
|
||||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||||
@ -337,7 +337,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
||||||
|
|
||||||
// Check equality too as if we don't we might end up with a postfix_len = 0
|
// Check equality too as if we don't we might end up with a postfix_len = 0
|
||||||
// in the next iteration of the loop
|
// in the next iteration of the loop
|
||||||
@ -345,9 +345,9 @@ impl State {
|
|||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
if self.support_chunking {
|
if self.support_chunking {
|
||||||
// We support chunking, just set postfix_len to exactly match prefill_token_budget
|
// We support chunking, just set postfix_len to exactly match prefill_token_budget
|
||||||
postfix_len = prefill_token_budget - prefill_tokens;
|
let chunk_len = prefill_token_budget - prefill_tokens;
|
||||||
// Push this entry inside the batch
|
// Push this entry inside the batch
|
||||||
batch.push((id, entry, Some(block_allocation), postfix_len));
|
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
|
||||||
break 'entry_loop;
|
break 'entry_loop;
|
||||||
} else {
|
} else {
|
||||||
// We don't support chunking, this entry needs to go back to the buffer
|
// We don't support chunking, this entry needs to go back to the buffer
|
||||||
@ -363,10 +363,10 @@ impl State {
|
|||||||
|
|
||||||
prefill_tokens += postfix_len;
|
prefill_tokens += postfix_len;
|
||||||
|
|
||||||
(Some(block_allocation), postfix_len)
|
Some(block_allocation)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
batch.push((id, entry, block_allocation, postfix_len));
|
batch.push((id, entry, block_allocation, None));
|
||||||
if Some(batch.len()) == max_size {
|
if Some(batch.len()) == max_size {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -395,7 +395,7 @@ impl State {
|
|||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
for (id, mut entry, block_allocation, postfix_len) in batch {
|
for (id, mut entry, block_allocation, chunk_len) in batch {
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
// Add relationships
|
// Add relationships
|
||||||
@ -447,9 +447,9 @@ impl State {
|
|||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
prefix_len,
|
cache_len: prefix_len,
|
||||||
adapter_id: entry.request.adapter_id.clone(),
|
adapter_id: entry.request.adapter_id.clone(),
|
||||||
postfix_len,
|
chunk_len,
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
@ -582,7 +582,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_append() {
|
async fn test_append() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
@ -598,7 +598,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_empty() {
|
async fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
|
|
||||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
@ -606,7 +606,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_min_size() {
|
async fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
@ -638,7 +638,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_max_size() {
|
async fn test_next_batch_max_size() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
@ -658,7 +658,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_token_budget() {
|
async fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 2);
|
let mut state = State::new(false, 1, false, None, 0, 2, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
@ -691,14 +691,14 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
@ -706,7 +706,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -739,7 +739,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_max_size() {
|
async fn test_queue_next_batch_max_size() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -755,7 +755,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -780,7 +780,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
let queue = Queue::new(false, 1, false, None, 2, 16);
|
let queue = Queue::new(false, 1, false, None, 2, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -799,7 +799,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
@ -158,8 +158,8 @@ async fn prefill(
|
|||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
postfix_len: sequence_length,
|
chunk_len: None,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
@ -9,13 +9,16 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import docker
|
import docker
|
||||||
import pytest
|
import pytest
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||||
from docker.errors import NotFound
|
from docker.errors import NotFound
|
||||||
from syrupy.extensions.json import JSONSnapshotExtension
|
from syrupy.extensions.json import JSONSnapshotExtension
|
||||||
|
|
||||||
from text_generation import AsyncClient
|
from text_generation import AsyncClient
|
||||||
from text_generation.types import (
|
from text_generation.types import (
|
||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
@ -639,3 +642,22 @@ def generate_multi():
|
|||||||
return responses
|
return responses
|
||||||
|
|
||||||
return generate_load_inner
|
return generate_load_inner
|
||||||
|
|
||||||
|
|
||||||
|
# TODO fix the server parsser to count inline image tokens correctly
|
||||||
|
@pytest.fixture
|
||||||
|
def chicken():
|
||||||
|
path = Path(__file__).parent / "images" / "chicken_on_money.png"
|
||||||
|
|
||||||
|
with open(path, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cow_beach():
|
||||||
|
path = Path(__file__).parent / "images" / "cow_beach.png"
|
||||||
|
|
||||||
|
with open(path, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
|
|||||||
return flash_pali_gemma_handle.client
|
return flash_pali_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):
|
||||||
cow = get_cow_beach()
|
inputs = f"Where is the cow standing?\n"
|
||||||
inputs = f"Where is the cow standing?\n"
|
|
||||||
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
||||||
|
|
||||||
assert response.generated_text == "beach"
|
assert response.generated_text == "beach"
|
||||||
@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
|||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
async def test_flash_pali_gemma_two_images(
|
||||||
chicken = get_chicken()
|
flash_pali_gemma, response_snapshot, chicken, cow_beach
|
||||||
cow_beach = get_cow_beach()
|
):
|
||||||
response = await flash_pali_gemma.generate(
|
response = await flash_pali_gemma.generate(
|
||||||
f"caption\n",
|
f"caption\n",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -16,22 +15,8 @@ async def idefics(idefics_handle):
|
|||||||
return idefics_handle.client
|
return idefics_handle.client
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics(idefics, response_snapshot):
|
async def test_idefics(idefics, response_snapshot, chicken):
|
||||||
chicken = get_chicken()
|
|
||||||
response = await idefics.generate(
|
response = await idefics.generate(
|
||||||
f"User:Can you tell me a very short story based on the image?",
|
f"User:Can you tell me a very short story based on the image?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot):
|
|||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_idefics_two_images(idefics, response_snapshot):
|
async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach):
|
||||||
chicken = get_chicken()
|
|
||||||
cow_beach = get_cow_beach()
|
|
||||||
response = await idefics.generate(
|
response = await idefics.generate(
|
||||||
f"User:Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
f"User:Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
|
|||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
async def test_idefics_load(idefics, generate_load, response_snapshot, chicken):
|
||||||
chicken = get_chicken()
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
idefics,
|
idefics,
|
||||||
f"User:Can you tell me a very short story based on the image?",
|
f"User:Can you tell me a very short story based on the image?",
|
||||||
|
@ -1,18 +1,4 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
|
async def test_flash_idefics2_next_simple(
|
||||||
chicken = get_chicken()
|
flash_idefics2_next, response_snapshot, chicken
|
||||||
|
):
|
||||||
response = await flash_idefics2_next.generate(
|
response = await flash_idefics2_next.generate(
|
||||||
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
|
async def test_flash_idefics2_two_images(
|
||||||
chicken = get_chicken()
|
flash_idefics2_next, response_snapshot, chicken, cow_beach
|
||||||
cow_beach = get_cow_beach()
|
):
|
||||||
response = await flash_idefics2_next.generate(
|
response = await flash_idefics2_next.generate(
|
||||||
f"User:Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
f"User:Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_next_load(
|
async def test_flash_idefics2_next_load(
|
||||||
flash_idefics2_next, generate_load, response_snapshot
|
flash_idefics2_next, generate_load, response_snapshot, chicken
|
||||||
):
|
):
|
||||||
chicken = get_chicken()
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_idefics2_next,
|
flash_idefics2_next,
|
||||||
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
|
@ -1,12 +1,4 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle):
|
|||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken):
|
||||||
chicken = get_chicken()
|
|
||||||
response = await flash_llava_next.generate(
|
response = await flash_llava_next.generate(
|
||||||
f"User:Can you tell me a very short story based on the image?",
|
f"User:Can you tell me a very short story based on the image?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_load(
|
async def test_flash_llava_next_load(
|
||||||
flash_llava_next, generate_load, response_snapshot
|
flash_llava_next, generate_load, response_snapshot, chicken
|
||||||
):
|
):
|
||||||
chicken = get_chicken()
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_llava_next,
|
flash_llava_next,
|
||||||
f"User:Can you tell me a very short story based on the image?",
|
f"User:Can you tell me a very short story based on the image?",
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
@ -15,22 +14,8 @@ async def mllama(mllama_handle):
|
|||||||
return mllama_handle.client
|
return mllama_handle.client
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mllama_simpl(mllama, response_snapshot):
|
async def test_mllama_simpl(mllama, response_snapshot):
|
||||||
# chicken = get_chicken()
|
|
||||||
response = await mllama.chat(
|
response = await mllama.chat(
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
@ -139,12 +139,14 @@ message Request {
|
|||||||
repeated uint32 slots = 10;
|
repeated uint32 slots = 10;
|
||||||
/// LORA adapter index
|
/// LORA adapter index
|
||||||
optional string adapter_id = 11;
|
optional string adapter_id = 11;
|
||||||
/// Prefix length that can be retrieved from the KV cache.
|
/// Tokens that can be retrieved from the KV cache.
|
||||||
uint32 prefix_len = 12;
|
/// This value is set for the first prefill and never reset
|
||||||
|
uint32 cache_len = 12;
|
||||||
/// Context truncation
|
/// Context truncation
|
||||||
bool add_special_tokens = 13;
|
bool add_special_tokens = 13;
|
||||||
/// Postfix length for prefill chunking
|
/// Chunk of tokens that must be computed for the first prefill
|
||||||
uint32 postfix_len = 14;
|
/// This value is set for the first prefill and never reset
|
||||||
|
optional uint32 chunk_len = 14;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
@ -280,24 +280,36 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prompt_length = len(tokenized_input)
|
prompt_length = len(tokenized_input)
|
||||||
prompt_lengths.append(prompt_length)
|
prompt_lengths.append(prompt_length)
|
||||||
|
|
||||||
cache_length = r.prefix_len
|
cache_length = r.cache_len
|
||||||
input_length = r.postfix_len
|
|
||||||
assert (
|
assert (
|
||||||
cache_length <= prompt_length
|
cache_length <= prompt_length
|
||||||
), f"Prefix {cache_length} vs input {prompt_length}"
|
), f"Prefix {cache_length} vs input {prompt_length}"
|
||||||
if cache_length == prompt_length:
|
if cache_length == prompt_length:
|
||||||
assert False, "unreachable"
|
assert False, "unreachable"
|
||||||
if cache_length + input_length < prompt_length:
|
|
||||||
# FIXME: speculate is not supported for context chunking at the moment
|
|
||||||
assert speculate == 0
|
|
||||||
assert get_support_chunking()
|
|
||||||
assert input_length > 0
|
|
||||||
|
|
||||||
postfix_ids = tokenized_input[cache_length : cache_length + input_length]
|
# `chunk_len` is an optional field in the protobuf
|
||||||
|
# It is only set if the model support chunking
|
||||||
|
if r.HasField("chunk_len"):
|
||||||
|
input_length = r.chunk_len
|
||||||
|
|
||||||
|
if cache_length + input_length < prompt_length:
|
||||||
|
# FIXME: speculate is not supported for context chunking at the moment
|
||||||
|
assert speculate == 0
|
||||||
|
assert get_support_chunking()
|
||||||
|
assert input_length > 0
|
||||||
|
|
||||||
|
postfix_ids = tokenized_input[
|
||||||
|
cache_length : cache_length + input_length
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(postfix_ids) == input_length
|
||||||
|
), "Rust and Python tokenizers are not aligned"
|
||||||
|
else:
|
||||||
|
# Use all the remaining ids
|
||||||
|
postfix_ids = tokenized_input[cache_length:]
|
||||||
|
input_length = len(postfix_ids)
|
||||||
|
|
||||||
assert (
|
|
||||||
len(postfix_ids) == input_length
|
|
||||||
), "Rust and Python tokenizers are not aligned"
|
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
prefix_offsets.append(prompt_length - 5)
|
prefix_offsets.append(prompt_length - 5)
|
||||||
@ -1097,6 +1109,7 @@ class FlashCausalLM(Model):
|
|||||||
head_size: Optional[int] = None,
|
head_size: Optional[int] = None,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
|
support_chunking: bool = True,
|
||||||
):
|
):
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
@ -1224,7 +1237,7 @@ class FlashCausalLM(Model):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
sliding_window=config.sliding_window,
|
sliding_window=config.sliding_window,
|
||||||
support_chunking=True,
|
support_chunking=support_chunking,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
from io import BytesIO
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from typing import Iterable, Optional, Tuple, List, Dict
|
from typing import Iterable, Optional, Tuple, List, Dict
|
||||||
from text_generation_server.pb.generate_pb2 import Request
|
from text_generation_server.pb.generate_pb2 import Request
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import (
|
from transformers import (
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
|
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
|
||||||
max=config.text_config.vocab_size - 1
|
max=config.text_config.vocab_size - 1
|
||||||
)
|
)
|
||||||
|
if isinstance(batch.input_ids, list):
|
||||||
|
if len(batch) > 1:
|
||||||
|
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
||||||
|
else:
|
||||||
|
input_ids = batch.input_ids[0]
|
||||||
|
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||||
|
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
class MllamaCausalLM(VlmCausalLM):
|
class MllamaCausalLM(VlmCausalLM):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: VlmCausalLMBatch,
|
batch: MllamaCausalLMBatch,
|
||||||
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
speculative_ids = batch.speculative_ids
|
speculative_ids = batch.speculative_ids
|
||||||
@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
prefix_lens_tensor = (
|
cache_lengths_tensor = (
|
||||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
@ -269,38 +278,25 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
# Only run cuda graphs when there's no images.
|
# Only run cuda graphs when there's no images.
|
||||||
or batch.cross_attention_states is not None
|
or batch.cross_attention_states is not None
|
||||||
):
|
):
|
||||||
input_lengths = input_lengths + prefix_lens_tensor
|
|
||||||
if PREFIX_CACHING:
|
if PREFIX_CACHING:
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
cache_lengths=batch.cache_lengths,
|
||||||
)
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths_tensor=input_lengths,
|
input_lengths_tensor=input_lengths,
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
):
|
):
|
||||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
cache_lengths=prefix_lens_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
max_q=max_s,
|
max_q=batch.max_input_length,
|
||||||
max_k=max_k,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch.pixel_values is not None:
|
|
||||||
cross_attention_states = self.model.vision_forward(
|
|
||||||
pixel_values=batch.pixel_values,
|
|
||||||
aspect_ratio_ids=batch.aspect_ratio_ids,
|
|
||||||
aspect_ratio_mask=batch.aspect_ratio_mask,
|
|
||||||
)
|
|
||||||
batch.cross_attention_states = cross_attention_states
|
|
||||||
|
|
||||||
cross_attention_states = batch.cross_attention_states
|
|
||||||
|
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -312,14 +308,18 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
cross_attention_states=cross_attention_states,
|
pixel_values=batch.pixel_values,
|
||||||
adapter_data=adapter_data,
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
image_indices=batch.image_indices[:],
|
image_sizes=batch.image_sizes,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
if batch.pixel_attention_mask is not None:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
if batch.image_sizes is not None:
|
||||||
|
batch.image_sizes = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
@ -330,22 +330,34 @@ class MllamaCausalLM(VlmCausalLM):
|
|||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
cache_lengths=batch.cache_lengths,
|
||||||
)
|
)
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
else:
|
else:
|
||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
|
|
||||||
|
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||||
|
# so it doesn't matter if we override it with bogus values.
|
||||||
cuda_graph["slots"].fill_(0)
|
cuda_graph["slots"].fill_(0)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
input_lengths + prefix_lens_tensor
|
cuda_graph["cache_lengths"].zero_()
|
||||||
)
|
cuda_graph["cache_lengths"][
|
||||||
|
: cache_lengths_tensor.shape[0]
|
||||||
|
] = cache_lengths_tensor
|
||||||
|
|
||||||
# Replay the graph
|
with self._forward_context(
|
||||||
cuda_graph["graph"].replay()
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||||
|
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||||
|
state=cuda_graph["state"],
|
||||||
|
):
|
||||||
|
# Replay the graph
|
||||||
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
speculative_logits = (
|
speculative_logits = (
|
||||||
|
@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
# FIXME: VLM do not work with context chunking yet
|
||||||
|
support_chunking=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -356,7 +358,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
if PREFIX_CACHING:
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
@ -368,13 +370,12 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
input_lengths_tensor=input_lengths,
|
input_lengths_tensor=input_lengths,
|
||||||
cache_lengths_tensor=cache_lengths_tensor,
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
):
|
):
|
||||||
max_k = (input_lengths + cache_lengths_tensor).max().item()
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
cache_lengths=cache_lengths_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
max_q=max_s,
|
max_q=batch.max_input_length,
|
||||||
max_k=max_k,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -416,7 +417,10 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
|
||||||
|
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||||
|
# so it doesn't matter if we override it with bogus values.
|
||||||
|
cuda_graph["slots"].fill_(0)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
|
Loading…
Reference in New Issue
Block a user