mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Use next token chooser
This commit is contained in:
parent
d5685656a4
commit
476d8fc379
@ -39,7 +39,7 @@ class VectorizedCausalLMBatch(Batch):
|
||||
token_offsets: List[Optional[int]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
next_token_chooser: "VectorizedNextTokenChooser"
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
|
||||
# Metadata used for padding
|
||||
@ -93,6 +93,8 @@ class VectorizedCausalLMBatch(Batch):
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
next_token_chooser=VectorizedNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
|
||||
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
@ -132,7 +134,7 @@ class VectorizedCausalLMBatch(Batch):
|
||||
input_lengths=input_lengths.tolist(),
|
||||
offsets=offsets,
|
||||
token_offsets=token_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
next_token_chooser=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_input_length=max_input_length.item(),
|
||||
max_tokens=max_tokens,
|
||||
@ -381,15 +383,16 @@ class VectorizedCausalLM(Model):
|
||||
) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]:
|
||||
key_length=batch.max_input_length
|
||||
query_length=key_length if batch.past_key_values is None else 1
|
||||
input_ids=batch.input_ids[:, key_length-query_length: key_length]
|
||||
|
||||
outputs = self.model.forward(
|
||||
input_ids=batch.input_ids[:, key_length-query_length: key_length],
|
||||
input_ids=input_ids,
|
||||
attention_mask=batch.attention_mask[:, : key_length],
|
||||
position_ids=batch.position_ids[:, key_length-query_length: key_length],
|
||||
past_key_values=batch.past_key_values,
|
||||
)
|
||||
# TODO: Post-processing
|
||||
next_token_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1)
|
||||
next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits[:, -1, :])
|
||||
|
||||
# Update batch
|
||||
# TODO: Why do we need all input ids?
|
||||
@ -402,6 +405,9 @@ class VectorizedCausalLM(Model):
|
||||
next_token_ids=next_token_ids.cpu().tolist()
|
||||
next_token_texts=self.tokenizer.batch_decode(next_token_ids)
|
||||
|
||||
# TODO: Why do we need logprobs?
|
||||
logprobs=logprobs.cpu().tolist()
|
||||
|
||||
# TODO: Vectorize some of this?
|
||||
|
||||
generations: List[Generation] = []
|
||||
@ -409,7 +415,6 @@ class VectorizedCausalLM(Model):
|
||||
|
||||
for i, (next_token_id, next_token_text) in enumerate(zip(next_token_ids, next_token_texts)):
|
||||
stopping_criterias=batch.stopping_criterias[i]
|
||||
next_token_chooser=batch.next_token_choosers[i]
|
||||
stop, reason = stopping_criterias(
|
||||
next_token_id,
|
||||
next_token_text,
|
||||
@ -420,14 +425,9 @@ class VectorizedCausalLM(Model):
|
||||
output_text = self.decode(
|
||||
batch.input_ids[i, -stopping_criterias.current_tokens :]
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
# TODO: Seed
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criterias.current_tokens, reason, seed
|
||||
output_text, stopping_criterias.current_tokens, reason, seed=None
|
||||
)
|
||||
else:
|
||||
# Keep request in the batch
|
||||
@ -437,9 +437,9 @@ class VectorizedCausalLM(Model):
|
||||
|
||||
generation = Generation(
|
||||
batch.requests[i].id,
|
||||
None,
|
||||
None, # TODO: Prefill tokens
|
||||
next_token_id,
|
||||
0,
|
||||
logprobs[i],
|
||||
next_token_text,
|
||||
next_token_id in self.all_special_ids,
|
||||
generated_text,
|
||||
@ -448,3 +448,4 @@ class VectorizedCausalLM(Model):
|
||||
generations.append(generation)
|
||||
|
||||
return generations, next_batch
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user