diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py index c09564ac..1e476819 100644 --- a/backends/neuron/server/text_generation_server/generator.py +++ b/backends/neuron/server/text_generation_server/generator.py @@ -211,19 +211,11 @@ class Slot: self._mask = attention_mask.clone() self._selector = selector - def pause(self, reset_on_pause: bool): + def pause(self): """Mark the current slot as paused for generation. Note that the KV cache for this slot will still be filled. """ - if reset_on_pause: - # Drop the last token as it will be added back when resuming the slot - self._generated_tokens -= 1 - # Since generated tokens are now part of the prefill, we need to reevaluate - # max_new_tokens for the next generation - self._generation_config.max_new_tokens = ( - self._max_new_tokens - self._generated_tokens - ) self._state = Slot.State.PAUSE def resume(self): @@ -340,7 +332,12 @@ class NeuronGenerator(Generator): tokenizer: PreTrainedTokenizerBase, ): self.model = model - self.rebuild_cache_on_prefill = not self.model.continuous_batching + if not isinstance(self.model, NeuronModelForCausalLM): + raise ValueError("The model must be a NeuronModelForCausalLM.") + if not model.neuron_config.continuous_batching: + raise ValueError( + "The neuron model must be compiled with continuous_batching=True." + ) # Specify padding and truncation options for decoder-only architecture tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" @@ -412,14 +409,8 @@ class NeuronGenerator(Generator): logger.debug( f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}" ) - if self.rebuild_cache_on_prefill: - # We will clear pending slots and prefill all slots - prefill_slots = self.slots - seq_ids = None - else: - # We only need to pass inputs for the new requests - prefill_slots = new_slots - seq_ids = torch.tensor([slot.id for slot in prefill_slots]) + prefill_slots = new_slots + seq_ids = torch.tensor([slot.id for slot in prefill_slots]) # Reconstruct the full inputs (without padding) as seen by the model. # This comprises: # - the inputs for new requests, @@ -445,12 +436,8 @@ class NeuronGenerator(Generator): input_ids = padded_inputs.input_ids attention_mask = padded_inputs.attention_mask # Pause previously active slots during generation - next_tokens = [] for slot in active_slots: - slot.pause(reset_on_pause=self.rebuild_cache_on_prefill) - if self.rebuild_cache_on_prefill: - # The slot will be reset, so we need to store its next token - next_tokens.append(slot.next_token) + slot.pause() # Each slot must be reset with the padded inputs and masks for i, slot in enumerate(prefill_slots): if slot.state != slot.state.EMPTY: @@ -484,9 +471,6 @@ class NeuronGenerator(Generator): # Reactivate previously active slots for the next decode for i, slot in enumerate(active_slots): slot.resume() - if self.rebuild_cache_on_prefill: - # Append back the next token - slot.append(next_tokens[i]) logger.debug("Model ready for decoding") if next_batch is not None: logger.debug( @@ -530,12 +514,8 @@ class NeuronGenerator(Generator): raise ValueError( "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)" ) - if self.model.continuous_batching: - decode_slots = active_slots - seq_ids = torch.tensor([slot.id for slot in decode_slots]) - else: - decode_slots = self.slots - seq_ids = None + decode_slots = active_slots + seq_ids = torch.tensor([slot.id for slot in decode_slots]) # Reconstruct input_ids and attention_mask from decode slots n_slots = len(decode_slots) input_ids = torch.full(