diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py index b3887e14..c09564ac 100644 --- a/backends/neuron/server/text_generation_server/generator.py +++ b/backends/neuron/server/text_generation_server/generator.py @@ -474,7 +474,7 @@ class NeuronGenerator(Generator): # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored, # as they have already been generated and sent back in the last decode. model_inputs = self.model.prepare_inputs_for_prefill( - input_ids, attention_mask, seq_ids + input_ids, attention_mask=attention_mask, seq_ids=seq_ids ) logits = self.model(**model_inputs)[0] generation, next_batch = self._generate_token( @@ -551,7 +551,7 @@ class NeuronGenerator(Generator): input_ids[i, 0] = slot.next_token attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask model_inputs = self.model.prepare_inputs_for_decode( - input_ids, attention_mask, seq_ids + input_ids, attention_mask=attention_mask, seq_ids=seq_ids ) logits = self.model(**model_inputs)[0] return self._generate_token(decode_slots, next_batch_id, logits, input_ids)