diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py index 10a4d7a2..bd8191ba 100644 --- a/backends/neuron/server/text_generation_server/generator.py +++ b/backends/neuron/server/text_generation_server/generator.py @@ -341,7 +341,10 @@ class NeuronGenerator(Generator): self.model = model if not isinstance(self.model, NeuronModelForCausalLM): raise ValueError("The model must be a NeuronModelForCausalLM.") - if not model.neuron_config.continuous_batching: + if ( + model.neuron_config.batch_size > 1 + and not model.neuron_config.continuous_batching + ): raise ValueError( "The neuron model must be compiled with continuous_batching=True." )