improve decode

This commit is contained in:
OlivierDehaene 2023-04-04 18:31:26 +02:00
parent cdc33ce63c
commit c11e77411f
5 changed files with 66 additions and 25 deletions

View File

@ -34,6 +34,7 @@ class CausalLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
offsets: List[Optional[int]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
@ -64,12 +65,14 @@ class CausalLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
offsets = []
# Parse batch
max_truncation = 0
padding_right_offset = 0
for r in pb.requests:
inputs.append(r.inputs)
offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@ -113,6 +116,7 @@ class CausalLMBatch(Batch):
past_key_values=None,
all_input_ids=all_input_ids,
input_lengths=input_lengths.tolist(),
offsets=offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
@ -135,6 +139,7 @@ class CausalLMBatch(Batch):
# Batch attributes
requests = []
input_lengths = []
offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
@ -151,6 +156,7 @@ class CausalLMBatch(Batch):
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
@ -264,6 +270,7 @@ class CausalLMBatch(Batch):
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
offsets=offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
@ -350,6 +357,7 @@ class CausalLM(Model):
# New values for next forward
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_input_ids = []
next_batch_all_input_ids = []
@ -364,6 +372,7 @@ class CausalLM(Model):
iterator = zip(
batch.requests,
batch.input_lengths,
batch.offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
@ -374,6 +383,7 @@ class CausalLM(Model):
for i, (
request,
input_length,
offset,
logits,
next_token_chooser,
stopping_criteria,
@ -391,10 +401,7 @@ class CausalLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.decode_token(
all_input_ids[-2, 0],
next_token_id_squeezed,
)
next_token_text, offset = self.decode_token(all_input_ids[:, 0], offset)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
@ -424,6 +431,7 @@ class CausalLM(Model):
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset)
next_batch_max_input_length = max(
next_batch_max_input_length, new_input_length
)
@ -507,6 +515,7 @@ class CausalLM(Model):
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids,
input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,

View File

@ -44,6 +44,7 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
offsets: List[Optional[int]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
@ -67,6 +68,7 @@ class FlashCausalLMBatch(Batch):
max_seqlen = 0
input_lengths = []
offsets = []
all_input_ids = []
all_input_ids_tensor = []
@ -84,6 +86,7 @@ class FlashCausalLMBatch(Batch):
input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
offsets.append(None)
all_input_ids.append(tokenized_input)
tokenized_input = torch.tensor(tokenized_input, device=device)
@ -120,6 +123,7 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen,
past_key_values=None,
input_lengths=input_lengths,
offsets=offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
@ -132,6 +136,7 @@ class FlashCausalLMBatch(Batch):
# Batch attributes
requests = []
input_lengths = []
offsets = []
all_input_ids = []
all_input_ids_tensor = []
next_token_choosers = []
@ -150,6 +155,7 @@ class FlashCausalLMBatch(Batch):
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
next_token_choosers.extend(batch.next_token_choosers)
@ -279,6 +285,7 @@ class FlashCausalLM(Model):
next_batch_max_seqlen = 0
next_batch_past_key_values = []
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_all_input_ids = []
next_batch_all_input_ids_tensor = []
@ -292,6 +299,7 @@ class FlashCausalLM(Model):
iterator = zip(
batch.requests,
batch.input_lengths,
batch.offsets,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
@ -302,6 +310,7 @@ class FlashCausalLM(Model):
for i, (
request,
input_length,
offset,
next_token_chooser,
stopping_criteria,
all_input_ids,
@ -334,9 +343,8 @@ class FlashCausalLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id_item]
next_token_text = self.decode_token(
all_input_ids[-2],
next_token_id_item,
next_token_text, offset = self.decode_token(
all_input_ids[-(stopping_criteria.current_tokens + 1) :], offset
)
# Evaluate stopping criteria
@ -377,6 +385,7 @@ class FlashCausalLM(Model):
next_batch_cu_seqlens[-1] + new_input_length
)
next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset)
next_batch_all_input_ids.append(all_input_ids)
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
@ -453,6 +462,7 @@ class FlashCausalLM(Model):
max_seqlen=next_batch_max_seqlen,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets,
all_input_ids=next_batch_all_input_ids,
all_input_ids_tensor=next_batch_all_input_ids_tensor,
next_token_choosers=next_batch_next_token_choosers,

View File

@ -93,23 +93,21 @@ class GalacticaCausalLMBatch(CausalLMBatch):
inputs = []
next_token_choosers = []
stopping_criterias = []
input_lengths = []
offsets = []
# Parse batch
max_truncation = 0
max_sequence_length = 0
padding_right_offset = 0
for r in pb.requests:
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
input_lengths.append(r.input_length)
offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_sequence_length = max(max_sequence_length, r.input_length)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
@ -123,13 +121,17 @@ class GalacticaCausalLMBatch(CausalLMBatch):
truncation=True,
max_length=max_truncation,
).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_sequence_length + padding_right_offset)
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
@ -144,10 +146,11 @@ class GalacticaCausalLMBatch(CausalLMBatch):
past_key_values=None,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
offsets=offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=max_sequence_length,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
)

