refactor(neuron): use named parameters in inputs helpers

This allows to hide the differences between the two backends in terms of
input parameters.
This commit is contained in:
David Corvoysier 2025-05-22 14:53:25 +00:00
parent b094f026c1
commit 2eb223613e

View File

@ -474,7 +474,7 @@ class NeuronGenerator(Generator):
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
# as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(
input_ids, attention_mask, seq_ids
input_ids, attention_mask=attention_mask, seq_ids=seq_ids
)
logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token(
@ -551,7 +551,7 @@ class NeuronGenerator(Generator):
input_ids[i, 0] = slot.next_token
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
model_inputs = self.model.prepare_inputs_for_decode(
input_ids, attention_mask, seq_ids
input_ids, attention_mask=attention_mask, seq_ids=seq_ids
)
logits = self.model(**model_inputs)[0]
return self._generate_token(decode_slots, next_batch_id, logits, input_ids)