fix rust and python unit-tests

This commit is contained in:
OlivierDehaene 2024-06-11 17:11:16 +02:00
parent 73c3903214
commit 37266e2dbb
12 changed files with 288 additions and 112 deletions

View File

@ -16,4 +16,3 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Secret Scanning - name: Secret Scanning
uses: trufflesecurity/trufflehog@main uses: trufflesecurity/trufflehog@main

View File

@ -1,7 +1,8 @@
use std::sync::{Arc, Mutex}; use std::fmt::Formatter;
use std::sync::{Arc, Mutex, TryLockError};
use thiserror::Error; use thiserror::Error;
#[derive(Debug, Clone)] #[derive(Clone)]
pub(crate) struct BlockAllocation { pub(crate) struct BlockAllocation {
allocated_blocks: Vec<u32>, allocated_blocks: Vec<u32>,
allocated_slots: Vec<u32>, allocated_slots: Vec<u32>,
@ -53,7 +54,19 @@ impl Drop for BlockAllocation {
} }
} }
#[derive(Debug, Clone)] impl std::fmt::Debug for BlockAllocation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BlockAllocation")
.field("allocated_blocks", &self.allocated_blocks.len())
.field("allocated_slots", &self.allocated_slots.len())
.field("required_blocks", &self.required_blocks)
.field("required_slots", &self.required_slots)
.field("block_allocator", &self.block_allocator)
.finish()
}
}
#[derive(Clone)]
pub(crate) struct BlockAllocator { pub(crate) struct BlockAllocator {
free_blocks: Arc<Mutex<Vec<u32>>>, free_blocks: Arc<Mutex<Vec<u32>>>,
block_size: u32, block_size: u32,
@ -129,8 +142,7 @@ impl BlockAllocator {
Err(AllocationError::NotEnoughPages) Err(AllocationError::NotEnoughPages)
} else { } else {
let n_free_blocks = free_blocks.len(); let n_free_blocks = free_blocks.len();
let allocated_blocks = let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks);
free_blocks.split_off(n_free_blocks - clipped_required_blocks);
let allocated_blocks = if repeats != 1 { let allocated_blocks = if repeats != 1 {
let mut allocated_blocks = allocated_blocks.repeat(repeats); let mut allocated_blocks = allocated_blocks.repeat(repeats);
@ -140,9 +152,8 @@ impl BlockAllocator {
allocated_blocks allocated_blocks
}; };
let mut allocated_slots = Vec::with_capacity( let mut allocated_slots =
allocated_blocks.len() * self.block_size as usize * repeats, Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
);
let required_slots = (prompt_tokens + decode_tokens) as usize; let required_slots = (prompt_tokens + decode_tokens) as usize;
@ -166,7 +177,30 @@ impl BlockAllocator {
} }
pub(crate) fn free(&self, blocks: Vec<u32>) { pub(crate) fn free(&self, blocks: Vec<u32>) {
self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks) self.free_blocks
.lock()
.expect("Lock could not be acquired. This is a bug.")
.extend(blocks)
}
}
impl std::fmt::Debug for BlockAllocator {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut d = f.debug_struct("BlockAllocator");
d.field("block_size", &self.block_size)
.field("window_size", &self.window_size);
match self.free_blocks.try_lock() {
Ok(guard) => {
d.field("free_blocks", &(*guard).len());
}
Err(TryLockError::Poisoned(err)) => {
d.field("free_blocks", &(**err.get_ref()).len());
}
Err(TryLockError::WouldBlock) => {
d.field("free_blocks", &format_args!("<locked>"));
}
};
d.finish()
} }
} }

View File

