diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index af09f70f..c10910aa 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -489,10 +489,13 @@ class Mamba(Model): generations: List[Generation] = [] stopped = True + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, torch.log_softmax(logits[:, -1], -1), + accepted_ids, ) start_decode = time.time_ns()