From 1bccd1a3e49e0c8074b0aa03c010c208f5160eaa Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 23 May 2025 13:26:05 +0000 Subject: [PATCH] feat(neuron): support on-device sampling --- .../text_generation_server/generator.py | 45 +++++++++++++++---- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py index 9e9ed987..10a4d7a2 100644 --- a/backends/neuron/server/text_generation_server/generator.py +++ b/backends/neuron/server/text_generation_server/generator.py @@ -356,6 +356,10 @@ class NeuronGenerator(Generator): ] self.batch_id = 0 + @property + def on_device_sampling(self) -> bool: + return getattr(self.model.neuron_config, "on_device_sampling", False) + @property def info(self) -> InfoResponse: """Returns the expected InfoResponse.""" @@ -454,6 +458,9 @@ class NeuronGenerator(Generator): ) input_ids = padded_inputs.input_ids attention_mask = padded_inputs.attention_mask + sampling_params = ( + torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None + ) # Pause previously active slots during generation for slot in active_slots: slot.pause() @@ -477,14 +484,21 @@ class NeuronGenerator(Generator): slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) slot_attention_mask = attention_mask[i] slot.reset(slot_input_ids, slot_attention_mask, selector) + if sampling_params is not None: + sampling_params[i, 0] = slot.generation_config.top_k + sampling_params[i, 1] = slot.generation_config.top_p + sampling_params[i, 2] = slot.generation_config.temperature # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored, # as they have already been generated and sent back in the last decode. model_inputs = self.model.prepare_inputs_for_prefill( - input_ids, attention_mask=attention_mask, seq_ids=seq_ids + input_ids, + attention_mask=attention_mask, + seq_ids=seq_ids, + sampling_params=sampling_params, ) - logits = self.model(**model_inputs)[0] + tokens_or_logits = self.model(**model_inputs)[0] generation, next_batch = self._generate_token( - prefill_slots, self.batch_id, logits, input_ids + prefill_slots, self.batch_id, tokens_or_logits, input_ids ) self.batch_id += 1 # Reactivate previously active slots for the next decode @@ -544,22 +558,32 @@ class NeuronGenerator(Generator): for slot in decode_slots: max_length = max(max_length, slot.attention_mask.size(-1)) attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64) + sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None for i, slot in enumerate(decode_slots): if slot.state != Slot.State.EMPTY: # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) input_ids[i, 0] = slot.next_token attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask + if sampling_params is not None: + sampling_params[i, 0] = slot.generation_config.top_k + sampling_params[i, 1] = slot.generation_config.top_p + sampling_params[i, 2] = slot.generation_config.temperature model_inputs = self.model.prepare_inputs_for_decode( - input_ids, attention_mask=attention_mask, seq_ids=seq_ids + input_ids, + attention_mask=attention_mask, + seq_ids=seq_ids, + sampling_params=sampling_params, + ) + tokens_or_logits = self.model(**model_inputs)[0] + return self._generate_token( + decode_slots, next_batch_id, tokens_or_logits, input_ids ) - logits = self.model(**model_inputs)[0] - return self._generate_token(decode_slots, next_batch_id, logits, input_ids) def _generate_token( self, slots: List[Slot], next_batch_id: int, - logits: torch.Tensor, + tokens_or_logits: torch.Tensor, input_ids: torch.LongTensor, ) -> Tuple[List[Generation], CachedBatch]: generations = [] @@ -568,9 +592,12 @@ class NeuronGenerator(Generator): if slot.state != Slot.State.READY: continue request_id = slot.request_id - next_token_logits = logits[i : i + 1, -1, :] slot_input_ids = input_ids[i : i + 1, :] - next_token = slot.select(slot_input_ids, next_token_logits) + if self.on_device_sampling: + next_token = tokens_or_logits[i] + else: + next_token_logits = tokens_or_logits[i : i + 1, -1, :] + next_token = slot.select(slot_input_ids, next_token_logits) next_token_text = slot.append(next_token) generated_text = None finish_reason = None