fix(neuron): use neuron_config whenever possible

This commit is contained in:
David Corvoysier 2025-05-23 08:33:12 +00:00
parent e586f8bdd6
commit d4523f290a
4 changed files with 28 additions and 16 deletions

View File

@ -344,7 +344,9 @@ class NeuronGenerator(Generator):
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
self.slots = [
Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
]
self.batch_id = 0
@property
@ -368,14 +370,22 @@ class NeuronGenerator(Generator):
The maximum number of tokens the model supports.
"""
# Just check that the warmup request parameters match the model capacity
batch_size = self.model.batch_size
batch_size = self.model.neuron_config.batch_size
if len(batch.requests) > batch_size:
raise ValueError(
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
)
self.prefill(batch)
self.clear()
return self.model.batch_size * self.model.max_length
return (
self.model.neuron_config.batch_size
* self.model.neuron_config.sequence_length
)
def max_prefill_length(self) -> int:
if hasattr(self.model.neuron_config, "max_context_length"):
return self.model.neuron_config.max_context_length
return self.model.neuron_config.sequence_length
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests.
@ -395,7 +405,7 @@ class NeuronGenerator(Generator):
if len(empty_slots) < len(batch.requests):
raise ValueError(
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
f" Please align max_batch_size with the static batch size: {self.model.batch_size}."
f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
)
# Assign each request to an empty slot
logger.debug(
@ -422,8 +432,10 @@ class NeuronGenerator(Generator):
inputs.append(slot.cached_text)
# Apply truncation, making sure we fit into static dimensions
if slot.truncate == 0:
max_length = self.model.max_length
elif slot.truncate > max_length and slot.truncate < self.model.max_length:
max_length = self.max_prefill_length()
elif (
slot.truncate > max_length and slot.truncate < self.max_prefill_length()
):
max_length = slot.truncate
# Tokenize with padding and truncation
padded_inputs = self.tokenizer(
@ -451,7 +463,7 @@ class NeuronGenerator(Generator):
slot_input_ids,
slot.generation_config,
self.model,
self.model.max_length,
self.model.neuron_config.sequence_length,
tokenizer=self.tokenizer,
seed=slot.seed,
)
@ -602,7 +614,7 @@ class NeuronGenerator(Generator):
def _cached_batch(self, batch_id: int, request_ids: List):
size = len(request_ids)
max_tokens = size * self.model.max_length
max_tokens = size * self.model.neuron_config.sequence_length
return CachedBatch(
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
)

View File

@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config):
"""
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
assert generator.model.batch_size > 1
assert generator.model.neuron_config.batch_size > 1
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
tokens = {0: [], 1: []}
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
max_length = generator.model.max_length
max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
assert next_batch.size == 1

View File

@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample):
request = create_request(
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
)
max_length = generator.model.max_length
max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times

View File

@ -9,7 +9,7 @@ def test_prefill(neuron_model_config):
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
max_batch_size = 4
assert generator.model.batch_size >= max_batch_size
assert generator.model.neuron_config.batch_size >= max_batch_size
for num_requests in [1, max_batch_size]:
for do_sample in [True, False]:
mode = "sample" if do_sample else "greedy"
@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
)
)
# Let's be pessimistic when estimating max_tokens
max_length = generator.model.max_length
max_length = generator.max_prefill_length()
batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
)
@ -70,7 +70,7 @@ def test_prefill_truncate(neuron_model_config):
config_name = neuron_model_config["name"]
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
batch_size = generator.model.batch_size
batch_size = generator.model.neuron_config.batch_size
# We apply truncation to all requests but the first one
truncate = [
None,
@ -83,7 +83,7 @@ def test_prefill_truncate(neuron_model_config):
requests = []
for i in range(batch_size):
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
max_length = generator.model.max_length
max_length = generator.max_prefill_length()
batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
)