feat(neuron): support on-device sampling

This commit is contained in:
David Corvoysier 2025-05-23 13:26:05 +00:00
parent c8c5dcf352
commit 1bccd1a3e4

View File

@ -356,6 +356,10 @@ class NeuronGenerator(Generator):
]
self.batch_id = 0
@property
def on_device_sampling(self) -> bool:
return getattr(self.model.neuron_config, "on_device_sampling", False)
@property
def info(self) -> InfoResponse:
"""Returns the expected InfoResponse."""
@ -454,6 +458,9 @@ class NeuronGenerator(Generator):
)
input_ids = padded_inputs.input_ids
attention_mask = padded_inputs.attention_mask
sampling_params = (
torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
)
# Pause previously active slots during generation
for slot in active_slots:
slot.pause()
@ -477,14 +484,21 @@ class NeuronGenerator(Generator):
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
# as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(
input_ids, attention_mask=attention_mask, seq_ids=seq_ids
input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
)
logits = self.model(**model_inputs)[0]
tokens_or_logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token(
prefill_slots, self.batch_id, logits, input_ids
prefill_slots, self.batch_id, tokens_or_logits, input_ids
)
self.batch_id += 1
# Reactivate previously active slots for the next decode
@ -544,22 +558,32 @@ class NeuronGenerator(Generator):
for slot in decode_slots:
max_length = max(max_length, slot.attention_mask.size(-1))
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None
for i, slot in enumerate(decode_slots):
if slot.state != Slot.State.EMPTY:
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
model_inputs = self.model.prepare_inputs_for_decode(
input_ids, attention_mask=attention_mask, seq_ids=seq_ids
input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
)
tokens_or_logits = self.model(**model_inputs)[0]
return self._generate_token(
decode_slots, next_batch_id, tokens_or_logits, input_ids
)
logits = self.model(**model_inputs)[0]
return self._generate_token(decode_slots, next_batch_id, logits, input_ids)
def _generate_token(
self,
slots: List[Slot],
next_batch_id: int,
logits: torch.Tensor,
tokens_or_logits: torch.Tensor,
input_ids: torch.LongTensor,
) -> Tuple[List[Generation], CachedBatch]:
generations = []
@ -568,8 +592,11 @@ class NeuronGenerator(Generator):
if slot.state != Slot.State.READY:
continue
request_id = slot.request_id
next_token_logits = logits[i : i + 1, -1, :]
slot_input_ids = input_ids[i : i + 1, :]
if self.on_device_sampling:
next_token = tokens_or_logits[i]
else:
next_token_logits = tokens_or_logits[i : i + 1, -1, :]
next_token = slot.select(slot_input_ids, next_token_logits)
next_token_text = slot.append(next_token)
generated_text = None