mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 02:42:05 +00:00
refactor(neuron): remove obsolete code paths
This commit is contained in:
parent
11184f804a
commit
e586f8bdd6
@ -211,19 +211,11 @@ class Slot:
|
|||||||
self._mask = attention_mask.clone()
|
self._mask = attention_mask.clone()
|
||||||
self._selector = selector
|
self._selector = selector
|
||||||
|
|
||||||
def pause(self, reset_on_pause: bool):
|
def pause(self):
|
||||||
"""Mark the current slot as paused for generation.
|
"""Mark the current slot as paused for generation.
|
||||||
|
|
||||||
Note that the KV cache for this slot will still be filled.
|
Note that the KV cache for this slot will still be filled.
|
||||||
"""
|
"""
|
||||||
if reset_on_pause:
|
|
||||||
# Drop the last token as it will be added back when resuming the slot
|
|
||||||
self._generated_tokens -= 1
|
|
||||||
# Since generated tokens are now part of the prefill, we need to reevaluate
|
|
||||||
# max_new_tokens for the next generation
|
|
||||||
self._generation_config.max_new_tokens = (
|
|
||||||
self._max_new_tokens - self._generated_tokens
|
|
||||||
)
|
|
||||||
self._state = Slot.State.PAUSE
|
self._state = Slot.State.PAUSE
|
||||||
|
|
||||||
def resume(self):
|
def resume(self):
|
||||||
@ -340,7 +332,12 @@ class NeuronGenerator(Generator):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.rebuild_cache_on_prefill = not self.model.continuous_batching
|
if not isinstance(self.model, NeuronModelForCausalLM):
|
||||||
|
raise ValueError("The model must be a NeuronModelForCausalLM.")
|
||||||
|
if not model.neuron_config.continuous_batching:
|
||||||
|
raise ValueError(
|
||||||
|
"The neuron model must be compiled with continuous_batching=True."
|
||||||
|
)
|
||||||
# Specify padding and truncation options for decoder-only architecture
|
# Specify padding and truncation options for decoder-only architecture
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
@ -412,14 +409,8 @@ class NeuronGenerator(Generator):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
|
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
|
||||||
)
|
)
|
||||||
if self.rebuild_cache_on_prefill:
|
prefill_slots = new_slots
|
||||||
# We will clear pending slots and prefill all slots
|
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
|
||||||
prefill_slots = self.slots
|
|
||||||
seq_ids = None
|
|
||||||
else:
|
|
||||||
# We only need to pass inputs for the new requests
|
|
||||||
prefill_slots = new_slots
|
|
||||||
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
|
|
||||||
# Reconstruct the full inputs (without padding) as seen by the model.
|
# Reconstruct the full inputs (without padding) as seen by the model.
|
||||||
# This comprises:
|
# This comprises:
|
||||||
# - the inputs for new requests,
|
# - the inputs for new requests,
|
||||||
@ -445,12 +436,8 @@ class NeuronGenerator(Generator):
|
|||||||
input_ids = padded_inputs.input_ids
|
input_ids = padded_inputs.input_ids
|
||||||
attention_mask = padded_inputs.attention_mask
|
attention_mask = padded_inputs.attention_mask
|
||||||
# Pause previously active slots during generation
|
# Pause previously active slots during generation
|
||||||
next_tokens = []
|
|
||||||
for slot in active_slots:
|
for slot in active_slots:
|
||||||
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
|
slot.pause()
|
||||||
if self.rebuild_cache_on_prefill:
|
|
||||||
# The slot will be reset, so we need to store its next token
|
|
||||||
next_tokens.append(slot.next_token)
|
|
||||||
# Each slot must be reset with the padded inputs and masks
|
# Each slot must be reset with the padded inputs and masks
|
||||||
for i, slot in enumerate(prefill_slots):
|
for i, slot in enumerate(prefill_slots):
|
||||||
if slot.state != slot.state.EMPTY:
|
if slot.state != slot.state.EMPTY:
|
||||||
@ -484,9 +471,6 @@ class NeuronGenerator(Generator):
|
|||||||
# Reactivate previously active slots for the next decode
|
# Reactivate previously active slots for the next decode
|
||||||
for i, slot in enumerate(active_slots):
|
for i, slot in enumerate(active_slots):
|
||||||
slot.resume()
|
slot.resume()
|
||||||
if self.rebuild_cache_on_prefill:
|
|
||||||
# Append back the next token
|
|
||||||
slot.append(next_tokens[i])
|
|
||||||
logger.debug("Model ready for decoding")
|
logger.debug("Model ready for decoding")
|
||||||
if next_batch is not None:
|
if next_batch is not None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -530,12 +514,8 @@ class NeuronGenerator(Generator):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
|
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
|
||||||
)
|
)
|
||||||
if self.model.continuous_batching:
|
decode_slots = active_slots
|
||||||
decode_slots = active_slots
|
seq_ids = torch.tensor([slot.id for slot in decode_slots])
|
||||||
seq_ids = torch.tensor([slot.id for slot in decode_slots])
|
|
||||||
else:
|
|
||||||
decode_slots = self.slots
|
|
||||||
seq_ids = None
|
|
||||||
# Reconstruct input_ids and attention_mask from decode slots
|
# Reconstruct input_ids and attention_mask from decode slots
|
||||||
n_slots = len(decode_slots)
|
n_slots = len(decode_slots)
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
|
Loading…
Reference in New Issue
Block a user