fix launcher

This commit is contained in:
OlivierDehaene 2024-10-08 19:23:45 +02:00
parent ea4b739a9f
commit 08953c5975
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
3 changed files with 37 additions and 53 deletions

View File

@ -179,18 +179,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.await
.expect("Unable to clear cache");
// Warmup shard
let max_batch_size = batch_size.iter().max().unwrap();
sharded_client
.warmup(
sequence_length,
sequence_length * max_batch_size,
(sequence_length + decode_length) * max_batch_size,
Some(*max_batch_size as usize),
)
.await
.expect("Unable to warmup");
tracing::info!("Connected");
// Run app

View File

@ -1727,12 +1727,6 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
@ -1786,12 +1780,6 @@ fn main() -> Result<(), LauncherError> {
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",

View File

@ -173,9 +173,6 @@ class FlashCausalLMBatch(Batch):
# Will be set by `generate_token` and reset after each prefill forward
prefill_logprob_tokens: List[Optional[Tokens]]
# Prefixes
prefix_ids: List[List[int]]
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor
@ -259,7 +256,6 @@ class FlashCausalLMBatch(Batch):
read_offsets = []
all_input_ids = []
all_postfix_ids = []
prefix_ids = []
requests_idx_mapping = {}
next_token_chooser_parameters = []
@ -297,7 +293,6 @@ class FlashCausalLMBatch(Batch):
assert get_support_chunking()
assert input_length > 0
prefix_ids.append(tokenized_input[:cache_length])
postfix_ids = tokenized_input[cache_length : cache_length + input_length]
assert (
@ -400,7 +395,6 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
@ -464,7 +458,6 @@ class FlashCausalLMBatch(Batch):
requests = []
block_tables = []
all_input_ids = []
prefix_ids = []
input_ids = []
prompt_lengths = []
@ -505,7 +498,6 @@ class FlashCausalLMBatch(Batch):
)
all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx])
input_lengths.append(request_input_length)
@ -621,7 +613,6 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
@ -718,7 +709,6 @@ class FlashCausalLMBatch(Batch):
block_tables = []
cache_lengths = []
all_input_ids = []
prefix_ids = []
prompt_lengths = []
input_lengths = []
@ -802,7 +792,6 @@ class FlashCausalLMBatch(Batch):
block_tables.extend(batch.block_tables)
cache_lengths.extend(batch.cache_lengths)
all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids)
prompt_lengths.extend(batch.prompt_lengths)
input_lengths.extend(batch.input_lengths)
@ -873,7 +862,6 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
@ -1839,6 +1827,8 @@ class FlashCausalLM(Model):
batch.input_lengths,
batch.all_input_ids,
accepted_ids,
current_prefilling_mask,
batch.prefilling_mask,
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
@ -1855,6 +1845,8 @@ class FlashCausalLM(Model):
input_length,
all_input_ids,
n_accepted_ids,
request_was_prefilling,
request_is_prefilling,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
@ -1864,7 +1856,6 @@ class FlashCausalLM(Model):
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
out_length = out_end_index - out_start_index
if finished_prefilling:
# Initialize position_ids
@ -1880,21 +1871,25 @@ class FlashCausalLM(Model):
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
# If the request was prefilling and cache_length == 0, the first token is a bogus token
# and needs to be removed. We do so by incrementing the start_index
if request_was_prefilling and cache_length == 0:
start_index += 1
# If the request was prefilling, and it is done prefilling, the last token was generated and is
# therefore not part of the prefill. We remove it by decrementing out_end_index
if request_was_prefilling and not request_is_prefilling:
out_end_index -= 1
if len(batch) > 1:
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
batch.input_ids[start_index + 1 : start_index + out_length]
prefill_tokens_indices[out_start_index:out_end_index] = (
batch.input_ids[start_index:end_index]
)
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : start_index + out_length
]
prefill_tokens_indices = batch.input_ids[start_index:end_index]
# Represent whether this request is still prefilling
# If it is, the tokens we decoded should be ignored
accept_tokens = cache_length + input_length >= prompt_length
if accept_tokens:
if not request_is_prefilling:
# Only save tokens if we are done prefilling for this request
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
@ -1995,7 +1990,6 @@ class FlashCausalLM(Model):
batch.read_offsets,
batch.stopping_criterias,
batch.all_input_ids,
batch.prefix_ids,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
@ -2019,7 +2013,6 @@ class FlashCausalLM(Model):
read_offset,
stopping_criteria,
all_input_ids,
prefix_ids,
do_sample,
seed,
top_n_tokens,
@ -2039,19 +2032,30 @@ class FlashCausalLM(Model):
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# log_master(logger.info, f"{prefill_logprobs}")
if not request_is_prefilling:
# If the request is done prefilling, then the last logprob is a generated token
# We need to remove it
out_end_index -= 1
request_prefill_logprobs = prefill_logprobs[
out_start_index : out_end_index - 1
out_start_index:out_end_index
]
prefill_token_ids = all_input_ids[
cache_length : cache_length + input_length
]
prefill_token_ids = all_input_ids[:-1]
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
if past_prefill_logprob_tokens is None:
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] * (
len(prefix_ids) + 1
cache_length + 1
) + request_prefill_logprobs
prefill_token_ids = prefix_ids + prefill_token_ids
prefill_token_ids = (
all_input_ids[:cache_length] + prefill_token_ids
)
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
@ -2059,6 +2063,10 @@ class FlashCausalLM(Model):
skip_special_tokens=False,
)
# log_master(logger.info, f"{prefill_token_ids}")
# log_master(logger.info, f"{request_prefill_logprobs}")
# log_master(logger.info, f"{prefill_texts}")
prefill_logprob_tokens = Tokens(
prefill_token_ids,
request_prefill_logprobs,