diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index f0b568b9..ca9b7ebe 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -175,23 +175,31 @@ class VectorizedNextTokenChooser: repetition_penalty=self._standardize(repetition_penalty, 1.0) if any([x!=1.0 for x in repetition_penalty]): - self.repetition_penalty=torch.tensor([repetition_penalty], dtype=torch.float32, device=device).unsqueeze(1) + self.repetition_penalty=torch.tensor(repetition_penalty, dtype=torch.float32, device=device).unsqueeze(1) else: self.repetition_penalty=None temperature=self._standardize(temperature, 1.0) if any([x!=1.0 for x in temperature]): do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)] - self.temperature=torch.tensor([temperature], dtype=torch.float32, device=device).unsqueeze(1) + self.temperature=torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1) else: self.temperature=None top_k=self._standardize(top_k, 0) - if any([x!=0 for x in top_k]): + n_top_k=sum([x!=0 for x in top_k]) + if n_top_k>0: do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)] - self.top_k=torch.tensor([top_k], dtype=torch.float32, device=device).unsqueeze(1) + self.max_top_k=max(top_k) + self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.float32, device=device).unsqueeze(1) + if n_top_kself.max_top_k: # Safety check + max_top_k=scores.size(-1) + top_k=torch.clamp_max(self.top_k,max_top_k) # Run only if needed. + else: + max_top_k=self.max_top_k + top_k=self.top_k + kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) + if self.top_k_mask is not None: + kth_scores.masked_fill_(self.top_k_mask, self.filter_value) # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] + indices_to_remove = scores < kth_scores scores = scores.masked_fill(indices_to_remove, self.filter_value) # Compute logprobs