diff --git a/backends/gaudi/server/text_generation_server/utils/tokens.py b/backends/gaudi/server/text_generation_server/utils/tokens.py index 9c44ba15..9f5ffb87 100644 --- a/backends/gaudi/server/text_generation_server/utils/tokens.py +++ b/backends/gaudi/server/text_generation_server/utils/tokens.py @@ -552,8 +552,13 @@ def pad_next_token_chooser_parameters( class Sampling: def __init__(self, seed: int, device: str = "cpu"): - self.generator = torch.Generator("cpu") - self.generator.manual_seed(seed) + if device in ["hpu", torch.device("hpu")]: + import habana_frameworks.torch.hpu.random as htrandom + + self.generator = htrandom.default_generators[0].manual_seed(seed) + else: + self.generator = torch.Generator("cpu") + self.generator.manual_seed(seed) self.seed = seed def __call__(self, logits):