mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix launcher
This commit is contained in:
parent
ea4b739a9f
commit
08953c5975
@ -179,18 +179,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
.await
|
.await
|
||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
|
|
||||||
// Warmup shard
|
|
||||||
let max_batch_size = batch_size.iter().max().unwrap();
|
|
||||||
sharded_client
|
|
||||||
.warmup(
|
|
||||||
sequence_length,
|
|
||||||
sequence_length * max_batch_size,
|
|
||||||
(sequence_length + decode_length) * max_batch_size,
|
|
||||||
Some(*max_batch_size as usize),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.expect("Unable to warmup");
|
|
||||||
|
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
// Run app
|
// Run app
|
||||||
|
@ -1727,12 +1727,6 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
|
||||||
max_batch_prefill_tokens, max_input_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||||
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||||
@ -1786,12 +1780,6 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
|
||||||
max_batch_prefill_tokens, max_batch_total_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
|
@ -173,9 +173,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Will be set by `generate_token` and reset after each prefill forward
|
# Will be set by `generate_token` and reset after each prefill forward
|
||||||
prefill_logprob_tokens: List[Optional[Tokens]]
|
prefill_logprob_tokens: List[Optional[Tokens]]
|
||||||
|
|
||||||
# Prefixes
|
|
||||||
prefix_ids: List[List[int]]
|
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[List[int]]
|
all_input_ids: List[List[int]]
|
||||||
all_input_ids_tensor: torch.Tensor
|
all_input_ids_tensor: torch.Tensor
|
||||||
@ -259,7 +256,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_postfix_ids = []
|
all_postfix_ids = []
|
||||||
prefix_ids = []
|
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
@ -297,7 +293,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
assert get_support_chunking()
|
assert get_support_chunking()
|
||||||
assert input_length > 0
|
assert input_length > 0
|
||||||
|
|
||||||
prefix_ids.append(tokenized_input[:cache_length])
|
|
||||||
postfix_ids = tokenized_input[cache_length : cache_length + input_length]
|
postfix_ids = tokenized_input[cache_length : cache_length + input_length]
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -400,7 +395,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
prefix_ids=prefix_ids,
|
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
@ -464,7 +458,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests = []
|
requests = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
|
||||||
input_ids = []
|
input_ids = []
|
||||||
|
|
||||||
prompt_lengths = []
|
prompt_lengths = []
|
||||||
@ -505,7 +498,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
prefix_ids.append(self.prefix_ids[idx])
|
|
||||||
|
|
||||||
prompt_lengths.append(self.prompt_lengths[idx])
|
prompt_lengths.append(self.prompt_lengths[idx])
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
@ -621,7 +613,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
prefix_ids=prefix_ids,
|
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
@ -718,7 +709,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables = []
|
block_tables = []
|
||||||
cache_lengths = []
|
cache_lengths = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
|
||||||
|
|
||||||
prompt_lengths = []
|
prompt_lengths = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
@ -802,7 +792,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables.extend(batch.block_tables)
|
block_tables.extend(batch.block_tables)
|
||||||
cache_lengths.extend(batch.cache_lengths)
|
cache_lengths.extend(batch.cache_lengths)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
prefix_ids.extend(batch.prefix_ids)
|
|
||||||
|
|
||||||
prompt_lengths.extend(batch.prompt_lengths)
|
prompt_lengths.extend(batch.prompt_lengths)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
@ -873,7 +862,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
prefix_ids=prefix_ids,
|
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
@ -1839,6 +1827,8 @@ class FlashCausalLM(Model):
|
|||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
accepted_ids,
|
accepted_ids,
|
||||||
|
current_prefilling_mask,
|
||||||
|
batch.prefilling_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
||||||
@ -1855,6 +1845,8 @@ class FlashCausalLM(Model):
|
|||||||
input_length,
|
input_length,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
n_accepted_ids,
|
n_accepted_ids,
|
||||||
|
request_was_prefilling,
|
||||||
|
request_is_prefilling,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
@ -1864,7 +1856,6 @@ class FlashCausalLM(Model):
|
|||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
out_length = out_end_index - out_start_index
|
|
||||||
|
|
||||||
if finished_prefilling:
|
if finished_prefilling:
|
||||||
# Initialize position_ids
|
# Initialize position_ids
|
||||||
@ -1880,21 +1871,25 @@ class FlashCausalLM(Model):
|
|||||||
# Used to gather prefill logprobs
|
# Used to gather prefill logprobs
|
||||||
# Copy batch.input_ids to prefill_token_indices
|
# Copy batch.input_ids to prefill_token_indices
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
|
# If the request was prefilling and cache_length == 0, the first token is a bogus token
|
||||||
|
# and needs to be removed. We do so by incrementing the start_index
|
||||||
|
if request_was_prefilling and cache_length == 0:
|
||||||
|
start_index += 1
|
||||||
|
|
||||||
|
# If the request was prefilling, and it is done prefilling, the last token was generated and is
|
||||||
|
# therefore not part of the prefill. We remove it by decrementing out_end_index
|
||||||
|
if request_was_prefilling and not request_is_prefilling:
|
||||||
|
out_end_index -= 1
|
||||||
|
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
prefill_tokens_indices[out_start_index:out_end_index] = (
|
||||||
batch.input_ids[start_index + 1 : start_index + out_length]
|
batch.input_ids[start_index:end_index]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Set prefill_tokens_indices to the correct slice
|
# Set prefill_tokens_indices to the correct slice
|
||||||
prefill_tokens_indices = batch.input_ids[
|
prefill_tokens_indices = batch.input_ids[start_index:end_index]
|
||||||
start_index + 1 : start_index + out_length
|
|
||||||
]
|
|
||||||
|
|
||||||
# Represent whether this request is still prefilling
|
if not request_is_prefilling:
|
||||||
# If it is, the tokens we decoded should be ignored
|
|
||||||
accept_tokens = cache_length + input_length >= prompt_length
|
|
||||||
|
|
||||||
if accept_tokens:
|
|
||||||
# Only save tokens if we are done prefilling for this request
|
# Only save tokens if we are done prefilling for this request
|
||||||
for j in range(n_accepted_ids):
|
for j in range(n_accepted_ids):
|
||||||
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
|
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
|
||||||
@ -1995,7 +1990,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.read_offsets,
|
batch.read_offsets,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
batch.prefix_ids,
|
|
||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
@ -2019,7 +2013,6 @@ class FlashCausalLM(Model):
|
|||||||
read_offset,
|
read_offset,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
prefix_ids,
|
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
@ -2039,19 +2032,30 @@ class FlashCausalLM(Model):
|
|||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
|
# log_master(logger.info, f"{prefill_logprobs}")
|
||||||
|
|
||||||
|
if not request_is_prefilling:
|
||||||
|
# If the request is done prefilling, then the last logprob is a generated token
|
||||||
|
# We need to remove it
|
||||||
|
out_end_index -= 1
|
||||||
|
|
||||||
request_prefill_logprobs = prefill_logprobs[
|
request_prefill_logprobs = prefill_logprobs[
|
||||||
out_start_index : out_end_index - 1
|
out_start_index:out_end_index
|
||||||
|
]
|
||||||
|
prefill_token_ids = all_input_ids[
|
||||||
|
cache_length : cache_length + input_length
|
||||||
]
|
]
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
|
||||||
|
|
||||||
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
|
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
|
||||||
|
|
||||||
if past_prefill_logprob_tokens is None:
|
if past_prefill_logprob_tokens is None:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
request_prefill_logprobs = [float("nan")] * (
|
request_prefill_logprobs = [float("nan")] * (
|
||||||
len(prefix_ids) + 1
|
cache_length + 1
|
||||||
) + request_prefill_logprobs
|
) + request_prefill_logprobs
|
||||||
prefill_token_ids = prefix_ids + prefill_token_ids
|
prefill_token_ids = (
|
||||||
|
all_input_ids[:cache_length] + prefill_token_ids
|
||||||
|
)
|
||||||
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
@ -2059,6 +2063,10 @@ class FlashCausalLM(Model):
|
|||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# log_master(logger.info, f"{prefill_token_ids}")
|
||||||
|
# log_master(logger.info, f"{request_prefill_logprobs}")
|
||||||
|
# log_master(logger.info, f"{prefill_texts}")
|
||||||
|
|
||||||
prefill_logprob_tokens = Tokens(
|
prefill_logprob_tokens = Tokens(
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
request_prefill_logprobs,
|
request_prefill_logprobs,
|
||||||
|
Loading…
Reference in New Issue
Block a user