mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing infer iterator.
This commit is contained in:
parent
09839b05f4
commit
9bf31fe388
@ -527,13 +527,14 @@ fn send_responses(
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||
for (i, (((id, logprob), text), special)) in tokens_
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs.into_iter())
|
||||
.zip(tokens_.texts.into_iter())
|
||||
.zip(tokens_.is_special.into_iter())
|
||||
.enumerate()
|
||||
.enumerate().peekable();
|
||||
while let Some( (i, (((id, logprob), text), special))) = iterator.next()
|
||||
{
|
||||
let token = Token {
|
||||
id,
|
||||
@ -557,9 +558,9 @@ fn send_responses(
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
match (&generation.generated_text, i) {
|
||||
(Some(generated_text), i) if i == n - 1 => {
|
||||
};
|
||||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
|
@ -101,7 +101,7 @@ def get_model(
|
||||
if speculate is not None:
|
||||
set_speculate(speculate)
|
||||
else:
|
||||
set_speculate(2)
|
||||
set_speculate(0)
|
||||
|
||||
if "facebook/galactica" in model_id:
|
||||
return GalacticaSharded(
|
||||
@ -159,7 +159,10 @@ def get_model(
|
||||
method = "medusa"
|
||||
else:
|
||||
method = "n-gram"
|
||||
logger.info(f"Using speculation {method} with {get_speculate()} input ids.")
|
||||
|
||||
speculate = get_speculate()
|
||||
if speculate > 0:
|
||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
|
@ -960,9 +960,6 @@ class FlashCausalLM(Model):
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Append next token to all tokens
|
||||
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
|
||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
||||
|
||||
next_token_texts = []
|
||||
left = 0
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
@ -983,12 +980,14 @@ class FlashCausalLM(Model):
|
||||
|
||||
if stop:
|
||||
stopped = True
|
||||
left = len(_next_token_ids) - 1 - j
|
||||
left = n_accepted_ids - 1 - j
|
||||
break
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
|
||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
|
||||
index += n_accepted_ids
|
||||
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]
|
||||
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
@ -1085,8 +1084,6 @@ class FlashCausalLM(Model):
|
||||
batch.prefill_cu_outlens = None
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
if prefill:
|
||||
batch.max_seqlen += speculative_length
|
||||
batch.max_seqlen = batch.max_seqlen + 1
|
||||
|
||||
return generations, batch
|
||||
|
Loading…
Reference in New Issue
Block a user