mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
improve decode
This commit is contained in:
parent
cdc33ce63c
commit
c11e77411f
@ -34,6 +34,7 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
|
offsets: List[Optional[int]]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -64,12 +65,14 @@ class CausalLMBatch(Batch):
|
|||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
offsets = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
|
offsets.append(None)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -113,6 +116,7 @@ class CausalLMBatch(Batch):
|
|||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
|
offsets=offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
@ -135,6 +139,7 @@ class CausalLMBatch(Batch):
|
|||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
@ -151,6 +156,7 @@ class CausalLMBatch(Batch):
|
|||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
|
offsets.extend(batch.offsets)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
@ -264,6 +270,7 @@ class CausalLMBatch(Batch):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
offsets=offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
@ -350,6 +357,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# New values for next forward
|
# New values for next forward
|
||||||
next_batch_input_lengths = []
|
next_batch_input_lengths = []
|
||||||
|
next_batch_offsets = []
|
||||||
next_batch_input_ids = []
|
next_batch_input_ids = []
|
||||||
next_batch_all_input_ids = []
|
next_batch_all_input_ids = []
|
||||||
|
|
||||||
@ -364,6 +372,7 @@ class CausalLM(Model):
|
|||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
|
batch.offsets,
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
@ -374,6 +383,7 @@ class CausalLM(Model):
|
|||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
|
offset,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
@ -391,10 +401,7 @@ class CausalLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text = self.decode_token(
|
next_token_text, offset = self.decode_token(all_input_ids[:, 0], offset)
|
||||||
all_input_ids[-2, 0],
|
|
||||||
next_token_id_squeezed,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
@ -424,6 +431,7 @@ class CausalLM(Model):
|
|||||||
next_batch_all_input_ids.append(all_input_ids)
|
next_batch_all_input_ids.append(all_input_ids)
|
||||||
next_batch_size += 1
|
next_batch_size += 1
|
||||||
next_batch_input_lengths.append(new_input_length)
|
next_batch_input_lengths.append(new_input_length)
|
||||||
|
next_batch_offsets.append(offset)
|
||||||
next_batch_max_input_length = max(
|
next_batch_max_input_length = max(
|
||||||
next_batch_max_input_length, new_input_length
|
next_batch_max_input_length, new_input_length
|
||||||
)
|
)
|
||||||
@ -507,6 +515,7 @@ class CausalLM(Model):
|
|||||||
past_key_values=next_batch_past_key_values,
|
past_key_values=next_batch_past_key_values,
|
||||||
all_input_ids=next_batch_all_input_ids,
|
all_input_ids=next_batch_all_input_ids,
|
||||||
input_lengths=next_batch_input_lengths,
|
input_lengths=next_batch_input_lengths,
|
||||||
|
offsets=next_batch_offsets,
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
|
@ -44,6 +44,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
|
offsets: List[Optional[int]]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -67,6 +68,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_input_ids_tensor = []
|
all_input_ids_tensor = []
|
||||||
|
|
||||||
@ -84,6 +86,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
offsets.append(None)
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
tokenized_input = torch.tensor(tokenized_input, device=device)
|
tokenized_input = torch.tensor(tokenized_input, device=device)
|
||||||
@ -120,6 +123,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
offsets=offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
@ -132,6 +136,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_input_ids_tensor = []
|
all_input_ids_tensor = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
@ -150,6 +155,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
|
offsets.extend(batch.offsets)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
@ -279,6 +285,7 @@ class FlashCausalLM(Model):
|
|||||||
next_batch_max_seqlen = 0
|
next_batch_max_seqlen = 0
|
||||||
next_batch_past_key_values = []
|
next_batch_past_key_values = []
|
||||||
next_batch_input_lengths = []
|
next_batch_input_lengths = []
|
||||||
|
next_batch_offsets = []
|
||||||
next_batch_all_input_ids = []
|
next_batch_all_input_ids = []
|
||||||
next_batch_all_input_ids_tensor = []
|
next_batch_all_input_ids_tensor = []
|
||||||
|
|
||||||
@ -292,6 +299,7 @@ class FlashCausalLM(Model):
|
|||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
|
batch.offsets,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
@ -302,6 +310,7 @@ class FlashCausalLM(Model):
|
|||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
|
offset,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
@ -334,9 +343,8 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id_item]
|
next_token_logprob = logprobs[-1, next_token_id_item]
|
||||||
next_token_text = self.decode_token(
|
next_token_text, offset = self.decode_token(
|
||||||
all_input_ids[-2],
|
all_input_ids[-(stopping_criteria.current_tokens + 1) :], offset
|
||||||
next_token_id_item,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
@ -377,6 +385,7 @@ class FlashCausalLM(Model):
|
|||||||
next_batch_cu_seqlens[-1] + new_input_length
|
next_batch_cu_seqlens[-1] + new_input_length
|
||||||
)
|
)
|
||||||
next_batch_input_lengths.append(new_input_length)
|
next_batch_input_lengths.append(new_input_length)
|
||||||
|
next_batch_offsets.append(offset)
|
||||||
next_batch_all_input_ids.append(all_input_ids)
|
next_batch_all_input_ids.append(all_input_ids)
|
||||||
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
|
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
|
||||||
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
|
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
|
||||||
@ -453,6 +462,7 @@ class FlashCausalLM(Model):
|
|||||||
max_seqlen=next_batch_max_seqlen,
|
max_seqlen=next_batch_max_seqlen,
|
||||||
past_key_values=next_batch_past_key_values,
|
past_key_values=next_batch_past_key_values,
|
||||||
input_lengths=next_batch_input_lengths,
|
input_lengths=next_batch_input_lengths,
|
||||||
|
offsets=next_batch_offsets,
|
||||||
all_input_ids=next_batch_all_input_ids,
|
all_input_ids=next_batch_all_input_ids,
|
||||||
all_input_ids_tensor=next_batch_all_input_ids_tensor,
|
all_input_ids_tensor=next_batch_all_input_ids_tensor,
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
|
@ -93,23 +93,21 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
input_lengths = []
|
offsets = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
max_sequence_length = 0
|
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
input_lengths.append(r.input_length)
|
offsets.append(None)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
max_sequence_length = max(max_sequence_length, r.input_length)
|
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
padding_right_offset, stopping_criteria.max_new_tokens
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
@ -123,13 +121,17 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
max_input_length = input_lengths.max()
|
||||||
|
|
||||||
input_ids = tokenized_inputs["input_ids"]
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
# Allocate maximum attention_mask
|
# Allocate maximum attention_mask
|
||||||
attention_mask = input_ids.new_zeros(
|
attention_mask = input_ids.new_zeros(
|
||||||
(pb.size, max_sequence_length + padding_right_offset)
|
(pb.size, max_input_length + padding_right_offset)
|
||||||
)
|
)
|
||||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||||
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
|
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
@ -144,10 +146,11 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
offsets=offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
max_sequence_length=max_sequence_length,
|
max_input_length=max_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,16 +24,26 @@ class Model(ABC):
|
|||||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def decode_token(self, previous_token_id: int, token_id: int) -> str:
|
def decode_token(
|
||||||
|
self, all_input_ids: List[int], offset: Optional[int] = None
|
||||||
|
) -> Tuple[str, Optional[int]]:
|
||||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
# Decode previous token and previous token + token
|
|
||||||
|
# Decode all token minus last one and all tokens
|
||||||
results = self.tokenizer.batch_decode(
|
results = self.tokenizer.batch_decode(
|
||||||
[[previous_token_id], [previous_token_id, token_id]],
|
[all_input_ids[:-1], all_input_ids],
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if results[0] and results[0][0] == " " and results[1][0] != " ":
|
# default offset is only the last token
|
||||||
results[0] = results[0].lstrip()
|
if offset is None:
|
||||||
|
offset = len(results[0])
|
||||||
|
|
||||||
# slice to remove previous token
|
# get text
|
||||||
return results[1][len(results[0]): ]
|
text = results[1][offset:]
|
||||||
|
|
||||||
|
# if text is utf-8
|
||||||
|
if text and text[-1] != "<EFBFBD>":
|
||||||
|
return text, None
|
||||||
|
else:
|
||||||
|
return "", offset
|
||||||
|
@ -38,6 +38,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
decoder_input_lengths: List[int]
|
decoder_input_lengths: List[int]
|
||||||
|
offsets: List[Optional[int]]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
@ -71,6 +72,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
decoder_input_ids = []
|
decoder_input_ids = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
|
offsets = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
@ -80,6 +82,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
# Decoder sequence only contains the bos_token
|
# Decoder sequence only contains the bos_token
|
||||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
|
offsets.append(None)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
@ -117,6 +120,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
|
offsets=offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=len(pb.requests),
|
size=len(pb.requests),
|
||||||
@ -147,6 +151,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
requests = []
|
requests = []
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
|
offsets = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
@ -166,6 +171,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||||
|
offsets.extend(batch.offsets)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
@ -303,6 +309,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
|
offsets=offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
@ -422,6 +429,7 @@ class Seq2SeqLM(Model):
|
|||||||
|
|
||||||
# New values for next forward
|
# New values for next forward
|
||||||
next_batch_input_lengths = []
|
next_batch_input_lengths = []
|
||||||
|
next_batch_offsets = []
|
||||||
next_batch_decoder_input_ids = []
|
next_batch_decoder_input_ids = []
|
||||||
next_batch_decoder_input_lengths = []
|
next_batch_decoder_input_lengths = []
|
||||||
|
|
||||||
@ -437,6 +445,7 @@ class Seq2SeqLM(Model):
|
|||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
|
batch.offsets,
|
||||||
batch.decoder_input_lengths,
|
batch.decoder_input_lengths,
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
@ -448,6 +457,7 @@ class Seq2SeqLM(Model):
|
|||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
|
offset,
|
||||||
decoder_input_length,
|
decoder_input_length,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
@ -466,10 +476,7 @@ class Seq2SeqLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text = self.decode_token(
|
next_token_text, offset = self.decode_token(decoder_input_ids, offset)
|
||||||
decoder_input_ids[-2],
|
|
||||||
next_token_id_squeezed,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||||
@ -496,6 +503,7 @@ class Seq2SeqLM(Model):
|
|||||||
next_batch_size += 1
|
next_batch_size += 1
|
||||||
next_batch_input_lengths.append(input_length)
|
next_batch_input_lengths.append(input_length)
|
||||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||||
|
next_batch_offsets.append(offset)
|
||||||
next_batch_max_input_length = max(
|
next_batch_max_input_length = max(
|
||||||
next_batch_max_input_length, input_length
|
next_batch_max_input_length, input_length
|
||||||
)
|
)
|
||||||
@ -581,6 +589,7 @@ class Seq2SeqLM(Model):
|
|||||||
past_key_values=next_batch_past_key_values,
|
past_key_values=next_batch_past_key_values,
|
||||||
input_lengths=next_batch_input_lengths,
|
input_lengths=next_batch_input_lengths,
|
||||||
decoder_input_lengths=next_batch_decoder_input_lengths,
|
decoder_input_lengths=next_batch_decoder_input_lengths,
|
||||||
|
offsets=next_batch_offsets,
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user