fix(neuron): use neuron_config whenever possible

This commit is contained in:
David Corvoysier 2025-05-23 08:33:12 +00:00
parent e586f8bdd6
commit d4523f290a
4 changed files with 28 additions and 16 deletions

View File

@ -344,7 +344,9 @@ class NeuronGenerator(Generator):
tokenizer.truncation_side = "left" tokenizer.truncation_side = "left"
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)] self.slots = [
Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
]
self.batch_id = 0 self.batch_id = 0
@property @property
@ -368,14 +370,22 @@ class NeuronGenerator(Generator):
The maximum number of tokens the model supports. The maximum number of tokens the model supports.
""" """
# Just check that the warmup request parameters match the model capacity # Just check that the warmup request parameters match the model capacity
batch_size = self.model.batch_size batch_size = self.model.neuron_config.batch_size
if len(batch.requests) > batch_size: if len(batch.requests) > batch_size:
raise ValueError( raise ValueError(
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE." f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
) )
self.prefill(batch) self.prefill(batch)
self.clear() self.clear()
return self.model.batch_size * self.model.max_length return (
self.model.neuron_config.batch_size
* self.model.neuron_config.sequence_length
)
def max_prefill_length(self) -> int:
if hasattr(self.model.neuron_config, "max_context_length"):
return self.model.neuron_config.max_context_length
return self.model.neuron_config.sequence_length
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests. """Prefill new requests.
@ -395,7 +405,7 @@ class NeuronGenerator(Generator):
if len(empty_slots) < len(batch.requests): if len(empty_slots) < len(batch.requests):
raise ValueError( raise ValueError(
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots." f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
f" Please align max_batch_size with the static batch size: {self.model.batch_size}." f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
) )
# Assign each request to an empty slot # Assign each request to an empty slot
logger.debug( logger.debug(
@ -422,8 +432,10 @@ class NeuronGenerator(Generator):
inputs.append(slot.cached_text) inputs.append(slot.cached_text)
# Apply truncation, making sure we fit into static dimensions # Apply truncation, making sure we fit into static dimensions
if slot.truncate == 0: if slot.truncate == 0:
max_length = self.model.max_length max_length = self.max_prefill_length()
elif slot.truncate > max_length and slot.truncate < self.model.max_length: elif (
slot.truncate > max_length and slot.truncate < self.max_prefill_length()
):
max_length = slot.truncate max_length = slot.truncate
# Tokenize with padding and truncation # Tokenize with padding and truncation
padded_inputs = self.tokenizer( padded_inputs = self.tokenizer(
@ -451,7 +463,7 @@ class NeuronGenerator(Generator):
slot_input_ids, slot_input_ids,
slot.generation_config, slot.generation_config,
self.model, self.model,
self.model.max_length, self.model.neuron_config.sequence_length,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
seed=slot.seed, seed=slot.seed,
) )
@ -602,7 +614,7 @@ class NeuronGenerator(Generator):
def _cached_batch(self, batch_id: int, request_ids: List): def _cached_batch(self, batch_id: int, request_ids: List):
size = len(request_ids) size = len(request_ids)
max_tokens = size * self.model.max_length max_tokens = size * self.model.neuron_config.sequence_length
return CachedBatch( return CachedBatch(
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
) )

View File

@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config):
""" """
neuron_model_path = neuron_model_config["neuron_model_path"] neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path) generator = NeuronGenerator.from_pretrained(neuron_model_path)
assert generator.model.batch_size > 1 assert generator.model.neuron_config.batch_size > 1
input_text = "Once upon a time" input_text = "Once upon a time"
max_new_tokens = 20 max_new_tokens = 20
# Prefill a single request, remembering the generated token # Prefill a single request, remembering the generated token
tokens = {0: [], 1: []} tokens = {0: [], 1: []}
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens) request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
max_length = generator.model.max_length max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch) generations, next_batch = generator.prefill(batch)
assert next_batch.size == 1 assert next_batch.size == 1

View File

@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample):
request = create_request( request = create_request(
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
) )
max_length = generator.model.max_length max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch) generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times # We already generated one token: call decode max_new_tokens - 1 times

View File

@ -9,7 +9,7 @@ def test_prefill(neuron_model_config):
neuron_model_path = neuron_model_config["neuron_model_path"] neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path) generator = NeuronGenerator.from_pretrained(neuron_model_path)
max_batch_size = 4 max_batch_size = 4
assert generator.model.batch_size >= max_batch_size assert generator.model.neuron_config.batch_size >= max_batch_size
for num_requests in [1, max_batch_size]: for num_requests in [1, max_batch_size]:
for do_sample in [True, False]: for do_sample in [True, False]:
mode = "sample" if do_sample else "greedy" mode = "sample" if do_sample else "greedy"
@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
) )
) )
# Let's be pessimistic when estimating max_tokens # Let's be pessimistic when estimating max_tokens
max_length = generator.model.max_length max_length = generator.max_prefill_length()
batch = Batch( batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
) )
@ -70,7 +70,7 @@ def test_prefill_truncate(neuron_model_config):
config_name = neuron_model_config["name"] config_name = neuron_model_config["name"]
neuron_model_path = neuron_model_config["neuron_model_path"] neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path) generator = NeuronGenerator.from_pretrained(neuron_model_path)
batch_size = generator.model.batch_size batch_size = generator.model.neuron_config.batch_size
# We apply truncation to all requests but the first one # We apply truncation to all requests but the first one
truncate = [ truncate = [
None, None,
@ -83,7 +83,7 @@ def test_prefill_truncate(neuron_model_config):
requests = [] requests = []
for i in range(batch_size): for i in range(batch_size):
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i])) requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
max_length = generator.model.max_length max_length = generator.max_prefill_length()
batch = Batch( batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
) )