Fixing infer iterator.

This commit is contained in:
Nicolas Patry 2023-12-05 20:48:05 +00:00
parent 09839b05f4
commit 9bf31fe388
3 changed files with 15 additions and 14 deletions

View File

@ -527,13 +527,14 @@ fn send_responses(
let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len(); let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
for (i, (((id, logprob), text), special)) in tokens_ let mut iterator = tokens_
.ids .ids
.into_iter() .into_iter()
.zip(tokens_.logprobs.into_iter()) .zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter()) .zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter()) .zip(tokens_.is_special.into_iter())
.enumerate() .enumerate().peekable();
while let Some( (i, (((id, logprob), text), special))) = iterator.next()
{ {
let token = Token { let token = Token {
id, id,
@ -558,8 +559,8 @@ fn send_responses(
} else { } else {
vec![] vec![]
}; };
match (&generation.generated_text, i) { match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), i) if i == n - 1 => { (Some(generated_text), None) => {
// Generation has ended // Generation has ended
stopped = true; stopped = true;
// Send message // Send message

View File

@ -101,7 +101,7 @@ def get_model(
if speculate is not None: if speculate is not None:
set_speculate(speculate) set_speculate(speculate)
else: else:
set_speculate(2) set_speculate(0)
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
return GalacticaSharded( return GalacticaSharded(
@ -159,7 +159,10 @@ def get_model(
method = "medusa" method = "medusa"
else: else:
method = "n-gram" method = "n-gram"
logger.info(f"Using speculation {method} with {get_speculate()} input ids.")
speculate = get_speculate()
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"] model_type = config_dict["model_type"]

View File

@ -960,9 +960,6 @@ class FlashCausalLM(Model):
top_token_logprobs, top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Append next token to all tokens # Append next token to all tokens
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
next_token_texts = [] next_token_texts = []
left = 0 left = 0
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
@ -983,12 +980,14 @@ class FlashCausalLM(Model):
if stop: if stop:
stopped = True stopped = True
left = len(_next_token_ids) - 1 - j left = n_accepted_ids - 1 - j
break break
else: else:
stopped = False stopped = False
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
index += n_accepted_ids index += n_accepted_ids
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]
# Shard generations # Shard generations
# All generations will be appended in the rust sharded client # All generations will be appended in the rust sharded client
@ -1085,8 +1084,6 @@ class FlashCausalLM(Model):
batch.prefill_cu_outlens = None batch.prefill_cu_outlens = None
batch.prefill_head_indices = None batch.prefill_head_indices = None
batch.prefill_next_token_indices = None batch.prefill_next_token_indices = None
if prefill:
batch.max_seqlen += speculative_length
batch.max_seqlen = batch.max_seqlen + 1 batch.max_seqlen = batch.max_seqlen + 1
return generations, batch return generations, batch