Use next token chooser

This commit is contained in:
Joel Lamy-Poirier 2023-05-04 11:52:11 -04:00
parent d5685656a4
commit 476d8fc379
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -39,7 +39,7 @@ class VectorizedCausalLMBatch(Batch):
token_offsets: List[Optional[int]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
next_token_chooser: "VectorizedNextTokenChooser"
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
@ -93,6 +93,8 @@ class VectorizedCausalLMBatch(Batch):
padding_right_offset, stopping_criteria.max_new_tokens
)
next_token_chooser=VectorizedNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
@ -132,7 +134,7 @@ class VectorizedCausalLMBatch(Batch):
input_lengths=input_lengths.tolist(),
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
next_token_chooser=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
max_tokens=max_tokens,
@ -381,15 +383,16 @@ class VectorizedCausalLM(Model):
) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]:
key_length=batch.max_input_length
query_length=key_length if batch.past_key_values is None else 1
input_ids=batch.input_ids[:, key_length-query_length: key_length]
outputs = self.model.forward(
input_ids=batch.input_ids[:, key_length-query_length: key_length],
input_ids=input_ids,
attention_mask=batch.attention_mask[:, : key_length],
position_ids=batch.position_ids[:, key_length-query_length: key_length],
past_key_values=batch.past_key_values,
)
# TODO: Post-processing
next_token_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1)
next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits[:, -1, :])
# Update batch
# TODO: Why do we need all input ids?
@ -402,6 +405,9 @@ class VectorizedCausalLM(Model):
next_token_ids=next_token_ids.cpu().tolist()
next_token_texts=self.tokenizer.batch_decode(next_token_ids)
# TODO: Why do we need logprobs?
logprobs=logprobs.cpu().tolist()
# TODO: Vectorize some of this?
generations: List[Generation] = []
@ -409,7 +415,6 @@ class VectorizedCausalLM(Model):
for i, (next_token_id, next_token_text) in enumerate(zip(next_token_ids, next_token_texts)):
stopping_criterias=batch.stopping_criterias[i]
next_token_chooser=batch.next_token_choosers[i]
stop, reason = stopping_criterias(
next_token_id,
next_token_text,
@ -420,14 +425,9 @@ class VectorizedCausalLM(Model):
output_text = self.decode(
batch.input_ids[i, -stopping_criterias.current_tokens :]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
# TODO: Seed
generated_text = GeneratedText(
output_text, stopping_criterias.current_tokens, reason, seed
output_text, stopping_criterias.current_tokens, reason, seed=None
)
else:
# Keep request in the batch
@ -437,9 +437,9 @@ class VectorizedCausalLM(Model):
generation = Generation(
batch.requests[i].id,
None,
None, # TODO: Prefill tokens
next_token_id,
0,
logprobs[i],
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
@ -448,3 +448,4 @@ class VectorizedCausalLM(Model):
generations.append(generation)
return generations, next_batch