refactor(neuron): remove obsolete code paths

This commit is contained in:
David Corvoysier 2025-05-23 08:27:27 +00:00
parent 11184f804a
commit e586f8bdd6

View File

@ -211,19 +211,11 @@ class Slot:
self._mask = attention_mask.clone() self._mask = attention_mask.clone()
self._selector = selector self._selector = selector
def pause(self, reset_on_pause: bool): def pause(self):
"""Mark the current slot as paused for generation. """Mark the current slot as paused for generation.
Note that the KV cache for this slot will still be filled. Note that the KV cache for this slot will still be filled.
""" """
if reset_on_pause:
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Since generated tokens are now part of the prefill, we need to reevaluate
# max_new_tokens for the next generation
self._generation_config.max_new_tokens = (
self._max_new_tokens - self._generated_tokens
)
self._state = Slot.State.PAUSE self._state = Slot.State.PAUSE
def resume(self): def resume(self):
@ -340,7 +332,12 @@ class NeuronGenerator(Generator):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
): ):
self.model = model self.model = model
self.rebuild_cache_on_prefill = not self.model.continuous_batching if not isinstance(self.model, NeuronModelForCausalLM):
raise ValueError("The model must be a NeuronModelForCausalLM.")
if not model.neuron_config.continuous_batching:
raise ValueError(
"The neuron model must be compiled with continuous_batching=True."
)
# Specify padding and truncation options for decoder-only architecture # Specify padding and truncation options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
@ -412,14 +409,8 @@ class NeuronGenerator(Generator):
logger.debug( logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}" f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
) )
if self.rebuild_cache_on_prefill: prefill_slots = new_slots
# We will clear pending slots and prefill all slots seq_ids = torch.tensor([slot.id for slot in prefill_slots])
prefill_slots = self.slots
seq_ids = None
else:
# We only need to pass inputs for the new requests
prefill_slots = new_slots
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
# Reconstruct the full inputs (without padding) as seen by the model. # Reconstruct the full inputs (without padding) as seen by the model.
# This comprises: # This comprises:
# - the inputs for new requests, # - the inputs for new requests,
@ -445,12 +436,8 @@ class NeuronGenerator(Generator):
input_ids = padded_inputs.input_ids input_ids = padded_inputs.input_ids
attention_mask = padded_inputs.attention_mask attention_mask = padded_inputs.attention_mask
# Pause previously active slots during generation # Pause previously active slots during generation
next_tokens = []
for slot in active_slots: for slot in active_slots:
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill) slot.pause()
if self.rebuild_cache_on_prefill:
# The slot will be reset, so we need to store its next token
next_tokens.append(slot.next_token)
# Each slot must be reset with the padded inputs and masks # Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(prefill_slots): for i, slot in enumerate(prefill_slots):
if slot.state != slot.state.EMPTY: if slot.state != slot.state.EMPTY:
@ -484,9 +471,6 @@ class NeuronGenerator(Generator):
# Reactivate previously active slots for the next decode # Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots): for i, slot in enumerate(active_slots):
slot.resume() slot.resume()
if self.rebuild_cache_on_prefill:
# Append back the next token
slot.append(next_tokens[i])
logger.debug("Model ready for decoding") logger.debug("Model ready for decoding")
if next_batch is not None: if next_batch is not None:
logger.debug( logger.debug(
@ -530,12 +514,8 @@ class NeuronGenerator(Generator):
raise ValueError( raise ValueError(
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)" "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
) )
if self.model.continuous_batching: decode_slots = active_slots
decode_slots = active_slots seq_ids = torch.tensor([slot.id for slot in decode_slots])
seq_ids = torch.tensor([slot.id for slot in decode_slots])
else:
decode_slots = self.slots
seq_ids = None
# Reconstruct input_ids and attention_mask from decode slots # Reconstruct input_ids and attention_mask from decode slots
n_slots = len(decode_slots) n_slots = len(decode_slots)
input_ids = torch.full( input_ids = torch.full(