Fix missing arguments in Galactica's from_pb

Fixes #1004
This commit is contained in:
Vincent Brouwers 2023-09-14 08:40:19 +00:00
parent c8a01d7591
commit 9b4545f279

View File

@ -80,6 +80,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers = []
stopping_criterias = []
prefix_offsets = []
top_n_tokens = []
read_offsets = []
requests_idx_mapping = {}
@ -96,6 +97,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
@ -129,6 +131,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
@ -146,6 +151,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,