mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix launcher
This commit is contained in:
parent
ea4b739a9f
commit
08953c5975
@ -179,18 +179,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
|
||||
// Warmup shard
|
||||
let max_batch_size = batch_size.iter().max().unwrap();
|
||||
sharded_client
|
||||
.warmup(
|
||||
sequence_length,
|
||||
sequence_length * max_batch_size,
|
||||
(sequence_length + decode_length) * max_batch_size,
|
||||
Some(*max_batch_size as usize),
|
||||
)
|
||||
.await
|
||||
.expect("Unable to warmup");
|
||||
|
||||
tracing::info!("Connected");
|
||||
|
||||
// Run app
|
||||
|
@ -1727,12 +1727,6 @@ fn main() -> Result<(), LauncherError> {
|
||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
||||
max_batch_prefill_tokens, max_input_tokens
|
||||
)));
|
||||
}
|
||||
|
||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||
@ -1786,12 +1780,6 @@ fn main() -> Result<(), LauncherError> {
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
max_batch_prefill_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
|
@ -173,9 +173,6 @@ class FlashCausalLMBatch(Batch):
|
||||
# Will be set by `generate_token` and reset after each prefill forward
|
||||
prefill_logprob_tokens: List[Optional[Tokens]]
|
||||
|
||||
# Prefixes
|
||||
prefix_ids: List[List[int]]
|
||||
|
||||
# All tokens
|
||||
all_input_ids: List[List[int]]
|
||||
all_input_ids_tensor: torch.Tensor
|
||||
@ -259,7 +256,6 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets = []
|
||||
all_input_ids = []
|
||||
all_postfix_ids = []
|
||||
prefix_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
@ -297,7 +293,6 @@ class FlashCausalLMBatch(Batch):
|
||||
assert get_support_chunking()
|
||||
assert input_length > 0
|
||||
|
||||
prefix_ids.append(tokenized_input[:cache_length])
|
||||
postfix_ids = tokenized_input[cache_length : cache_length + input_length]
|
||||
|
||||
assert (
|
||||
@ -400,7 +395,6 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
prefix_ids=prefix_ids,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
@ -464,7 +458,6 @@ class FlashCausalLMBatch(Batch):
|
||||
requests = []
|
||||
block_tables = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
input_ids = []
|
||||
|
||||
prompt_lengths = []
|
||||
@ -505,7 +498,6 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
prefix_ids.append(self.prefix_ids[idx])
|
||||
|
||||
prompt_lengths.append(self.prompt_lengths[idx])
|
||||
input_lengths.append(request_input_length)
|
||||
@ -621,7 +613,6 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
prefix_ids=prefix_ids,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
@ -718,7 +709,6 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables = []
|
||||
cache_lengths = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
|
||||
prompt_lengths = []
|
||||
input_lengths = []
|
||||
@ -802,7 +792,6 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables.extend(batch.block_tables)
|
||||
cache_lengths.extend(batch.cache_lengths)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
prefix_ids.extend(batch.prefix_ids)
|
||||
|
||||
prompt_lengths.extend(batch.prompt_lengths)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
@ -873,7 +862,6 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
prefix_ids=prefix_ids,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
@ -1839,6 +1827,8 @@ class FlashCausalLM(Model):
|
||||
batch.input_lengths,
|
||||
batch.all_input_ids,
|
||||
accepted_ids,
|
||||
current_prefilling_mask,
|
||||
batch.prefilling_mask,
|
||||
)
|
||||
|
||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
||||
@ -1855,6 +1845,8 @@ class FlashCausalLM(Model):
|
||||
input_length,
|
||||
all_input_ids,
|
||||
n_accepted_ids,
|
||||
request_was_prefilling,
|
||||
request_is_prefilling,
|
||||
) in enumerate(iterator):
|
||||
# Indexing metadata
|
||||
start_index = cumulative_length
|
||||
@ -1864,7 +1856,6 @@ class FlashCausalLM(Model):
|
||||
# Indexing metadata
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
out_length = out_end_index - out_start_index
|
||||
|
||||
if finished_prefilling:
|
||||
# Initialize position_ids
|
||||
@ -1880,21 +1871,25 @@ class FlashCausalLM(Model):
|
||||
# Used to gather prefill logprobs
|
||||
# Copy batch.input_ids to prefill_token_indices
|
||||
if prefill_logprobs:
|
||||
# If the request was prefilling and cache_length == 0, the first token is a bogus token
|
||||
# and needs to be removed. We do so by incrementing the start_index
|
||||
if request_was_prefilling and cache_length == 0:
|
||||
start_index += 1
|
||||
|
||||
# If the request was prefilling, and it is done prefilling, the last token was generated and is
|
||||
# therefore not part of the prefill. We remove it by decrementing out_end_index
|
||||
if request_was_prefilling and not request_is_prefilling:
|
||||
out_end_index -= 1
|
||||
|
||||
if len(batch) > 1:
|
||||
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
||||
batch.input_ids[start_index + 1 : start_index + out_length]
|
||||
prefill_tokens_indices[out_start_index:out_end_index] = (
|
||||
batch.input_ids[start_index:end_index]
|
||||
)
|
||||
else:
|
||||
# Set prefill_tokens_indices to the correct slice
|
||||
prefill_tokens_indices = batch.input_ids[
|
||||
start_index + 1 : start_index + out_length
|
||||
]
|
||||
prefill_tokens_indices = batch.input_ids[start_index:end_index]
|
||||
|
||||
# Represent whether this request is still prefilling
|
||||
# If it is, the tokens we decoded should be ignored
|
||||
accept_tokens = cache_length + input_length >= prompt_length
|
||||
|
||||
if accept_tokens:
|
||||
if not request_is_prefilling:
|
||||
# Only save tokens if we are done prefilling for this request
|
||||
for j in range(n_accepted_ids):
|
||||
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
|
||||
@ -1995,7 +1990,6 @@ class FlashCausalLM(Model):
|
||||
batch.read_offsets,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.prefix_ids,
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
batch.top_n_tokens,
|
||||
@ -2019,7 +2013,6 @@ class FlashCausalLM(Model):
|
||||
read_offset,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
prefix_ids,
|
||||
do_sample,
|
||||
seed,
|
||||
top_n_tokens,
|
||||
@ -2039,19 +2032,30 @@ class FlashCausalLM(Model):
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
# log_master(logger.info, f"{prefill_logprobs}")
|
||||
|
||||
if not request_is_prefilling:
|
||||
# If the request is done prefilling, then the last logprob is a generated token
|
||||
# We need to remove it
|
||||
out_end_index -= 1
|
||||
|
||||
request_prefill_logprobs = prefill_logprobs[
|
||||
out_start_index : out_end_index - 1
|
||||
out_start_index:out_end_index
|
||||
]
|
||||
prefill_token_ids = all_input_ids[
|
||||
cache_length : cache_length + input_length
|
||||
]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
|
||||
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
|
||||
|
||||
if past_prefill_logprob_tokens is None:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
request_prefill_logprobs = [float("nan")] * (
|
||||
len(prefix_ids) + 1
|
||||
cache_length + 1
|
||||
) + request_prefill_logprobs
|
||||
prefill_token_ids = prefix_ids + prefill_token_ids
|
||||
prefill_token_ids = (
|
||||
all_input_ids[:cache_length] + prefill_token_ids
|
||||
)
|
||||
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
@ -2059,6 +2063,10 @@ class FlashCausalLM(Model):
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# log_master(logger.info, f"{prefill_token_ids}")
|
||||
# log_master(logger.info, f"{request_prefill_logprobs}")
|
||||
# log_master(logger.info, f"{prefill_texts}")
|
||||
|
||||
prefill_logprob_tokens = Tokens(
|
||||
prefill_token_ids,
|
||||
request_prefill_logprobs,
|
||||
|
Loading…
Reference in New Issue
Block a user