diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index fc6d4760..f9df7855 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -441,7 +441,7 @@ class VectorizedNextTokenChooser: if return_logprobs: # Compute logprobs if scores.size(1)==1: - scores=scores.unsqueeze(1) + scores=last_token_scores.unsqueeze(1) else: # TODO: Post-process all the tokens? scores[:, -1, :]=last_token_scores @@ -560,6 +560,9 @@ class VectorizedCausalLM(Model): self, batch: VectorizedCausalLMBatch ) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]: key_length=batch.max_input_length + if key_length>batch.input_ids.size(1): + raise RuntimeError("Cannot generate more than `max_tokens`.") + query_length=key_length if batch.past_key_values is None else 1 input_ids=batch.input_ids[:, key_length-query_length: key_length] @@ -572,16 +575,15 @@ class VectorizedCausalLM(Model): # TODO: Post-processing next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits, batch.details) - next_token_ids=next_token_ids.cpu().tolist() if batch.generate_stream: # TODO: self.decode_token, offsets? - next_token_texts=self.tokenizer.batch_decode(next_token_ids) + next_token_texts=self.tokenizer.batch_decode(next_token_ids.tolist()) if batch.details: token_logprobs=logprobs[:, -1, :].gather(1, next_token_ids.unsqueeze(1)).squeeze(1).tolist() if query_length>1: prefill_token_ids=batch.input_ids[:, :key_length].tolist() - prefill_logprobs=logprobs.gather(1, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist() + prefill_logprobs=logprobs.gather(2, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist() prefill_tokens=[] for prefill_token_ids_, prefill_logprobs_, input_length in zip(prefill_token_ids, prefill_logprobs, batch.input_lengths): prefill_token_ids_=prefill_token_ids_[-input_length:]