diff --git a/server/tests/conftest.py b/server/tests/conftest.py index e0ed76b4..9fae8ee1 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2 def default_pb_parameters(): return generate_pb2.NextTokenChooserParameters( temperature=1.0, + repetition_penalty=1.0, top_k=0, top_p=1.0, do_sample=False, diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 4dc834b8..994c57d5 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -336,7 +336,7 @@ class CausalLM(Model): all_input_ids, ) in enumerate(iterator): # Select next token - tokens, logprobs = next_token_chooser(all_input_ids, logits) + tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits) next_token_id = tokens[-1].view(1, 1) # Append next token to all tokens diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 29492dd7..245dca12 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -418,7 +418,7 @@ class Seq2SeqLM(Model): decoder_input_ids, ) in enumerate(iterator): # Select next token - next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits) + next_token_id, logprobs = next_token_chooser(decoder_input_ids.view(1, -1), logits) # Append next token to decoder tokens decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 3433db43..8072a9f1 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -78,12 +78,8 @@ class NextTokenChooser: self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): - input_ids = input_ids.unsqueeze(0) - scores = scores.unsqueeze(-1) - # Warp logits scores = self.warpers(input_ids, scores) - scores = scores.squeeze(-1) # Compute logprobs logprobs = torch.log_softmax(scores, -1)