View File

@ -24,16 +24,26 @@ class Model(ABC):
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError
def decode_token(self, previous_token_id: int, token_id: int) -> str:
def decode_token(
self, all_input_ids: List[int], offset: Optional[int] = None
) -> Tuple[str, Optional[int]]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# Decode previous token and previous token + token
# Decode all token minus last one and all tokens
results = self.tokenizer.batch_decode(
[[previous_token_id], [previous_token_id, token_id]],
[all_input_ids[:-1], all_input_ids],
skip_special_tokens=False,
)
if results[0] and results[0][0] == " " and results[1][0] != " ":
results[0] = results[0].lstrip()
# default offset is only the last token
if offset is None:
offset = len(results[0])
# slice to remove previous token
return results[1][len(results[0]): ]
# get text
text = results[1][offset:]
# if text is utf-8
if text and text[-1] != "<EFBFBD>":
return text, None
else:
return "", offset

View File

@ -38,6 +38,7 @@ class Seq2SeqLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
decoder_input_lengths: List[int]
offsets: List[Optional[int]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
@ -71,6 +72,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = []
decoder_input_lengths = []
offsets = []
# Parse batch
max_truncation = 0
@ -80,6 +82,7 @@ class Seq2SeqLMBatch(Batch):
# Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1)
offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
@ -117,6 +120,7 @@ class Seq2SeqLMBatch(Batch):
past_key_values=None,
input_lengths=input_lengths.tolist(),
decoder_input_lengths=decoder_input_lengths,
offsets=offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=len(pb.requests),
@ -147,6 +151,7 @@ class Seq2SeqLMBatch(Batch):
requests = []
input_lengths = []
decoder_input_lengths = []
offsets = []
next_token_choosers = []
stopping_criterias = []
@ -166,6 +171,7 @@ class Seq2SeqLMBatch(Batch):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths)
offsets.extend(batch.offsets)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
@ -303,6 +309,7 @@ class Seq2SeqLMBatch(Batch):
past_key_values=past_key_values,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
offsets=offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
@ -422,6 +429,7 @@ class Seq2SeqLM(Model):
# New values for next forward
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = []
@ -437,6 +445,7 @@ class Seq2SeqLM(Model):
iterator = zip(
batch.requests,
batch.input_lengths,
batch.offsets,
batch.decoder_input_lengths,
logits,
batch.next_token_choosers,
@ -448,6 +457,7 @@ class Seq2SeqLM(Model):
for i, (
request,
input_length,
offset,
decoder_input_length,
logits,
next_token_chooser,
@ -466,10 +476,7 @@ class Seq2SeqLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.decode_token(
decoder_input_ids[-2],
next_token_id_squeezed,
)
next_token_text, offset = self.decode_token(decoder_input_ids, offset)
# Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text)
@ -496,6 +503,7 @@ class Seq2SeqLM(Model):
next_batch_size += 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_offsets.append(offset)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
@ -581,6 +589,7 @@ class Seq2SeqLM(Model):
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths,
offsets=next_batch_offsets,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,