mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing infer iterator.
This commit is contained in:
parent
09839b05f4
commit
9bf31fe388
@ -527,13 +527,14 @@ fn send_responses(
|
|||||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
let n = tokens_.ids.len();
|
let n = tokens_.ids.len();
|
||||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||||
for (i, (((id, logprob), text), special)) in tokens_
|
let mut iterator = tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(tokens_.logprobs.into_iter())
|
.zip(tokens_.logprobs.into_iter())
|
||||||
.zip(tokens_.texts.into_iter())
|
.zip(tokens_.texts.into_iter())
|
||||||
.zip(tokens_.is_special.into_iter())
|
.zip(tokens_.is_special.into_iter())
|
||||||
.enumerate()
|
.enumerate().peekable();
|
||||||
|
while let Some( (i, (((id, logprob), text), special))) = iterator.next()
|
||||||
{
|
{
|
||||||
let token = Token {
|
let token = Token {
|
||||||
id,
|
id,
|
||||||
@ -557,9 +558,9 @@ fn send_responses(
|
|||||||
.collect()
|
.collect()
|
||||||
} else {
|
} else {
|
||||||
vec![]
|
vec![]
|
||||||
};
|
};
|
||||||
match (&generation.generated_text, i) {
|
match (&generation.generated_text, iterator.peek()) {
|
||||||
(Some(generated_text), i) if i == n - 1 => {
|
(Some(generated_text), None) => {
|
||||||
// Generation has ended
|
// Generation has ended
|
||||||
stopped = true;
|
stopped = true;
|
||||||
// Send message
|
// Send message
|
||||||
|
@ -101,7 +101,7 @@ def get_model(
|
|||||||
if speculate is not None:
|
if speculate is not None:
|
||||||
set_speculate(speculate)
|
set_speculate(speculate)
|
||||||
else:
|
else:
|
||||||
set_speculate(2)
|
set_speculate(0)
|
||||||
|
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
return GalacticaSharded(
|
return GalacticaSharded(
|
||||||
@ -159,7 +159,10 @@ def get_model(
|
|||||||
method = "medusa"
|
method = "medusa"
|
||||||
else:
|
else:
|
||||||
method = "n-gram"
|
method = "n-gram"
|
||||||
logger.info(f"Using speculation {method} with {get_speculate()} input ids.")
|
|
||||||
|
speculate = get_speculate()
|
||||||
|
if speculate > 0:
|
||||||
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||||
|
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
|
@ -960,9 +960,6 @@ class FlashCausalLM(Model):
|
|||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
|
|
||||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
|
||||||
|
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
@ -983,12 +980,14 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
stopped = True
|
stopped = True
|
||||||
left = len(_next_token_ids) - 1 - j
|
left = n_accepted_ids - 1 - j
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
stopped = False
|
stopped = False
|
||||||
|
|
||||||
|
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
|
||||||
|
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
|
||||||
index += n_accepted_ids
|
index += n_accepted_ids
|
||||||
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]
|
|
||||||
|
|
||||||
# Shard generations
|
# Shard generations
|
||||||
# All generations will be appended in the rust sharded client
|
# All generations will be appended in the rust sharded client
|
||||||
@ -1085,8 +1084,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.prefill_cu_outlens = None
|
batch.prefill_cu_outlens = None
|
||||||
batch.prefill_head_indices = None
|
batch.prefill_head_indices = None
|
||||||
batch.prefill_next_token_indices = None
|
batch.prefill_next_token_indices = None
|
||||||
if prefill:
|
|
||||||
batch.max_seqlen += speculative_length
|
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
batch.max_seqlen = batch.max_seqlen + 1
|
||||||
|
|
||||||
return generations, batch
|
return generations, batch
|
||||||
|
Loading…
Reference in New Issue
Block a user