refactor(neuron): remove obsolete code paths

This commit is contained in:
David Corvoysier 2025-05-23 08:27:27 +00:00
parent 11184f804a
commit e586f8bdd6

View File

@ -211,19 +211,11 @@ class Slot:
self._mask = attention_mask.clone()
self._selector = selector
def pause(self, reset_on_pause: bool):
def pause(self):
"""Mark the current slot as paused for generation.
Note that the KV cache for this slot will still be filled.
"""
if reset_on_pause:
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Since generated tokens are now part of the prefill, we need to reevaluate
# max_new_tokens for the next generation
self._generation_config.max_new_tokens = (
self._max_new_tokens - self._generated_tokens
)
self._state = Slot.State.PAUSE
def resume(self):
@ -340,7 +332,12 @@ class NeuronGenerator(Generator):
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
self.rebuild_cache_on_prefill = not self.model.continuous_batching
if not isinstance(self.model, NeuronModelForCausalLM):
raise ValueError("The model must be a NeuronModelForCausalLM.")
if not model.neuron_config.continuous_batching:
raise ValueError(
"The neuron model must be compiled with continuous_batching=True."
)
# Specify padding and truncation options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
@ -412,14 +409,8 @@ class NeuronGenerator(Generator):
logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
)
if self.rebuild_cache_on_prefill:
# We will clear pending slots and prefill all slots
prefill_slots = self.slots
seq_ids = None
else:
# We only need to pass inputs for the new requests
prefill_slots = new_slots
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
prefill_slots = new_slots
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
# Reconstruct the full inputs (without padding) as seen by the model.
# This comprises:
# - the inputs for new requests,
@ -445,12 +436,8 @@ class NeuronGenerator(Generator):
input_ids = padded_inputs.input_ids
attention_mask = padded_inputs.attention_mask
# Pause previously active slots during generation
next_tokens = []
for slot in active_slots:
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
if self.rebuild_cache_on_prefill:
# The slot will be reset, so we need to store its next token
next_tokens.append(slot.next_token)
slot.pause()
# Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(prefill_slots):
if slot.state != slot.state.EMPTY:
@ -484,9 +471,6 @@ class NeuronGenerator(Generator):
# Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots):
slot.resume()
if self.rebuild_cache_on_prefill:
# Append back the next token
slot.append(next_tokens[i])
logger.debug("Model ready for decoding")
if next_batch is not None:
logger.debug(
@ -530,12 +514,8 @@ class NeuronGenerator(Generator):
raise ValueError(
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
)
if self.model.continuous_batching:
decode_slots = active_slots
seq_ids = torch.tensor([slot.id for slot in decode_slots])
else:
decode_slots = self.slots
seq_ids = None
decode_slots = active_slots
seq_ids = torch.tensor([slot.id for slot in decode_slots])
# Reconstruct input_ids and attention_mask from decode slots
n_slots = len(decode_slots)
input_ids = torch.full(