diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index 44118347..5a11ef2b 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -216,9 +216,13 @@ class VectorizedNextTokenChooser: else: self.typical_p=None - self.do_sample = any(do_sample) - if self.do_sample and not all(do_sample): - raise NotImplementedError("Mixed greedy and probabilistic sampling not supported") + num_do_sample=sum(do_sample) + self.do_sample = num_do_sample>0 + if self.do_sample and num_do_sample