diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index ca9b7ebe..44118347 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -166,6 +166,7 @@ class VectorizedNextTokenChooser: device="cpu", ): self.batch_size=batch_size + self.filter_value = -float("Inf") do_sample=self._standardize(do_sample, False) @@ -191,7 +192,7 @@ class VectorizedNextTokenChooser: if n_top_k>0: do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)] self.max_top_k=max(top_k) - self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.float32, device=device).unsqueeze(1) + self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.int64, device=device).unsqueeze(1) if n_top_k sorted_scores.gather(1, last_ind.view(-1, 1)) + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + # Compute logprobs logprobs = torch.log_softmax(scores, dim=-1)