mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
fix: update v3 scheduler and ensure max_batch_size > 0
This commit is contained in:
parent
6497ae61e2
commit
bec657973d
@ -168,7 +168,8 @@ pub(crate) async fn batching_task(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
let max_size =
|
||||||
|
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
|
@ -150,6 +150,14 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
|
if max_batch_size == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_batch_size` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let (backend, _backend_info) = connect_backend(
|
let (backend, _backend_info) = connect_backend(
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
@ -226,6 +226,13 @@ impl State {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(max_size) = max_size {
|
||||||
|
if max_size == 0 {
|
||||||
|
tracing::debug!("No capacity");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Pad prefill_token_budget to be a multiple of block size
|
// Pad prefill_token_budget to be a multiple of block size
|
||||||
let prefill_token_budget =
|
let prefill_token_budget =
|
||||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||||
|
@ -304,7 +304,7 @@ impl State {
|
|||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
|
|
||||||
// Check if max_size
|
// Check if max_size
|
||||||
if Some(batch_requests.len()) >= max_size {
|
if Some(batch_requests.len()) == max_size {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -161,11 +161,8 @@ pub(crate) async fn batching_task(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
let max_size = max_batch_size.map(|max_size| {
|
let max_size =
|
||||||
if batch_size as usize > max_size { 0 } else { max_size - batch_size as usize }
|
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||||
|
Loading…
Reference in New Issue
Block a user