mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fix no speculation.
This commit is contained in:
parent
9bf31fe388
commit
fdef00c27e
@ -427,7 +427,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
speculative_ids = self.speculative_ids[indices]
|
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
@ -483,7 +483,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots)
|
total_slots += len(b.slots)
|
||||||
blocks += b.blocks
|
blocks += b.blocks
|
||||||
speculative_length = b.speculative_ids.shape[1]
|
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||||
max_length = max(
|
max_length = max(
|
||||||
@ -589,7 +589,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
|
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None
|
||||||
|
|
||||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||||
for b in batches:
|
for b in batches:
|
||||||
@ -980,7 +980,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
stopped = True
|
stopped = True
|
||||||
left = n_accepted_ids - 1 - j
|
left = index + n_accepted_ids - j - 1
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
stopped = False
|
stopped = False
|
||||||
|
Loading…
Reference in New Issue
Block a user