Fix no speculation.

This commit is contained in:
Nicolas Patry 2023-12-05 22:02:22 +00:00
parent 9bf31fe388
commit fdef00c27e

View File

@ -427,7 +427,7 @@ class FlashCausalLMBatch(Batch):
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -483,7 +483,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) total_slots += len(b.slots)
blocks += b.blocks blocks += b.blocks
speculative_length = b.speculative_ids.shape[1] speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen) max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max( max_length = max(
@ -589,7 +589,7 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
) )
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None
# Needed to avoid dropping blocks when the batches will go out of scope # Needed to avoid dropping blocks when the batches will go out of scope
for b in batches: for b in batches:
@ -980,7 +980,7 @@ class FlashCausalLM(Model):
if stop: if stop:
stopped = True stopped = True
left = n_accepted_ids - 1 - j left = index + n_accepted_ids - j - 1
break break
else: else:
stopped = False stopped = False