mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-08 10:22:05 +00:00
refactor(neuron): remove obsolete code paths
This commit is contained in:
parent
11184f804a
commit
e586f8bdd6
@ -211,19 +211,11 @@ class Slot:
|
||||
self._mask = attention_mask.clone()
|
||||
self._selector = selector
|
||||
|
||||
def pause(self, reset_on_pause: bool):
|
||||
def pause(self):
|
||||
"""Mark the current slot as paused for generation.
|
||||
|
||||
Note that the KV cache for this slot will still be filled.
|
||||
"""
|
||||
if reset_on_pause:
|
||||
# Drop the last token as it will be added back when resuming the slot
|
||||
self._generated_tokens -= 1
|
||||
# Since generated tokens are now part of the prefill, we need to reevaluate
|
||||
# max_new_tokens for the next generation
|
||||
self._generation_config.max_new_tokens = (
|
||||
self._max_new_tokens - self._generated_tokens
|
||||
)
|
||||
self._state = Slot.State.PAUSE
|
||||
|
||||
def resume(self):
|
||||
@ -340,7 +332,12 @@ class NeuronGenerator(Generator):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
):
|
||||
self.model = model
|
||||
self.rebuild_cache_on_prefill = not self.model.continuous_batching
|
||||
if not isinstance(self.model, NeuronModelForCausalLM):
|
||||
raise ValueError("The model must be a NeuronModelForCausalLM.")
|
||||
if not model.neuron_config.continuous_batching:
|
||||
raise ValueError(
|
||||
"The neuron model must be compiled with continuous_batching=True."
|
||||
)
|
||||
# Specify padding and truncation options for decoder-only architecture
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer.padding_side = "left"
|
||||
@ -412,14 +409,8 @@ class NeuronGenerator(Generator):
|
||||
logger.debug(
|
||||
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
|
||||
)
|
||||
if self.rebuild_cache_on_prefill:
|
||||
# We will clear pending slots and prefill all slots
|
||||
prefill_slots = self.slots
|
||||
seq_ids = None
|
||||
else:
|
||||
# We only need to pass inputs for the new requests
|
||||
prefill_slots = new_slots
|
||||
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
|
||||
prefill_slots = new_slots
|
||||
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
|
||||
# Reconstruct the full inputs (without padding) as seen by the model.
|
||||
# This comprises:
|
||||
# - the inputs for new requests,
|
||||
@ -445,12 +436,8 @@ class NeuronGenerator(Generator):
|
||||
input_ids = padded_inputs.input_ids
|
||||
attention_mask = padded_inputs.attention_mask
|
||||
# Pause previously active slots during generation
|
||||
next_tokens = []
|
||||
for slot in active_slots:
|
||||
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
|
||||
if self.rebuild_cache_on_prefill:
|
||||
# The slot will be reset, so we need to store its next token
|
||||
next_tokens.append(slot.next_token)
|
||||
slot.pause()
|
||||
# Each slot must be reset with the padded inputs and masks
|
||||
for i, slot in enumerate(prefill_slots):
|
||||
if slot.state != slot.state.EMPTY:
|
||||
@ -484,9 +471,6 @@ class NeuronGenerator(Generator):
|
||||
# Reactivate previously active slots for the next decode
|
||||
for i, slot in enumerate(active_slots):
|
||||
slot.resume()
|
||||
if self.rebuild_cache_on_prefill:
|
||||
# Append back the next token
|
||||
slot.append(next_tokens[i])
|
||||
logger.debug("Model ready for decoding")
|
||||
if next_batch is not None:
|
||||
logger.debug(
|
||||
@ -530,12 +514,8 @@ class NeuronGenerator(Generator):
|
||||
raise ValueError(
|
||||
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
|
||||
)
|
||||
if self.model.continuous_batching:
|
||||
decode_slots = active_slots
|
||||
seq_ids = torch.tensor([slot.id for slot in decode_slots])
|
||||
else:
|
||||
decode_slots = self.slots
|
||||
seq_ids = None
|
||||
decode_slots = active_slots
|
||||
seq_ids = torch.tensor([slot.id for slot in decode_slots])
|
||||
# Reconstruct input_ids and attention_mask from decode slots
|
||||
n_slots = len(decode_slots)
|
||||
input_ids = torch.full(
|
||||
|
Loading…
Reference in New Issue
Block a user