mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
parent
c8a01d7591
commit
9b4545f279
@ -80,6 +80,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
|
top_n_tokens = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
@ -96,6 +97,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
@ -129,6 +131,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||||
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||||
|
|
||||||
@ -146,6 +151,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
Loading…
Reference in New Issue
Block a user