From 2c6ef7c93aeb2c2538e3ed77a1ef88c0bb0261c6 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 7 Feb 2024 03:57:35 +0000 Subject: [PATCH] fix: add missing accepted_ids to batch_top_tokens --- server/text_generation_server/models/mamba.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index af09f70f..c10910aa 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -489,10 +489,13 @@ class Mamba(Model): generations: List[Generation] = [] stopped = True + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, torch.log_softmax(logits[:, -1], -1), + accepted_ids, ) start_decode = time.time_ns()