From 476d8fc379037c7f545afa1b7874e97957fe8a5b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 4 May 2023 11:52:11 -0400 Subject: [PATCH] Use next token chooser --- .../models/vectorized_causal_lm.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index 5a11ef2b..72232d64 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -39,7 +39,7 @@ class VectorizedCausalLMBatch(Batch): token_offsets: List[Optional[int]] # Generation helpers - next_token_choosers: List[NextTokenChooser] + next_token_chooser: "VectorizedNextTokenChooser" stopping_criterias: List[StoppingCriteria] # Metadata used for padding @@ -93,6 +93,8 @@ class VectorizedCausalLMBatch(Batch): padding_right_offset, stopping_criteria.max_new_tokens ) + next_token_chooser=VectorizedNextTokenChooser.from_pb([r.parameters for r in pb.requests], device) + tokenized_inputs = tokenizer( inputs, return_tensors="pt", @@ -132,7 +134,7 @@ class VectorizedCausalLMBatch(Batch): input_lengths=input_lengths.tolist(), offsets=offsets, token_offsets=token_offsets, - next_token_choosers=next_token_choosers, + next_token_chooser=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), max_tokens=max_tokens, @@ -381,15 +383,16 @@ class VectorizedCausalLM(Model): ) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]: key_length=batch.max_input_length query_length=key_length if batch.past_key_values is None else 1 + input_ids=batch.input_ids[:, key_length-query_length: key_length] outputs = self.model.forward( - input_ids=batch.input_ids[:, key_length-query_length: key_length], + input_ids=input_ids, attention_mask=batch.attention_mask[:, : key_length], position_ids=batch.position_ids[:, key_length-query_length: key_length], past_key_values=batch.past_key_values, ) # TODO: Post-processing - next_token_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1) + next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits[:, -1, :]) # Update batch # TODO: Why do we need all input ids? @@ -402,6 +405,9 @@ class VectorizedCausalLM(Model): next_token_ids=next_token_ids.cpu().tolist() next_token_texts=self.tokenizer.batch_decode(next_token_ids) + # TODO: Why do we need logprobs? + logprobs=logprobs.cpu().tolist() + # TODO: Vectorize some of this? generations: List[Generation] = [] @@ -409,7 +415,6 @@ class VectorizedCausalLM(Model): for i, (next_token_id, next_token_text) in enumerate(zip(next_token_ids, next_token_texts)): stopping_criterias=batch.stopping_criterias[i] - next_token_chooser=batch.next_token_choosers[i] stop, reason = stopping_criterias( next_token_id, next_token_text, @@ -420,14 +425,9 @@ class VectorizedCausalLM(Model): output_text = self.decode( batch.input_ids[i, -stopping_criterias.current_tokens :] ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - + # TODO: Seed generated_text = GeneratedText( - output_text, stopping_criterias.current_tokens, reason, seed + output_text, stopping_criterias.current_tokens, reason, seed=None ) else: # Keep request in the batch @@ -437,9 +437,9 @@ class VectorizedCausalLM(Model): generation = Generation( batch.requests[i].id, - None, + None, # TODO: Prefill tokens next_token_id, - 0, + logprobs[i], next_token_text, next_token_id in self.all_special_ids, generated_text, @@ -448,3 +448,4 @@ class VectorizedCausalLM(Model): generations.append(generation) return generations, next_batch +