From 4554a69b222d1af6831b7091a7354cbaf0f18fb8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 3 May 2023 14:55:08 -0400 Subject: [PATCH] Top p and typical p --- .../models/vectorized_causal_lm.py | 41 +++++++++++++++++-- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index ca9b7ebe..44118347 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -166,6 +166,7 @@ class VectorizedNextTokenChooser: device="cpu", ): self.batch_size=batch_size + self.filter_value = -float("Inf") do_sample=self._standardize(do_sample, False) @@ -191,7 +192,7 @@ class VectorizedNextTokenChooser: if n_top_k>0: do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)] self.max_top_k=max(top_k) - self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.float32, device=device).unsqueeze(1) + self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.int64, device=device).unsqueeze(1) if n_top_k sorted_scores.gather(1, last_ind.view(-1, 1)) + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + # Compute logprobs logprobs = torch.log_softmax(scores, dim=-1)