mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Include a few fixes
Bad degradation on pad tokens for ngram Invalid batch discard `stopped` value could be incorrect.
This commit is contained in:
parent
7b34445457
commit
f6958ea6d4
@ -527,11 +527,6 @@ fn send_responses(
|
|||||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
let n = tokens_.ids.len();
|
let n = tokens_.ids.len();
|
||||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||||
|
|
||||||
assert_eq!(n, tokens_.logprobs.len());
|
|
||||||
assert_eq!(n, tokens_.texts.len());
|
|
||||||
assert_eq!(n, tokens_.is_special.len());
|
|
||||||
|
|
||||||
let mut iterator = tokens_
|
let mut iterator = tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -229,7 +229,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if self.requires_padding {
|
if self.requires_padding {
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
} else {
|
} else {
|
||||||
let max_new_tokens = match self.window_size {
|
let max_new_tokens = match self.window_size {
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
@ -237,7 +237,7 @@ impl State {
|
|||||||
window_size.saturating_sub(entry.request.input_length),
|
window_size.saturating_sub(entry.request.input_length),
|
||||||
entry.request.stopping_parameters.max_new_tokens,
|
entry.request.stopping_parameters.max_new_tokens,
|
||||||
),
|
),
|
||||||
} + self.speculate;
|
};
|
||||||
|
|
||||||
// pad to block size
|
// pad to block size
|
||||||
decode_tokens +=
|
decode_tokens +=
|
||||||
@ -245,7 +245,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
if prefill_tokens > prefill_token_budget
|
||||||
|| (prefill_tokens + decode_tokens) > token_budget
|
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||||
{
|
{
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
@ -543,7 +543,15 @@ mod tests {
|
|||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
queue.append(entry2);
|
queue.append(entry2);
|
||||||
|
|
||||||
|
// Budget of 1 is not enough
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
@ -162,9 +162,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
tokenized_input = tokenized_input[-r.truncate :]
|
||||||
|
|
||||||
speculate_ids = []
|
|
||||||
|
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
@ -806,10 +803,9 @@ class FlashCausalLM(Model):
|
|||||||
del batch
|
del batch
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
try:
|
if isinstance(out, tuple):
|
||||||
out, speculative_logits = out.logits, out.speculative_logits
|
out, speculative_logits = out
|
||||||
except Exception:
|
else:
|
||||||
out = out
|
|
||||||
speculative_logits = None
|
speculative_logits = None
|
||||||
|
|
||||||
|
|
||||||
@ -829,9 +825,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
logger.info(f"Accepted id {accepted_ids}")
|
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||||
)
|
)
|
||||||
@ -843,8 +836,7 @@ class FlashCausalLM(Model):
|
|||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
length = len(batch)
|
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||||
next_position_ids = batch.position_ids.new_empty(length)
|
|
||||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||||
# We do not need cu_seqlen_prefill anymore
|
# We do not need cu_seqlen_prefill anymore
|
||||||
batch.cu_seqlen_prefill = None
|
batch.cu_seqlen_prefill = None
|
||||||
@ -965,6 +957,9 @@ class FlashCausalLM(Model):
|
|||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
|
before = stopping_criteria.current_tokens
|
||||||
|
|
||||||
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_id = next_token_ids[j]
|
next_token_id = next_token_ids[j]
|
||||||
@ -982,11 +977,12 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
stopped = True
|
|
||||||
left = index + n_accepted_ids - j - 1
|
left = index + n_accepted_ids - j - 1
|
||||||
|
current_stopped = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
stopped = False
|
current_stopped = False
|
||||||
|
stopped = stopped and current_stopped
|
||||||
|
|
||||||
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
|
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
|
||||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
|
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
|
||||||
@ -997,15 +993,12 @@ class FlashCausalLM(Model):
|
|||||||
if i % self.world_size == self.rank:
|
if i % self.world_size == self.rank:
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
# Remove potentially accepted ids that do not respect
|
|
||||||
# the stopping_criteria
|
|
||||||
_ids = all_input_ids
|
|
||||||
output_text, _, _ = self.decode_token(
|
output_text, _, _ = self.decode_token(
|
||||||
_ids,
|
all_input_ids,
|
||||||
prefix_offset=len(_ids)
|
prefix_offset=len(all_input_ids)
|
||||||
- stopping_criteria.current_tokens
|
- stopping_criteria.current_tokens
|
||||||
- 1,
|
- 1,
|
||||||
read_offset=len(_ids)
|
read_offset=len(all_input_ids)
|
||||||
- stopping_criteria.current_tokens,
|
- stopping_criteria.current_tokens,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
@ -34,7 +34,7 @@ class MedusaModel(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
logits = self.lm_head(x)
|
logits = self.lm_head(x)
|
||||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
return Output(logits=logits, speculative_logits=speculative_logits)
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class MedusaHead(torch.nn.Module):
|
class MedusaHead(torch.nn.Module):
|
||||||
|
@ -186,6 +186,8 @@ def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, a
|
|||||||
index = 0
|
index = 0
|
||||||
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
|
for i, (_input_ids, n_accepted_ids) in enumerate(zip(cpu_input_ids, accepted_ids.tolist())):
|
||||||
stop = len(_input_ids)
|
stop = len(_input_ids)
|
||||||
|
# TODO 0 is not necessarily the pad token.
|
||||||
|
# Remove zero padded end.
|
||||||
for j, _id in enumerate(_input_ids):
|
for j, _id in enumerate(_input_ids):
|
||||||
if _id == 0:
|
if _id == 0:
|
||||||
stop = j
|
stop = j
|
||||||
|
Loading…
Reference in New Issue
Block a user