fix: add missing accepted_ids to batch_top_tokens

This commit is contained in:
drbh 2024-02-07 03:57:35 +00:00
parent 48624fee25
commit 2c6ef7c93a

View File

@ -489,10 +489,13 @@ class Mamba(Model):
generations: List[Generation] = []
stopped = True
# Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1),
accepted_ids,
)
start_decode = time.time_ns()