@ -275,7 +275,9 @@ impl State {
if prefill_tokens > prefill_token_budget { if prefill_tokens > prefill_token_budget {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); tracing::debug!(
"Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget}"
);
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
break; break;
} }
@ -456,7 +458,7 @@ mod tests {
let entry = Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: vec![], inputs: vec![],
input_length: 0, input_length: 1,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
parameters: ValidParameters { parameters: ValidParameters {
@ -567,7 +569,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2); let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -689,7 +691,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2, 16); let queue = Queue::new(true, 1, None, 2, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);

View File

@ -256,11 +256,7 @@ async fn prefill(
.expect("ID not found in entries. This is a bug."); .expect("ID not found in entries. This is a bug.");
// Send intermediate responses // Send intermediate responses
if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| { if send_stream_responses(stream_responses, entry).is_err() {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
// Sending failed, remove entry // Sending failed, remove entry
entries entries
.remove(&id) .remove(&id)
@ -405,7 +401,7 @@ async fn filter_batch(
.filter_batch( .filter_batch(
id, id,
updated_requests, updated_requests,
terminated_entries.keys().map(|v| *v).collect(), terminated_entries.keys().copied().collect(),
) )
.await .await
.unwrap() .unwrap()
@ -460,11 +456,14 @@ fn send_terminated_generations(
}; };
// Send responses // Send responses
if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| { let send_result = entry.response_tx.send(Ok(response)).map_err(|err| {
tracing::error!("Entry response channel error."); tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err err
}) { });
if send_result.is_err() {
// The channel is dropped, skip the rest of the messages
continue 'terminated_generations; continue 'terminated_generations;
} }
} }
@ -504,11 +503,7 @@ fn filter_send_ended_generations(
// If the generation has ended for this request, we send the responses to the channel and // If the generation has ended for this request, we send the responses to the channel and
// remove the entry to drop it and free its blocks // remove the entry to drop it and free its blocks
if finished { if finished {
let _ = send_stream_responses(stream_responses, entry).map_err(|err| { let _ = send_stream_responses(stream_responses, entry);
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
});
// Remove from entries and filter // Remove from entries and filter
entries.remove(&id).expect("ID not found in entries. This is a bug."); entries.remove(&id).expect("ID not found in entries. This is a bug.");
return None; return None;
@ -525,7 +520,11 @@ fn send_stream_responses(
entry: &Entry, entry: &Entry,
) -> Result<(), Box<SendError<Result<InferStreamResponse, InferError>>>> { ) -> Result<(), Box<SendError<Result<InferStreamResponse, InferError>>>> {
for response in stream_responses { for response in stream_responses {
entry.response_tx.send(Ok(response))?; entry.response_tx.send(Ok(response)).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
})?;
} }
Ok(()) Ok(())
} }
@ -541,7 +540,7 @@ fn filter_send_update_allocations(
) -> (bool, IntMap<u64, Entry>) { ) -> (bool, IntMap<u64, Entry>) {
let mut updated = false; let mut updated = false;
let ids: Vec<u64> = entries.keys().map(|v| *v).collect(); let ids: Vec<u64> = entries.keys().copied().collect();
let mut terminated_entries = let mut terminated_entries =
IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default());
@ -581,11 +580,7 @@ fn filter_send_update_allocations(
.expect("ID not found in stream_responses. This is a bug."); .expect("ID not found in stream_responses. This is a bug.");
// Send intermediate responses // Send intermediate responses
if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| { if send_stream_responses(stream_response, entry).is_err() {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
// Sending failed, remove entry // Sending failed, remove entry
entries entries
.remove(id) .remove(id)

View File

@ -197,8 +197,10 @@ def test_causal_lm_generate_token_completion_multi(
# Copy stopping_criterias before filtering # Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
next_batch = next_batch.filter( next_batch, _ = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] default_bloom,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[],
) )
for _ in range( for _ in range(
@ -307,15 +309,13 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter( next_batch, _ = next_batch.filter(
default_bloom,
[ [
generate_pb2.UpdatedRequest( generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
id=next_batch.requests[0].id, blocks=[], slots=[] generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
), ],
generate_pb2.UpdatedRequest( [],
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
) )
for _ in range( for _ in range(
@ -339,8 +339,10 @@ def test_batch_concatenate(
== default_bloom_batch.stopping_criterias[0].max_new_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens
) )
next_batch = next_batch.filter( next_batch, _ = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])] default_bloom,
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
[],
) )
for _ in range( for _ in range(

View File

@ -198,8 +198,10 @@ def test_causal_lm_generate_token_completion_multi(
default_multi_requests_causal_lm_batch.stopping_criterias.copy() default_multi_requests_causal_lm_batch.stopping_criterias.copy()
) )
next_batch = next_batch.filter( next_batch, _ = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] default_causal_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[],
) )
for _ in range( for _ in range(
@ -307,15 +309,13 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
) )
next_batch = next_batch.filter( next_batch, _ = next_batch.filter(
default_causal_lm,
[ [
generate_pb2.UpdatedRequest( generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
id=next_batch.requests[0].id, blocks=[], slots=[] generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
), ],
generate_pb2.UpdatedRequest( [],
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
) )
for _ in range( for _ in range(
@ -337,15 +337,12 @@ def test_batch_concatenate(
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
) )
next_batch = next_batch.filter( next_batch, _ = next_batch.filter(
default_causal_lm,
[ [
generate_pb2.UpdatedRequest( generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
id=next_batch.requests[0].id, blocks=[], slots=[] ],
), [],
generate_pb2.UpdatedRequest(
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
) )
for _ in range( for _ in range(

View File

@ -206,8 +206,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
) )
assert generations[1].generated_text.generated_tokens == 5 assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter( next_batch, _ = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] default_seq2seq_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[],
) )
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
@ -341,15 +343,13 @@ def test_batch_concatenate(
) )
assert generations[2].generated_text.generated_tokens == 5 assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter( next_batch, _ = next_batch.filter(
default_seq2seq_lm,
[ [
generate_pb2.UpdatedRequest( generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
id=next_batch.requests[0].id, blocks=[], slots=[] generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
), ],
generate_pb2.UpdatedRequest( [],
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
) )
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
@ -360,8 +360,10 @@ def test_batch_concatenate(
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7 assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter( next_batch, _ = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])] default_seq2seq_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
[],
) )
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)

View File

@ -159,14 +159,48 @@ class CausalLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter( def filter(
self, updated_requests: List[generate_pb2.KeptRequest] self,
) -> Optional["CausalLMBatch"]: model: "CausalLM",
request_ids = [r.id for r in updated_requests] kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["CausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
all_input_ids = self.all_input_ids[idx]
stopping_criteria = self.stopping_criterias[idx]
next_token_chooser = self.next_token_choosers[idx]
if len(request_ids) == 0: # Decode generated tokens
raise ValueError("Batch must have at least one request") output_text, _, _ = model.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
terminated_generations.append(
generate_pb2.TerminatedGeneration(
id=request_id,
generated_text=generate_pb2.GeneratedText(
text=output_text,
generated_tokens=stopping_criteria.current_tokens,
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
seed=seed,
),
)
)
if not kept_requests:
return None, terminated_generations
request_ids = [r.id for r in kept_requests]
if len(request_ids) == len(self): if len(request_ids) == len(self):
return self return self, terminated_generations
keep_indices = [] keep_indices = []
@ -262,7 +296,7 @@ class CausalLMBatch(Batch):
self.padding_right_offset = new_padding_right_offset self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens self.max_tokens = max_tokens
return self return self, terminated_generations
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")

View File

@ -215,15 +215,51 @@ class IdeficsCausalLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter( def filter(
self, updated_requests: List[generate_pb2.KeptRequest] self,
) -> Optional["IdeficsCausalLMBatch"]: model: "IdeficsCausalLM",
request_ids = [r.id for r in updated_requests] kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[
Optional["IdeficsCausalLMBatch"], List[generate_pb2.TerminatedGeneration]
]:
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
all_input_ids = self.all_input_ids[idx]
stopping_criteria = self.stopping_criterias[idx]
next_token_chooser = self.next_token_choosers[idx]
# It deletes requests from the batch. For instance when client lost connection # Decode generated tokens
if len(request_ids) == 0: output_text, _, _ = model.decode_token(
raise ValueError("Batch must have at least one request") all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
terminated_generations.append(
generate_pb2.TerminatedGeneration(
id=request_id,
generated_text=generate_pb2.GeneratedText(
text=output_text,
generated_tokens=stopping_criteria.current_tokens,
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
seed=seed,
),
)
)
if not kept_requests:
return None, terminated_generations
request_ids = [r.id for r in kept_requests]
if len(request_ids) == len(self): if len(request_ids) == len(self):
return self return self, terminated_generations
keep_indices = [] keep_indices = []
@ -330,7 +366,7 @@ class IdeficsCausalLMBatch(Batch):
self.padding_right_offset = new_padding_right_offset self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens self.max_tokens = max_tokens
return self return self, terminated_generations
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")

View File

@ -196,14 +196,48 @@ class MambaBatch(Batch):
) )
def filter( def filter(
self, updated_requests: List[generate_pb2.KeptRequest] self,
) -> Optional["MambaBatch"]: model: "Mamba",
request_ids = [r.id for r in updated_requests] kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["MambaBatch"], List[generate_pb2.TerminatedGeneration]]:
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
all_input_ids = self.all_input_ids[idx]
stopping_criteria = self.stopping_criterias[idx]
next_token_chooser = self.next_token_choosers[idx]
if len(request_ids) == 0: # Decode generated tokens
raise ValueError("Batch must have at least one request") output_text, _, _ = model.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
terminated_generations.append(
generate_pb2.TerminatedGeneration(
id=request_id,
generated_text=generate_pb2.GeneratedText(
text=output_text,
generated_tokens=stopping_criteria.current_tokens,
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
seed=seed,
),
)
)
if not kept_requests:
return None, terminated_generations
request_ids = [r.id for r in kept_requests]
if len(request_ids) == len(self): if len(request_ids) == len(self):
return self return self, terminated_generations
keep_indices = [] keep_indices = []
@ -278,7 +312,7 @@ class MambaBatch(Batch):
:, indices :, indices
] ]
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
return self return self, terminated_generations
@classmethod @classmethod
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":

View File

@ -167,14 +167,49 @@ class Seq2SeqLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter( def filter(
self, updated_requests: List[generate_pb2.KeptRequest] self,
) -> Optional["Seq2SeqLMBatch"]: model: "Seq2SeqLM",
request_ids = [r.id for r in updated_requests] kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["Seq2SeqLMBatch"], List[generate_pb2.TerminatedGeneration]]:
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
all_decoder_input_ids = self.all_decoder_input_ids[idx]
decoder_input_length = self.decoder_input_lengths[idx]
stopping_criteria = self.stopping_criterias[idx]
next_token_chooser = self.next_token_choosers[idx]
if len(request_ids) == 0: # Decode generated tokens
raise ValueError("Batch must have at least one request") output_text, _, _ = model.decode_token(
all_decoder_input_ids,
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
read_offset=len(all_decoder_input_ids) - decoder_input_length,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
terminated_generations.append(
generate_pb2.TerminatedGeneration(
id=request_id,
generated_text=generate_pb2.GeneratedText(
text=output_text,
generated_tokens=stopping_criteria.current_tokens,
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
seed=seed,
),
)
)
if not kept_requests:
return None, terminated_generations
request_ids = [r.id for r in kept_requests]
if len(request_ids) == len(self): if len(request_ids) == len(self):
return self return self, terminated_generations
keep_indices = [] keep_indices = []
@ -281,7 +316,7 @@ class Seq2SeqLMBatch(Batch):
self.padding_right_offset = padding_right_offset self.padding_right_offset = padding_right_offset
self.max_tokens = max_tokens self.max_tokens = max_tokens
return self return self, terminated_generations
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")

View File

@ -123,13 +123,19 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter( def filter(
self, updated_requests: List[generate_pb2.KeptRequest] self,
) -> Optional["VlmCausalLMBatch"]: model: "VlmCausalLM",
batch = super().filter(updated_requests) kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["VlmCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
batch, terminated_generations = super().filter(
model, kept_requests, terminated_request_ids
)
if batch is not None:
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch, terminated_generations
@classmethod @classmethod
def batch_tokenized_inputs( def batch_tokenized_inputs(