mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
fix rust and python unit-tests
This commit is contained in:
parent
73c3903214
commit
37266e2dbb
1
.github/workflows/trufflehog.yml
vendored
1
.github/workflows/trufflehog.yml
vendored
@ -16,4 +16,3 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Secret Scanning
|
- name: Secret Scanning
|
||||||
uses: trufflesecurity/trufflehog@main
|
uses: trufflesecurity/trufflehog@main
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
use std::sync::{Arc, Mutex};
|
use std::fmt::Formatter;
|
||||||
|
use std::sync::{Arc, Mutex, TryLockError};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct BlockAllocation {
|
pub(crate) struct BlockAllocation {
|
||||||
allocated_blocks: Vec<u32>,
|
allocated_blocks: Vec<u32>,
|
||||||
allocated_slots: Vec<u32>,
|
allocated_slots: Vec<u32>,
|
||||||
@ -53,7 +54,19 @@ impl Drop for BlockAllocation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
impl std::fmt::Debug for BlockAllocation {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("BlockAllocation")
|
||||||
|
.field("allocated_blocks", &self.allocated_blocks.len())
|
||||||
|
.field("allocated_slots", &self.allocated_slots.len())
|
||||||
|
.field("required_blocks", &self.required_blocks)
|
||||||
|
.field("required_slots", &self.required_slots)
|
||||||
|
.field("block_allocator", &self.block_allocator)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub(crate) struct BlockAllocator {
|
pub(crate) struct BlockAllocator {
|
||||||
free_blocks: Arc<Mutex<Vec<u32>>>,
|
free_blocks: Arc<Mutex<Vec<u32>>>,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
@ -129,8 +142,7 @@ impl BlockAllocator {
|
|||||||
Err(AllocationError::NotEnoughPages)
|
Err(AllocationError::NotEnoughPages)
|
||||||
} else {
|
} else {
|
||||||
let n_free_blocks = free_blocks.len();
|
let n_free_blocks = free_blocks.len();
|
||||||
let allocated_blocks =
|
let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks);
|
||||||
free_blocks.split_off(n_free_blocks - clipped_required_blocks);
|
|
||||||
|
|
||||||
let allocated_blocks = if repeats != 1 {
|
let allocated_blocks = if repeats != 1 {
|
||||||
let mut allocated_blocks = allocated_blocks.repeat(repeats);
|
let mut allocated_blocks = allocated_blocks.repeat(repeats);
|
||||||
@ -140,9 +152,8 @@ impl BlockAllocator {
|
|||||||
allocated_blocks
|
allocated_blocks
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut allocated_slots = Vec::with_capacity(
|
let mut allocated_slots =
|
||||||
allocated_blocks.len() * self.block_size as usize * repeats,
|
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
|
||||||
);
|
|
||||||
|
|
||||||
let required_slots = (prompt_tokens + decode_tokens) as usize;
|
let required_slots = (prompt_tokens + decode_tokens) as usize;
|
||||||
|
|
||||||
@ -166,7 +177,30 @@ impl BlockAllocator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||||
self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks)
|
self.free_blocks
|
||||||
|
.lock()
|
||||||
|
.expect("Lock could not be acquired. This is a bug.")
|
||||||
|
.extend(blocks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for BlockAllocator {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let mut d = f.debug_struct("BlockAllocator");
|
||||||
|
d.field("block_size", &self.block_size)
|
||||||
|
.field("window_size", &self.window_size);
|
||||||
|
match self.free_blocks.try_lock() {
|
||||||
|
Ok(guard) => {
|
||||||
|
d.field("free_blocks", &(*guard).len());
|
||||||
|
}
|
||||||
|
Err(TryLockError::Poisoned(err)) => {
|
||||||
|
d.field("free_blocks", &(**err.get_ref()).len());
|
||||||
|
}
|
||||||
|
Err(TryLockError::WouldBlock) => {
|
||||||
|
d.field("free_blocks", &format_args!("<locked>"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
d.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,7 +275,9 @@ impl State {
|
|||||||
if prefill_tokens > prefill_token_budget {
|
if prefill_tokens > prefill_token_budget {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
tracing::debug!(
|
||||||
|
"Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget}"
|
||||||
|
);
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -456,7 +458,7 @@ mod tests {
|
|||||||
let entry = Entry {
|
let entry = Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: vec![],
|
inputs: vec![],
|
||||||
input_length: 0,
|
input_length: 1,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
parameters: ValidParameters {
|
parameters: ValidParameters {
|
||||||
@ -567,7 +569,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_token_budget() {
|
async fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false, 1, None, 0, 2);
|
let mut state = State::new(false, 1, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
@ -689,7 +691,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
let queue = Queue::new(false, 1, None, 2, 16);
|
let queue = Queue::new(true, 1, None, 2, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -256,11 +256,7 @@ async fn prefill(
|
|||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
// Send intermediate responses
|
// Send intermediate responses
|
||||||
if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| {
|
if send_stream_responses(stream_responses, entry).is_err() {
|
||||||
tracing::error!("Entry response channel error.");
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
|
||||||
err
|
|
||||||
}) {
|
|
||||||
// Sending failed, remove entry
|
// Sending failed, remove entry
|
||||||
entries
|
entries
|
||||||
.remove(&id)
|
.remove(&id)
|
||||||
@ -405,7 +401,7 @@ async fn filter_batch(
|
|||||||
.filter_batch(
|
.filter_batch(
|
||||||
id,
|
id,
|
||||||
updated_requests,
|
updated_requests,
|
||||||
terminated_entries.keys().map(|v| *v).collect(),
|
terminated_entries.keys().copied().collect(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -460,11 +456,14 @@ fn send_terminated_generations(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Send responses
|
// Send responses
|
||||||
if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| {
|
let send_result = entry.response_tx.send(Ok(response)).map_err(|err| {
|
||||||
tracing::error!("Entry response channel error.");
|
tracing::error!("Entry response channel error.");
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
err
|
err
|
||||||
}) {
|
});
|
||||||
|
|
||||||
|
if send_result.is_err() {
|
||||||
|
// The channel is dropped, skip the rest of the messages
|
||||||
continue 'terminated_generations;
|
continue 'terminated_generations;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -504,11 +503,7 @@ fn filter_send_ended_generations(
|
|||||||
// If the generation has ended for this request, we send the responses to the channel and
|
// If the generation has ended for this request, we send the responses to the channel and
|
||||||
// remove the entry to drop it and free its blocks
|
// remove the entry to drop it and free its blocks
|
||||||
if finished {
|
if finished {
|
||||||
let _ = send_stream_responses(stream_responses, entry).map_err(|err| {
|
let _ = send_stream_responses(stream_responses, entry);
|
||||||
tracing::error!("Entry response channel error.");
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
|
||||||
err
|
|
||||||
});
|
|
||||||
// Remove from entries and filter
|
// Remove from entries and filter
|
||||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
return None;
|
return None;
|
||||||
@ -525,7 +520,11 @@ fn send_stream_responses(
|
|||||||
entry: &Entry,
|
entry: &Entry,
|
||||||
) -> Result<(), Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
) -> Result<(), Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
for response in stream_responses {
|
for response in stream_responses {
|
||||||
entry.response_tx.send(Ok(response))?;
|
entry.response_tx.send(Ok(response)).map_err(|err| {
|
||||||
|
tracing::error!("Entry response channel error.");
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
|
err
|
||||||
|
})?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -541,7 +540,7 @@ fn filter_send_update_allocations(
|
|||||||
) -> (bool, IntMap<u64, Entry>) {
|
) -> (bool, IntMap<u64, Entry>) {
|
||||||
let mut updated = false;
|
let mut updated = false;
|
||||||
|
|
||||||
let ids: Vec<u64> = entries.keys().map(|v| *v).collect();
|
let ids: Vec<u64> = entries.keys().copied().collect();
|
||||||
let mut terminated_entries =
|
let mut terminated_entries =
|
||||||
IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default());
|
IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
@ -581,11 +580,7 @@ fn filter_send_update_allocations(
|
|||||||
.expect("ID not found in stream_responses. This is a bug.");
|
.expect("ID not found in stream_responses. This is a bug.");
|
||||||
|
|
||||||
// Send intermediate responses
|
// Send intermediate responses
|
||||||
if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| {
|
if send_stream_responses(stream_response, entry).is_err() {
|
||||||
tracing::error!("Entry response channel error.");
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
|
||||||
err
|
|
||||||
}) {
|
|
||||||
// Sending failed, remove entry
|
// Sending failed, remove entry
|
||||||
entries
|
entries
|
||||||
.remove(id)
|
.remove(id)
|
||||||
|
@ -197,8 +197,10 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
# Copy stopping_criterias before filtering
|
# Copy stopping_criterias before filtering
|
||||||
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
||||||
|
|
||||||
next_batch = next_batch.filter(
|
next_batch, _ = next_batch.filter(
|
||||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
|
default_bloom,
|
||||||
|
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
|
||||||
|
[],
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
@ -307,15 +309,13 @@ def test_batch_concatenate(
|
|||||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter(
|
next_batch, _ = next_batch.filter(
|
||||||
|
default_bloom,
|
||||||
[
|
[
|
||||||
generate_pb2.UpdatedRequest(
|
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
|
||||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||||
),
|
],
|
||||||
generate_pb2.UpdatedRequest(
|
[],
|
||||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
@ -339,8 +339,10 @@ def test_batch_concatenate(
|
|||||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter(
|
next_batch, _ = next_batch.filter(
|
||||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])]
|
default_bloom,
|
||||||
|
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
|
||||||
|
[],
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
|
@ -198,8 +198,10 @@ def test_causal_lm_generate_token_completion_multi(
|
|||||||
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter(
|
next_batch, _ = next_batch.filter(
|
||||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
|
default_causal_lm,
|
||||||
|
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
|
||||||
|
[],
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
@ -307,15 +309,13 @@ def test_batch_concatenate(
|
|||||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter(
|
next_batch, _ = next_batch.filter(
|
||||||
|
default_causal_lm,
|
||||||
[
|
[
|
||||||
generate_pb2.UpdatedRequest(
|
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
|
||||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||||
),
|
],
|
||||||
generate_pb2.UpdatedRequest(
|
[],
|
||||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
@ -337,15 +337,12 @@ def test_batch_concatenate(
|
|||||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
next_batch = next_batch.filter(
|
next_batch, _ = next_batch.filter(
|
||||||
|
default_causal_lm,
|
||||||
[
|
[
|
||||||
generate_pb2.UpdatedRequest(
|
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
],
|
||||||
),
|
[],
|
||||||
generate_pb2.UpdatedRequest(
|
|
||||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
|
@ -206,8 +206,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||||||
)
|
)
|
||||||
assert generations[1].generated_text.generated_tokens == 5
|
assert generations[1].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
next_batch = next_batch.filter(
|
next_batch, _ = next_batch.filter(
|
||||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
|
default_seq2seq_lm,
|
||||||
|
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
|
||||||
|
[],
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
@ -341,15 +343,13 @@ def test_batch_concatenate(
|
|||||||
)
|
)
|
||||||
assert generations[2].generated_text.generated_tokens == 5
|
assert generations[2].generated_text.generated_tokens == 5
|
||||||
|
|
||||||
next_batch = next_batch.filter(
|
next_batch, _ = next_batch.filter(
|
||||||
|
default_seq2seq_lm,
|
||||||
[
|
[
|
||||||
generate_pb2.UpdatedRequest(
|
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
|
||||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||||
),
|
],
|
||||||
generate_pb2.UpdatedRequest(
|
[],
|
||||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
@ -360,8 +360,10 @@ def test_batch_concatenate(
|
|||||||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||||
assert generations[0].generated_text.generated_tokens == 7
|
assert generations[0].generated_text.generated_tokens == 7
|
||||||
|
|
||||||
next_batch = next_batch.filter(
|
next_batch, _ = next_batch.filter(
|
||||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])]
|
default_seq2seq_lm,
|
||||||
|
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
|
||||||
|
[],
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||||
|
@ -159,14 +159,48 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
self,
|
||||||
) -> Optional["CausalLMBatch"]:
|
model: "CausalLM",
|
||||||
request_ids = [r.id for r in updated_requests]
|
kept_requests: List[generate_pb2.KeptRequest],
|
||||||
|
terminated_request_ids: List[int],
|
||||||
|
) -> Tuple[Optional["CausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||||
|
terminated_generations = []
|
||||||
|
for request_id in terminated_request_ids:
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
all_input_ids = self.all_input_ids[idx]
|
||||||
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
|
next_token_chooser = self.next_token_choosers[idx]
|
||||||
|
|
||||||
if len(request_ids) == 0:
|
# Decode generated tokens
|
||||||
raise ValueError("Batch must have at least one request")
|
output_text, _, _ = model.decode_token(
|
||||||
|
all_input_ids[:, 0],
|
||||||
|
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||||
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
terminated_generations.append(
|
||||||
|
generate_pb2.TerminatedGeneration(
|
||||||
|
id=request_id,
|
||||||
|
generated_text=generate_pb2.GeneratedText(
|
||||||
|
text=output_text,
|
||||||
|
generated_tokens=stopping_criteria.current_tokens,
|
||||||
|
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not kept_requests:
|
||||||
|
return None, terminated_generations
|
||||||
|
|
||||||
|
request_ids = [r.id for r in kept_requests]
|
||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
return self
|
return self, terminated_generations
|
||||||
|
|
||||||
keep_indices = []
|
keep_indices = []
|
||||||
|
|
||||||
@ -262,7 +296,7 @@ class CausalLMBatch(Batch):
|
|||||||
self.padding_right_offset = new_padding_right_offset
|
self.padding_right_offset = new_padding_right_offset
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
return self
|
return self, terminated_generations
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
|
@ -215,15 +215,51 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
self,
|
||||||
) -> Optional["IdeficsCausalLMBatch"]:
|
model: "IdeficsCausalLM",
|
||||||
request_ids = [r.id for r in updated_requests]
|
kept_requests: List[generate_pb2.KeptRequest],
|
||||||
|
terminated_request_ids: List[int],
|
||||||
|
) -> Tuple[
|
||||||
|
Optional["IdeficsCausalLMBatch"], List[generate_pb2.TerminatedGeneration]
|
||||||
|
]:
|
||||||
|
terminated_generations = []
|
||||||
|
for request_id in terminated_request_ids:
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
all_input_ids = self.all_input_ids[idx]
|
||||||
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
|
next_token_chooser = self.next_token_choosers[idx]
|
||||||
|
|
||||||
# It deletes requests from the batch. For instance when client lost connection
|
# Decode generated tokens
|
||||||
if len(request_ids) == 0:
|
output_text, _, _ = model.decode_token(
|
||||||
raise ValueError("Batch must have at least one request")
|
all_input_ids[:, 0],
|
||||||
|
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||||
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
terminated_generations.append(
|
||||||
|
generate_pb2.TerminatedGeneration(
|
||||||
|
id=request_id,
|
||||||
|
generated_text=generate_pb2.GeneratedText(
|
||||||
|
text=output_text,
|
||||||
|
generated_tokens=stopping_criteria.current_tokens,
|
||||||
|
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not kept_requests:
|
||||||
|
return None, terminated_generations
|
||||||
|
|
||||||
|
request_ids = [r.id for r in kept_requests]
|
||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
return self
|
return self, terminated_generations
|
||||||
|
|
||||||
keep_indices = []
|
keep_indices = []
|
||||||
|
|
||||||
@ -330,7 +366,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
self.padding_right_offset = new_padding_right_offset
|
self.padding_right_offset = new_padding_right_offset
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
return self
|
return self, terminated_generations
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
|
@ -196,14 +196,48 @@ class MambaBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
self,
|
||||||
) -> Optional["MambaBatch"]:
|
model: "Mamba",
|
||||||
request_ids = [r.id for r in updated_requests]
|
kept_requests: List[generate_pb2.KeptRequest],
|
||||||
|
terminated_request_ids: List[int],
|
||||||
|
) -> Tuple[Optional["MambaBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||||
|
terminated_generations = []
|
||||||
|
for request_id in terminated_request_ids:
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
all_input_ids = self.all_input_ids[idx]
|
||||||
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
|
next_token_chooser = self.next_token_choosers[idx]
|
||||||
|
|
||||||
if len(request_ids) == 0:
|
# Decode generated tokens
|
||||||
raise ValueError("Batch must have at least one request")
|
output_text, _, _ = model.decode_token(
|
||||||
|
all_input_ids[:, 0],
|
||||||
|
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||||
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
terminated_generations.append(
|
||||||
|
generate_pb2.TerminatedGeneration(
|
||||||
|
id=request_id,
|
||||||
|
generated_text=generate_pb2.GeneratedText(
|
||||||
|
text=output_text,
|
||||||
|
generated_tokens=stopping_criteria.current_tokens,
|
||||||
|
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not kept_requests:
|
||||||
|
return None, terminated_generations
|
||||||
|
|
||||||
|
request_ids = [r.id for r in kept_requests]
|
||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
return self
|
return self, terminated_generations
|
||||||
|
|
||||||
keep_indices = []
|
keep_indices = []
|
||||||
|
|
||||||
@ -278,7 +312,7 @@ class MambaBatch(Batch):
|
|||||||
:, indices
|
:, indices
|
||||||
]
|
]
|
||||||
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
|
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
|
||||||
return self
|
return self, terminated_generations
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
|
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
|
||||||
|
@ -167,14 +167,49 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
self,
|
||||||
) -> Optional["Seq2SeqLMBatch"]:
|
model: "Seq2SeqLM",
|
||||||
request_ids = [r.id for r in updated_requests]
|
kept_requests: List[generate_pb2.KeptRequest],
|
||||||
|
terminated_request_ids: List[int],
|
||||||
|
) -> Tuple[Optional["Seq2SeqLMBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||||
|
terminated_generations = []
|
||||||
|
for request_id in terminated_request_ids:
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
all_decoder_input_ids = self.all_decoder_input_ids[idx]
|
||||||
|
decoder_input_length = self.decoder_input_lengths[idx]
|
||||||
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
|
next_token_chooser = self.next_token_choosers[idx]
|
||||||
|
|
||||||
if len(request_ids) == 0:
|
# Decode generated tokens
|
||||||
raise ValueError("Batch must have at least one request")
|
output_text, _, _ = model.decode_token(
|
||||||
|
all_decoder_input_ids,
|
||||||
|
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
|
||||||
|
read_offset=len(all_decoder_input_ids) - decoder_input_length,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
terminated_generations.append(
|
||||||
|
generate_pb2.TerminatedGeneration(
|
||||||
|
id=request_id,
|
||||||
|
generated_text=generate_pb2.GeneratedText(
|
||||||
|
text=output_text,
|
||||||
|
generated_tokens=stopping_criteria.current_tokens,
|
||||||
|
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not kept_requests:
|
||||||
|
return None, terminated_generations
|
||||||
|
|
||||||
|
request_ids = [r.id for r in kept_requests]
|
||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
return self
|
return self, terminated_generations
|
||||||
|
|
||||||
keep_indices = []
|
keep_indices = []
|
||||||
|
|
||||||
@ -281,7 +316,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
self.padding_right_offset = padding_right_offset
|
self.padding_right_offset = padding_right_offset
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
return self
|
return self, terminated_generations
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
|
@ -123,13 +123,19 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
self,
|
||||||
) -> Optional["VlmCausalLMBatch"]:
|
model: "VlmCausalLM",
|
||||||
batch = super().filter(updated_requests)
|
kept_requests: List[generate_pb2.KeptRequest],
|
||||||
|
terminated_request_ids: List[int],
|
||||||
|
) -> Tuple[Optional["VlmCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||||
|
batch, terminated_generations = super().filter(
|
||||||
|
model, kept_requests, terminated_request_ids
|
||||||
|
)
|
||||||
|
if batch is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return batch
|
return batch, terminated_generations
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(
|
def batch_tokenized_inputs(
|
||||||
|
Loading…
Reference in New Issue
Block a user