diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json new file mode 100644 index 00000000..e9b1c57a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json @@ -0,0 +1,58 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "stop_sequence", + "generated_tokens": 5, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -10.0625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.28125, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 5229, + "logprob": -1.7587891, + "special": false, + "text": " failed" + }, + { + "id": 363, + "logprob": -0.5175781, + "special": false, + "text": " for" + }, + { + "id": 1404, + "logprob": 0.0, + "special": false, + "text": " user" + }, + { + "id": 376, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 1688, + "logprob": -0.20422363, + "special": false, + "text": "test" + } + ] + }, + "generated_text": "Test request failed for user \"test" +} diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json new file mode 100644 index 00000000..eb449de3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json @@ -0,0 +1,88 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -10.0625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.28125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -2.0878906, + "special": false, + "text": " for" + }, + { + "id": 278, + "logprob": -3.4121094, + "special": false, + "text": " the" + }, + { + "id": 376, + "logprob": -3.8457031, + "special": false, + "text": " \"" + }, + { + "id": 2577, + "logprob": -3.5566406, + "special": false, + "text": "Get" + }, + { + "id": 599, + "logprob": -3.4746094, + "special": false, + "text": " all" + }, + { + "id": 4160, + "logprob": -3.2363281, + "special": false, + "text": " users" + }, + { + "id": 29908, + "logprob": -0.49023438, + "special": false, + "text": "\"" + }, + { + "id": 16248, + "logprob": -1.2402344, + "special": false, + "text": " endpoint" + }, + { + "id": 29889, + "logprob": -0.88134766, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.41870117, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": " for the \"Get all users\" endpoint.\n" +} diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py new file mode 100644 index 00000000..64180c0a --- /dev/null +++ b/integration-tests/models/test_flash_medusa.py @@ -0,0 +1,58 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_medusa_handle(launcher): + with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_medusa(flash_medusa_handle): + await flash_medusa_handle.health(300) + return flash_medusa_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_medusa_simple(flash_medusa, response_snapshot): + response = await flash_medusa.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_medusa_all_params(flash_medusa, response_snapshot): + response = await flash_medusa.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 5 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): + responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 06dd3f5c..4aeb447d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -450,26 +450,7 @@ class FlashLlamaModel(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, - speculative_ids: Optional[torch.Tensor] ) -> torch.Tensor: - if speculative_ids is not None: - speculative_length = speculative_ids.shape[1] - new_length = speculative_length + 1 - new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0) - new_position_ids = (position_ids.view((1, -1)).expand(new_length, 1) + torch.arange(new_length).unsqueeze(1).to(device=position_ids.device)).squeeze(0).squeeze(-1) - - # Add an extra block just in case - block_tables = torch.cat([block_tables, block_tables[:, -1:] + 1], dim=1) - # Add Copy the block tables for all members - block_tables = block_tables.expand(new_length, -1).contiguous() - slots = slots.expand(new_length) + torch.arange(new_length, dtype=slots.dtype).to(device=slots.device) - input_lengths = input_lengths.expand(new_length) + torch.arange(new_length, dtype=input_lengths.dtype).to(device=input_lengths.device) - max_s = max_s + speculative_length - - - input_ids = new_input_ids - position_ids = new_position_ids - hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -520,7 +501,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - speculative_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: hidden_states = self.model( input_ids, @@ -531,7 +511,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): slots, input_lengths, max_s, - speculative_ids, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 52679be8..539bdf52 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -3,6 +3,7 @@ import itertools from text_generation_server.utils.tokens import batch_top_tokens import torch import torch.distributed +from loguru import logger import numpy as np @@ -46,7 +47,6 @@ class FlashCausalLMBatch(Batch): # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill cu_seqlen_prefill: Optional[torch.Tensor] - cu_seqlen_speculative: Optional[torch.Tensor] # Paged Attention values @@ -123,7 +123,6 @@ class FlashCausalLMBatch(Batch): position_ids = [] speculative_ids = [] cu_seqlen_prefill = [0] - cu_seqlen_speculative = [0] needed_blocks_slots = [] start_slots = [] slot_indices = [] @@ -163,18 +162,9 @@ class FlashCausalLMBatch(Batch): tokenized_input = tokenized_input[-r.truncate :] - # # TODO remove this - # # Scaffolding to speculate some ids - # speculate_ids = [1, 2] - # tokenized_input.extend([1, 2]) - speculate_ids = [] - - input_length = len(tokenized_input) input_lengths.append(input_length) - - prefix_offsets.append(input_length - 5) read_offsets.append(input_length) @@ -186,7 +176,6 @@ class FlashCausalLMBatch(Batch): # Add cumulative lengths of all previous inputs cu_seqlen_prefill.append(cumulative_length + input_length) - cu_seqlen_speculative.append(cumulative_length + input_length - len(speculate_ids)) next_token_chooser_parameters.append(r.parameters) @@ -199,7 +188,8 @@ class FlashCausalLMBatch(Batch): # Paged attention # Remove one as the first token des not have a past - total_tokens = input_length + max_new_tokens - 1 + speculative_length = 2 + total_tokens = input_length + max_new_tokens - 1 + speculative_length needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) @@ -268,9 +258,6 @@ class FlashCausalLMBatch(Batch): cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int32 ) - cu_seqlen_speculative = torch.tensor( - cu_seqlen_speculative, device=device, dtype=torch.int32 - ) position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) @@ -296,6 +283,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens, device=device, dtype=torch.int64 ) + logger.info("FROM PB") return cls( batch_id=pb.id, requests=pb.requests, @@ -303,7 +291,6 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, - cu_seqlen_speculative=cu_seqlen_speculative, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=needed_blocks_slots, @@ -331,6 +318,7 @@ class FlashCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": + logger.info("FILTER") if len(request_ids) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same @@ -344,6 +332,7 @@ class FlashCausalLMBatch(Batch): # Used to index into tensors indices = [] + batch_indices = [] # slots to keep after filtering slot_filtering_indices = torch.zeros( @@ -371,14 +360,18 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_max_length = 0 + logger.info(f"Request ids {request_ids} {len(self.requests)}") for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] - indices.append(idx) + batch_indices.append(idx) + S = 1 if self.speculative_ids is None else self.speculative_ids.shape[1] + 1 + indices.extend(range(idx * S, (idx + 1) * S)) requests_idx_mapping[request_id] = i requests.append(self.requests[idx]) # Get length + logger.info(f"Input lengths {self.input_lengths} {idx} {S}") request_input_length = self.input_lengths[idx] max_seqlen = max(max_seqlen, request_input_length) @@ -429,14 +422,19 @@ class FlashCausalLMBatch(Batch): self.block_tables = None # Index into tensors + logger.info(f"INPUT IDS {indices} {self.input_ids}") input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] - next_token_chooser = self.next_token_chooser.filter(indices) - top_n_tokens_tensor = self.top_n_tokens_tensor[indices] + next_token_chooser = self.next_token_chooser.filter(batch_indices) + top_n_tokens_tensor = self.top_n_tokens_tensor[batch_indices] + + logger.info(f"{indices} {self.speculative_ids}") + speculative_ids = self.speculative_ids[batch_indices] + logger.info(f"SPEC IDS {speculative_ids}") start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -472,12 +470,14 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + speculative_ids=speculative_ids, ) @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": # Batch attributes + logger.info(f"Concatenate {len(batches)}, {[b.input_ids.shape for b in batches]}") requests = [] requests_idx_mapping = {} @@ -488,7 +488,7 @@ class FlashCausalLMBatch(Batch): max_length = 0 max_seqlen = 0 for b in batches: - total_batch_size += len(b) + total_batch_size += len(b.input_ids) total_slots += len(b.slots) blocks += b.blocks max_blocks = max(max_blocks, b.max_blocks) @@ -536,6 +536,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 + cumulative_1 = 0 cumulative_slots = 0 for i, batch in enumerate(batches): @@ -546,23 +547,28 @@ class FlashCausalLMBatch(Batch): else: # We need to offset the mapping for each batch by the cumulative batch size for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + cumulative_batch_size + requests_idx_mapping[k] = v + cumulative_1 start_index = cumulative_batch_size - end_index = cumulative_batch_size + len(batch) + end_index = cumulative_batch_size + len(batch.input_ids) slots_start_index = cumulative_slots slots_end_index = cumulative_slots + len(batch.slots) + start_index1 = cumulative_1 + end_index1 = cumulative_1 + len(batch) + # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids + logger.info(f"IN concat {batch.slot_indices}") slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + + top_n_tokens_tensor[start_index1:end_index1] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] + start_index1:end_index1, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] block_tables_tensor[ @@ -584,11 +590,14 @@ class FlashCausalLMBatch(Batch): top_n_tokens.extend(batch.top_n_tokens) # Update - cumulative_batch_size += len(batch) + cumulative_batch_size += len(batch.input_ids) cumulative_slots += len(batch.slots) + cumulative_1 += len(batch) start_slots = torch.concat(start_slots) + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, @@ -629,6 +638,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + speculative_ids=speculative_ids, ) def __del__(self): @@ -731,18 +741,28 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE) def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: - # Model Forward + + input_ids=batch.input_ids + position_ids=batch.position_ids + cu_seqlen_prefill=batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths=batch.input_lengths_tensor + max_s=batch.max_seqlen + lm_head_indices=batch.prefill_head_indices + return self.model.forward( - input_ids=batch.input_ids, - position_ids=batch.position_ids, - cu_seqlen_prefill=batch.cu_seqlen_prefill, - kv_cache=get_cache_manager().kv_cache, - block_tables=batch.block_tables_tensor, - slots=batch.slots[batch.slot_indices], - input_lengths=batch.input_lengths_tensor, - max_s=batch.max_seqlen, - lm_head_indices=batch.prefill_head_indices, - speculative_ids =batch.speculative_ids + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=lm_head_indices, + # speculative_ids =batch.speculative_ids ) @tracer.start_as_current_span("generate_token") @@ -790,8 +810,6 @@ class FlashCausalLM(Model): next_token_logits = out - # if next_token_logits.shape[0] == 3: - # import ipdb;ipdb.set_trace() next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits ) @@ -799,20 +817,15 @@ class FlashCausalLM(Model): batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs ) + logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}") - speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1] if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - if speculative_ids is not None: - # length = len(batch) * (1 + speculative_length) - length = len(batch) - else: - length = len(batch) - # import ipdb;ipdb.set_trace() + length = len(batch) next_position_ids = batch.position_ids.new_empty(length) # Keep only 1 slot index, TODO make sure we recover the speculated ids slots later batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] @@ -820,7 +833,6 @@ class FlashCausalLM(Model): batch.cu_seqlen_prefill = None else: prefill_logprobs = None - # import ipdb;ipdb.set_trace() next_position_ids = batch.position_ids # Cumulative length @@ -875,29 +887,32 @@ class FlashCausalLM(Model): start_index + 1 : start_index + out_length ] + # logger.info(f"Request ids {request_ids} {len(self.requests)}") for j in range(n_accepted_ids): batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] index += 1 cumulative_length += input_length - - # if accepted_ids[0] > 1: - # import ipdb;ipdb.set_trace() - - if len(accepted_ids) > 1: - raise Exception("Implemtent the batched behavior") + # if len(accepted_ids) > 1: + # raise Exception("Implemtent the batched behavior") # Set values in batch # batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1) - for n_accepted_ids in accepted_ids: - # TODO Make this batched - batch.input_ids = next_input_ids[-1:] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + n_accepted_ids - batch.input_lengths_tensor += n_accepted_ids - batch.slot_indices += n_accepted_ids + accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype) + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + + logger.info(f"ACCEPTED IDS {accepted_ids} {batch.position_ids}") + if accepted_ids.shape != batch.slot_indices: + # This can happen after a concatenation + # The slot indices is already modified for some speculative_ids + B = batch.slot_indices.shape[0] // accepted_ids.shape[0] + accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1) + batch.slot_indices += accepted_ids + batch.position_ids = next_position_ids + accepted_ids + batch.input_lengths_tensor += accepted_ids if prefill and prefill_logprobs: # Get prefill logprobs @@ -955,26 +970,22 @@ class FlashCausalLM(Model): next_token_texts = [] for j in range(index, index + n_accepted_ids): # Generated token - all_input_ids.append(next_token_ids[j]) + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids, prefix_offset, read_offset, ) next_token_texts.append(next_token_text) - - # Evaluate stopping criteria - - for next_token_id in _next_token_ids: stop, reason = stopping_criteria( next_token_id, next_token_text, ) - if stop: stopped = True break - if not stop: + else: stopped = False # Shard generations @@ -1068,8 +1079,27 @@ class FlashCausalLM(Model): batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None - if prefill: - batch.max_seqlen += speculative_length batch.max_seqlen = batch.max_seqlen + 1 + # Model Forward + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + batch.input_ids = torch.cat([batch.input_ids.unsqueeze(-1), batch.speculative_ids], dim=1).view(-1) + if batch.position_ids.shape[0] != B * new_length: + arange = torch.arange(new_length).unsqueeze(0).to(device=batch.position_ids.device) + batch.position_ids = (batch.position_ids.view((-1, 1)).expand(B,new_length) + arange).view(-1) + batch.slot_indices = (batch.slot_indices.view((-1, 1)).expand(B,new_length) + arange.to(dtype=batch.slot_indices.dtype)).view(-1) + batch.input_lengths_tensor = (batch.input_lengths_tensor.view((-1, 1)).expand(B,new_length) + arange.to(dtype=batch.input_lengths_tensor.dtype)).view(-1) + batch.max_seqlen = batch.max_seqlen + speculative_length + # Add an extra block just in case + block_tables = torch.cat([batch.block_tables_tensor, batch.block_tables_tensor[:, -1:] + 1], dim=1) + # Add Copy the block tables for all members + # Contiguous because paged assumes contiguity + batch.block_tables_tensor = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous() + + batch.lm_head_indices=batch.prefill_head_indices + cu_seqlen_prefill=batch.cu_seqlen_prefill + + return generations, batch diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index c5e07cca..2f4b4f7a 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -225,21 +225,35 @@ class HeterogeneousNextTokenChooser: scores = warper(input_ids, scores) - accepted_ids = [] next_ids = self.choice(scores) + from loguru import logger if speculated_ids is not None: - validate_speculative = next_ids[:-1] == speculated_ids[0] - index = 1 - for valid in validate_speculative.tolist(): - if valid: - index += 1 - # print(f"Validated {index - 1}") - next_ids = next_ids[:index] - scores = scores[:index] - speculative_scores = speculative_scores[index - 1:index] - accepted_ids.append(index) + logger.info(f"CHOOSER {next_ids} {speculated_ids}") + accepted_ids = [] + B = next_ids.shape[0] // (speculated_ids.shape[1] + 1) + S = speculated_ids.shape[1] + 1 + indices = [] + for i in range(B): + _next_ids = next_ids[i*S: (i + 1)*S] + _speculated_ids = speculated_ids[i] + validate_speculative = _next_ids[:-1] == _speculated_ids + index = i * S + accepted = 1 + # First is always valid + indices.append(index) + for valid in validate_speculative.tolist(): + if valid: + index += 1 + accepted += 1 + indices.append(index) + # print(f"Validated {accepted}") + accepted_ids.append(accepted) + accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype) + next_ids = next_ids[indices] + scores = scores[indices] + speculative_scores = speculative_scores[accepted_ids.cumsum(dim=-1) - 1] else: - accepted_ids.append(1) + accepted_ids = torch.ones_like(next_ids) logprobs = torch.log_softmax(scores, -1)