mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Fix no speculation.
This commit is contained in:
parent
9bf31fe388
commit
fdef00c27e
@ -427,7 +427,7 @@ class FlashCausalLMBatch(Batch):
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||
speculative_ids = self.speculative_ids[indices]
|
||||
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
@ -483,7 +483,7 @@ class FlashCausalLMBatch(Batch):
|
||||
total_batch_size += len(b)
|
||||
total_slots += len(b.slots)
|
||||
blocks += b.blocks
|
||||
speculative_length = b.speculative_ids.shape[1]
|
||||
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||
max_length = max(
|
||||
@ -589,7 +589,7 @@ class FlashCausalLMBatch(Batch):
|
||||
device=batches[0].next_token_chooser.device,
|
||||
)
|
||||
|
||||
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None
|
||||
|
||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||
for b in batches:
|
||||
@ -980,7 +980,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
if stop:
|
||||
stopped = True
|
||||
left = n_accepted_ids - 1 - j
|
||||
left = index + n_accepted_ids - j - 1
|
||||
break
|
||||
else:
|
||||
stopped = False
|
||||
|
Loading…
Reference in New Issue
Block a user