mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-29 22:12:07 +00:00
Do not schedule decode if max_new_tokens is equal to 1 (#183)
Co-authored-by: Bartosz Kowalski <bkowalski@habana.ai>
This commit is contained in:
parent
15e5df1cc4
commit
0ca54b55f8
@ -985,6 +985,10 @@ class CausalLM(Model):
|
|||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||||
)
|
)
|
||||||
|
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
||||||
|
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
||||||
|
# - we've already generated the first and only needed token in the prefill phase
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user