mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
parent
c8a01d7591
commit
9b4545f279
@ -80,6 +80,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
prefix_offsets = []
|
||||
top_n_tokens = []
|
||||
read_offsets = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
@ -96,6 +97,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
@ -129,6 +131,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||
|
||||
@ -146,6 +151,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
|
Loading…
Reference in New Issue
Block a user