mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-18 15:22:09 +00:00
fix(neuron): use neuron_config whenever possible
This commit is contained in:
parent
e586f8bdd6
commit
d4523f290a
@ -344,7 +344,9 @@ class NeuronGenerator(Generator):
|
|||||||
tokenizer.truncation_side = "left"
|
tokenizer.truncation_side = "left"
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.special_tokens = self.tokenizer.all_special_ids
|
self.special_tokens = self.tokenizer.all_special_ids
|
||||||
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
|
self.slots = [
|
||||||
|
Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
|
||||||
|
]
|
||||||
self.batch_id = 0
|
self.batch_id = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -368,14 +370,22 @@ class NeuronGenerator(Generator):
|
|||||||
The maximum number of tokens the model supports.
|
The maximum number of tokens the model supports.
|
||||||
"""
|
"""
|
||||||
# Just check that the warmup request parameters match the model capacity
|
# Just check that the warmup request parameters match the model capacity
|
||||||
batch_size = self.model.batch_size
|
batch_size = self.model.neuron_config.batch_size
|
||||||
if len(batch.requests) > batch_size:
|
if len(batch.requests) > batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
|
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
|
||||||
)
|
)
|
||||||
self.prefill(batch)
|
self.prefill(batch)
|
||||||
self.clear()
|
self.clear()
|
||||||
return self.model.batch_size * self.model.max_length
|
return (
|
||||||
|
self.model.neuron_config.batch_size
|
||||||
|
* self.model.neuron_config.sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_prefill_length(self) -> int:
|
||||||
|
if hasattr(self.model.neuron_config, "max_context_length"):
|
||||||
|
return self.model.neuron_config.max_context_length
|
||||||
|
return self.model.neuron_config.sequence_length
|
||||||
|
|
||||||
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
|
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
|
||||||
"""Prefill new requests.
|
"""Prefill new requests.
|
||||||
@ -395,7 +405,7 @@ class NeuronGenerator(Generator):
|
|||||||
if len(empty_slots) < len(batch.requests):
|
if len(empty_slots) < len(batch.requests):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
|
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
|
||||||
f" Please align max_batch_size with the static batch size: {self.model.batch_size}."
|
f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
|
||||||
)
|
)
|
||||||
# Assign each request to an empty slot
|
# Assign each request to an empty slot
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -422,8 +432,10 @@ class NeuronGenerator(Generator):
|
|||||||
inputs.append(slot.cached_text)
|
inputs.append(slot.cached_text)
|
||||||
# Apply truncation, making sure we fit into static dimensions
|
# Apply truncation, making sure we fit into static dimensions
|
||||||
if slot.truncate == 0:
|
if slot.truncate == 0:
|
||||||
max_length = self.model.max_length
|
max_length = self.max_prefill_length()
|
||||||
elif slot.truncate > max_length and slot.truncate < self.model.max_length:
|
elif (
|
||||||
|
slot.truncate > max_length and slot.truncate < self.max_prefill_length()
|
||||||
|
):
|
||||||
max_length = slot.truncate
|
max_length = slot.truncate
|
||||||
# Tokenize with padding and truncation
|
# Tokenize with padding and truncation
|
||||||
padded_inputs = self.tokenizer(
|
padded_inputs = self.tokenizer(
|
||||||
@ -451,7 +463,7 @@ class NeuronGenerator(Generator):
|
|||||||
slot_input_ids,
|
slot_input_ids,
|
||||||
slot.generation_config,
|
slot.generation_config,
|
||||||
self.model,
|
self.model,
|
||||||
self.model.max_length,
|
self.model.neuron_config.sequence_length,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
seed=slot.seed,
|
seed=slot.seed,
|
||||||
)
|
)
|
||||||
@ -602,7 +614,7 @@ class NeuronGenerator(Generator):
|
|||||||
|
|
||||||
def _cached_batch(self, batch_id: int, request_ids: List):
|
def _cached_batch(self, batch_id: int, request_ids: List):
|
||||||
size = len(request_ids)
|
size = len(request_ids)
|
||||||
max_tokens = size * self.model.max_length
|
max_tokens = size * self.model.neuron_config.sequence_length
|
||||||
return CachedBatch(
|
return CachedBatch(
|
||||||
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
|
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
|
||||||
)
|
)
|
||||||
|
@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config):
|
|||||||
"""
|
"""
|
||||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||||
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
||||||
assert generator.model.batch_size > 1
|
assert generator.model.neuron_config.batch_size > 1
|
||||||
input_text = "Once upon a time"
|
input_text = "Once upon a time"
|
||||||
max_new_tokens = 20
|
max_new_tokens = 20
|
||||||
# Prefill a single request, remembering the generated token
|
# Prefill a single request, remembering the generated token
|
||||||
tokens = {0: [], 1: []}
|
tokens = {0: [], 1: []}
|
||||||
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
|
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
|
||||||
max_length = generator.model.max_length
|
max_length = generator.model.neuron_config.sequence_length
|
||||||
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
||||||
generations, next_batch = generator.prefill(batch)
|
generations, next_batch = generator.prefill(batch)
|
||||||
assert next_batch.size == 1
|
assert next_batch.size == 1
|
||||||
|
@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample):
|
|||||||
request = create_request(
|
request = create_request(
|
||||||
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
|
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
|
||||||
)
|
)
|
||||||
max_length = generator.model.max_length
|
max_length = generator.model.neuron_config.sequence_length
|
||||||
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
||||||
generations, next_batch = generator.prefill(batch)
|
generations, next_batch = generator.prefill(batch)
|
||||||
# We already generated one token: call decode max_new_tokens - 1 times
|
# We already generated one token: call decode max_new_tokens - 1 times
|
||||||
|
@ -9,7 +9,7 @@ def test_prefill(neuron_model_config):
|
|||||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||||
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
||||||
max_batch_size = 4
|
max_batch_size = 4
|
||||||
assert generator.model.batch_size >= max_batch_size
|
assert generator.model.neuron_config.batch_size >= max_batch_size
|
||||||
for num_requests in [1, max_batch_size]:
|
for num_requests in [1, max_batch_size]:
|
||||||
for do_sample in [True, False]:
|
for do_sample in [True, False]:
|
||||||
mode = "sample" if do_sample else "greedy"
|
mode = "sample" if do_sample else "greedy"
|
||||||
@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Let's be pessimistic when estimating max_tokens
|
# Let's be pessimistic when estimating max_tokens
|
||||||
max_length = generator.model.max_length
|
max_length = generator.max_prefill_length()
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
|
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
|
||||||
)
|
)
|
||||||
@ -70,7 +70,7 @@ def test_prefill_truncate(neuron_model_config):
|
|||||||
config_name = neuron_model_config["name"]
|
config_name = neuron_model_config["name"]
|
||||||
neuron_model_path = neuron_model_config["neuron_model_path"]
|
neuron_model_path = neuron_model_config["neuron_model_path"]
|
||||||
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
||||||
batch_size = generator.model.batch_size
|
batch_size = generator.model.neuron_config.batch_size
|
||||||
# We apply truncation to all requests but the first one
|
# We apply truncation to all requests but the first one
|
||||||
truncate = [
|
truncate = [
|
||||||
None,
|
None,
|
||||||
@ -83,7 +83,7 @@ def test_prefill_truncate(neuron_model_config):
|
|||||||
requests = []
|
requests = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
|
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
|
||||||
max_length = generator.model.max_length
|
max_length = generator.max_prefill_length()
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
|
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user