Bump neuron SDK version (#3260)

* chore(neuron): bump version to 0.2.0

* refactor(neuron): use named parameters in inputs helpers

This allows to hide the differences between the two backends in terms of
input parameters.

* refactor(neuron): remove obsolete code paths

* fix(neuron): use neuron_config whenever possible

* fix(neuron): use new cache import path

* fix(neuron): neuron config is not stored in config anymore

* fix(nxd): adapt model retrieval to new APIs

* fix(generator): emulate greedy in sampling parameters

When on-device sampling is enabled, we need to emulate the greedy
behaviour using top-k=1, top-p=1, temperature=1.

* test(neuron): update models and expectations

* feat(neuron): support on-device sampling

* fix(neuron): adapt entrypoint

* tests(neuron): remove obsolete models

* fix(neuron): adjust test expectations for llama on nxd
This commit is contained in:
David Corvoysier 2025-06-10 17:56:25 +02:00 committed by GitHub
parent 1ff9d185d5
commit 79183d1647
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 393 additions and 264 deletions

View File

@ -5,7 +5,7 @@ RUN mkdir -p /tgi
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments # Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
FROM alpine AS optimum-neuron FROM alpine AS optimum-neuron
RUN mkdir -p /optimum-neuron RUN mkdir -p /optimum-neuron
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.2.0.tar.gz /optimum-neuron/sources.tar.gz
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1 RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
# Build cargo components (adapted from TGI original Dockerfile) # Build cargo components (adapted from TGI original Dockerfile)
@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
# Install neuronx packages # Install neuronx packages
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \ && apt-get install -y --no-install-recommends \
aws-neuronx-dkms=2.19.64.0 \ aws-neuronx-dkms=2.20.28.0 \
aws-neuronx-collectives=2.23.135.0-3e70920f2 \ aws-neuronx-collectives=2.24.59.0-838c7fc8b \
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \ aws-neuronx-runtime-lib=2.24.53.0-f239092cc \
aws-neuronx-tools=2.20.204.0 \ aws-neuronx-tools=2.22.61.0 \
libxml2 \ libxml2 \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& apt-get clean && apt-get clean
@ -125,11 +125,10 @@ RUN pip3 install \
--index-url https://download.pytorch.org/whl/cpu --index-url https://download.pytorch.org/whl/cpu
RUN pip3 install \ RUN pip3 install \
neuronx-cc==2.16.372.0 \ neuronx-cc==2.17.194.0 \
torch-neuronx==2.5.1.2.4.0 \ torch-neuronx==2.5.1.2.6.0 \
transformers-neuronx==0.13.322 \ neuronx-distributed==0.11.0 \
neuronx-distributed==0.10.1 \ libneuronxla==2.2.1630.0 \
libneuronxla==2.1.681.0 \
--extra-index-url=https://pip.repos.neuron.amazonaws.com --extra-index-url=https://pip.repos.neuron.amazonaws.com
# Install HuggingFace packages # Install HuggingFace packages
@ -160,7 +159,7 @@ RUN pip install dist/text_generation_server*.tar.gz
# Final image # Final image
FROM neuron FROM neuron
COPY backends/neuron/tgi_env.py /tgi_env.py COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh

View File

@ -7,7 +7,8 @@ from typing import List, Optional, Tuple
import torch import torch
from loguru import logger from loguru import logger
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
from optimum.neuron.configuration_utils import NeuronConfig
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from optimum.neuron import NeuronModelForCausalLM from optimum.neuron import NeuronModelForCausalLM
@ -175,6 +176,12 @@ class Slot:
self._generation_config.top_p = request.parameters.top_p self._generation_config.top_p = request.parameters.top_p
if request.parameters.typical_p != 0: if request.parameters.typical_p != 0:
self._generation_config.typical_p = request.parameters.typical_p self._generation_config.typical_p = request.parameters.typical_p
else:
# Set the sampling parameters to emulate greedy decoding when using on-device sampling
self._generation_config.temperature = 1.0
self._generation_config.top_k = 1
self._generation_config.top_p = 1.0
self._generation_config.typical_p = 1.0
if request.parameters.repetition_penalty != 0: if request.parameters.repetition_penalty != 0:
self._generation_config.repetition_penalty = ( self._generation_config.repetition_penalty = (
request.parameters.repetition_penalty request.parameters.repetition_penalty
@ -211,19 +218,11 @@ class Slot:
self._mask = attention_mask.clone() self._mask = attention_mask.clone()
self._selector = selector self._selector = selector
def pause(self, reset_on_pause: bool): def pause(self):
"""Mark the current slot as paused for generation. """Mark the current slot as paused for generation.
Note that the KV cache for this slot will still be filled. Note that the KV cache for this slot will still be filled.
""" """
if reset_on_pause:
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Since generated tokens are now part of the prefill, we need to reevaluate
# max_new_tokens for the next generation
self._generation_config.max_new_tokens = (
self._max_new_tokens - self._generated_tokens
)
self._state = Slot.State.PAUSE self._state = Slot.State.PAUSE
def resume(self): def resume(self):
@ -340,16 +339,27 @@ class NeuronGenerator(Generator):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
): ):
self.model = model self.model = model
self.rebuild_cache_on_prefill = not self.model.continuous_batching if not isinstance(self.model, NeuronModelForCausalLM):
raise ValueError("The model must be a NeuronModelForCausalLM.")
if not model.neuron_config.continuous_batching:
raise ValueError(
"The neuron model must be compiled with continuous_batching=True."
)
# Specify padding and truncation options for decoder-only architecture # Specify padding and truncation options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
tokenizer.truncation_side = "left" tokenizer.truncation_side = "left"
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)] self.slots = [
Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
]
self.batch_id = 0 self.batch_id = 0
@property
def on_device_sampling(self) -> bool:
return getattr(self.model.neuron_config, "on_device_sampling", False)
@property @property
def info(self) -> InfoResponse: def info(self) -> InfoResponse:
"""Returns the expected InfoResponse.""" """Returns the expected InfoResponse."""
@ -371,14 +381,22 @@ class NeuronGenerator(Generator):
The maximum number of tokens the model supports. The maximum number of tokens the model supports.
""" """
# Just check that the warmup request parameters match the model capacity # Just check that the warmup request parameters match the model capacity
batch_size = self.model.batch_size batch_size = self.model.neuron_config.batch_size
if len(batch.requests) > batch_size: if len(batch.requests) > batch_size:
raise ValueError( raise ValueError(
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE." f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
) )
self.prefill(batch) self.prefill(batch)
self.clear() self.clear()
return self.model.batch_size * self.model.max_length return (
self.model.neuron_config.batch_size
* self.model.neuron_config.sequence_length
)
def max_prefill_length(self) -> int:
if hasattr(self.model.neuron_config, "max_context_length"):
return self.model.neuron_config.max_context_length
return self.model.neuron_config.sequence_length
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests. """Prefill new requests.
@ -398,7 +416,7 @@ class NeuronGenerator(Generator):
if len(empty_slots) < len(batch.requests): if len(empty_slots) < len(batch.requests):
raise ValueError( raise ValueError(
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots." f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
f" Please align max_batch_size with the static batch size: {self.model.batch_size}." f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
) )
# Assign each request to an empty slot # Assign each request to an empty slot
logger.debug( logger.debug(
@ -412,14 +430,8 @@ class NeuronGenerator(Generator):
logger.debug( logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}" f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
) )
if self.rebuild_cache_on_prefill: prefill_slots = new_slots
# We will clear pending slots and prefill all slots seq_ids = torch.tensor([slot.id for slot in prefill_slots])
prefill_slots = self.slots
seq_ids = None
else:
# We only need to pass inputs for the new requests
prefill_slots = new_slots
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
# Reconstruct the full inputs (without padding) as seen by the model. # Reconstruct the full inputs (without padding) as seen by the model.
# This comprises: # This comprises:
# - the inputs for new requests, # - the inputs for new requests,
@ -431,8 +443,10 @@ class NeuronGenerator(Generator):
inputs.append(slot.cached_text) inputs.append(slot.cached_text)
# Apply truncation, making sure we fit into static dimensions # Apply truncation, making sure we fit into static dimensions
if slot.truncate == 0: if slot.truncate == 0:
max_length = self.model.max_length max_length = self.max_prefill_length()
elif slot.truncate > max_length and slot.truncate < self.model.max_length: elif (
slot.truncate > max_length and slot.truncate < self.max_prefill_length()
):
max_length = slot.truncate max_length = slot.truncate
# Tokenize with padding and truncation # Tokenize with padding and truncation
padded_inputs = self.tokenizer( padded_inputs = self.tokenizer(
@ -444,13 +458,12 @@ class NeuronGenerator(Generator):
) )
input_ids = padded_inputs.input_ids input_ids = padded_inputs.input_ids
attention_mask = padded_inputs.attention_mask attention_mask = padded_inputs.attention_mask
sampling_params = (
torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
)
# Pause previously active slots during generation # Pause previously active slots during generation
next_tokens = []
for slot in active_slots: for slot in active_slots:
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill) slot.pause()
if self.rebuild_cache_on_prefill:
# The slot will be reset, so we need to store its next token
next_tokens.append(slot.next_token)
# Each slot must be reset with the padded inputs and masks # Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(prefill_slots): for i, slot in enumerate(prefill_slots):
if slot.state != slot.state.EMPTY: if slot.state != slot.state.EMPTY:
@ -464,29 +477,33 @@ class NeuronGenerator(Generator):
slot_input_ids, slot_input_ids,
slot.generation_config, slot.generation_config,
self.model, self.model,
self.model.max_length, self.model.neuron_config.sequence_length,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
seed=slot.seed, seed=slot.seed,
) )
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
slot_attention_mask = attention_mask[i] slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector) slot.reset(slot_input_ids, slot_attention_mask, selector)
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored, # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
# as they have already been generated and sent back in the last decode. # as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill( model_inputs = self.model.prepare_inputs_for_prefill(
input_ids, attention_mask, seq_ids input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
) )
logits = self.model(**model_inputs)[0] tokens_or_logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token( generation, next_batch = self._generate_token(
prefill_slots, self.batch_id, logits, input_ids prefill_slots, self.batch_id, tokens_or_logits, input_ids
) )
self.batch_id += 1 self.batch_id += 1
# Reactivate previously active slots for the next decode # Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots): for i, slot in enumerate(active_slots):
slot.resume() slot.resume()
if self.rebuild_cache_on_prefill:
# Append back the next token
slot.append(next_tokens[i])
logger.debug("Model ready for decoding") logger.debug("Model ready for decoding")
if next_batch is not None: if next_batch is not None:
logger.debug( logger.debug(
@ -530,12 +547,8 @@ class NeuronGenerator(Generator):
raise ValueError( raise ValueError(
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)" "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
) )
if self.model.continuous_batching: decode_slots = active_slots
decode_slots = active_slots seq_ids = torch.tensor([slot.id for slot in decode_slots])
seq_ids = torch.tensor([slot.id for slot in decode_slots])
else:
decode_slots = self.slots
seq_ids = None
# Reconstruct input_ids and attention_mask from decode slots # Reconstruct input_ids and attention_mask from decode slots
n_slots = len(decode_slots) n_slots = len(decode_slots)
input_ids = torch.full( input_ids = torch.full(
@ -545,22 +558,32 @@ class NeuronGenerator(Generator):
for slot in decode_slots: for slot in decode_slots:
max_length = max(max_length, slot.attention_mask.size(-1)) max_length = max(max_length, slot.attention_mask.size(-1))
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64) attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None
for i, slot in enumerate(decode_slots): for i, slot in enumerate(decode_slots):
if slot.state != Slot.State.EMPTY: if slot.state != Slot.State.EMPTY:
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token input_ids[i, 0] = slot.next_token
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
model_inputs = self.model.prepare_inputs_for_decode( model_inputs = self.model.prepare_inputs_for_decode(
input_ids, attention_mask, seq_ids input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
)
tokens_or_logits = self.model(**model_inputs)[0]
return self._generate_token(
decode_slots, next_batch_id, tokens_or_logits, input_ids
) )
logits = self.model(**model_inputs)[0]
return self._generate_token(decode_slots, next_batch_id, logits, input_ids)
def _generate_token( def _generate_token(
self, self,
slots: List[Slot], slots: List[Slot],
next_batch_id: int, next_batch_id: int,
logits: torch.Tensor, tokens_or_logits: torch.Tensor,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> Tuple[List[Generation], CachedBatch]: ) -> Tuple[List[Generation], CachedBatch]:
generations = [] generations = []
@ -569,9 +592,12 @@ class NeuronGenerator(Generator):
if slot.state != Slot.State.READY: if slot.state != Slot.State.READY:
continue continue
request_id = slot.request_id request_id = slot.request_id
next_token_logits = logits[i : i + 1, -1, :]
slot_input_ids = input_ids[i : i + 1, :] slot_input_ids = input_ids[i : i + 1, :]
next_token = slot.select(slot_input_ids, next_token_logits) if self.on_device_sampling:
next_token = tokens_or_logits[i]
else:
next_token_logits = tokens_or_logits[i : i + 1, -1, :]
next_token = slot.select(slot_input_ids, next_token_logits)
next_token_text = slot.append(next_token) next_token_text = slot.append(next_token)
generated_text = None generated_text = None
finish_reason = None finish_reason = None
@ -622,7 +648,7 @@ class NeuronGenerator(Generator):
def _cached_batch(self, batch_id: int, request_ids: List): def _cached_batch(self, batch_id: int, request_ids: List):
size = len(request_ids) size = len(request_ids)
max_tokens = size * self.model.max_length max_tokens = size * self.model.neuron_config.sequence_length
return CachedBatch( return CachedBatch(
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
) )
@ -671,8 +697,16 @@ class NeuronGenerator(Generator):
Returns: Returns:
A NeuronGenerator. A NeuronGenerator.
""" """
config = AutoConfig.from_pretrained(model_id) try:
neuron_config = getattr(config, "neuron", None) neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_id,
revision,
e,
)
neuron_config = None
start = time.time() start = time.time()
if neuron_config is None: if neuron_config is None:
export_kwargs = get_export_kwargs_from_env() export_kwargs = get_export_kwargs_from_env()

View File

@ -6,10 +6,12 @@ from typing import Optional
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.constants import HF_HUB_CACHE from huggingface_hub.constants import HF_HUB_CACHE
from loguru import logger from loguru import logger
from transformers import AutoConfig
from optimum.neuron import NeuronModelForCausalLM from optimum.neuron.cache import get_hub_cached_entries
from optimum.neuron.utils import get_hub_cached_entries from optimum.neuron.configuration_utils import NeuronConfig
from .tgi_env import check_env_and_neuron_config_compatibility
def get_export_kwargs_from_env(): def get_export_kwargs_from_env():
@ -24,7 +26,6 @@ def get_export_kwargs_from_env():
num_cores = int(num_cores) num_cores = int(num_cores)
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None) auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
return { return {
"task": "text-generation",
"batch_size": batch_size, "batch_size": batch_size,
"sequence_length": sequence_length, "sequence_length": sequence_length,
"num_cores": num_cores, "num_cores": num_cores,
@ -32,20 +33,15 @@ def get_export_kwargs_from_env():
} }
def is_cached(model_id, neuron_config): def is_cached(model_id):
# Look for cached entries for the specified model # Look for cached entries for the specified model
in_cache = False in_cache = False
entries = get_hub_cached_entries(model_id, "inference") entries = get_hub_cached_entries(model_id)
# Look for compatible entries # Look for compatible entries
for entry in entries: for entry in entries:
compatible = True if check_env_and_neuron_config_compatibility(
for key, value in neuron_config.items(): entry, check_compiler_version=True
# Only weights can be different ):
if key in ["checkpoint_id", "checkpoint_revision"]:
continue
if entry[key] != value:
compatible = False
if compatible:
in_cache = True in_cache = True
break break
return in_cache return in_cache
@ -87,8 +83,16 @@ def fetch_model(
revision = None revision = None
# Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
# Note that the model may already be present in the cache. # Note that the model may already be present in the cache.
config = AutoConfig.from_pretrained(model_id, revision=revision) try:
neuron_config = getattr(config, "neuron", None) neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_id,
revision,
e,
)
neuron_config = None
if neuron_config is not None: if neuron_config is not None:
if os.path.isdir(model_id): if os.path.isdir(model_id):
return model_id return model_id
@ -99,16 +103,11 @@ def fetch_model(
log_cache_size() log_cache_size()
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
# Model needs to be exported: look for compatible cached entries on the hub # Model needs to be exported: look for compatible cached entries on the hub
export_kwargs = get_export_kwargs_from_env() if not is_cached(model_id):
export_config = NeuronModelForCausalLM.get_export_config(
model_id, config, revision=revision, **export_kwargs
)
neuron_config = export_config.neuron
if not is_cached(model_id, neuron_config):
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache" hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi" neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
error_msg = ( error_msg = (
f"No cached version found for {model_id} with {neuron_config}." f"No cached version found for {model_id} with {get_export_kwargs_from_env()}."
f"You can start a discussion to request it on {hub_cache_url}" f"You can start a discussion to request it on {hub_cache_url}"
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}" f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
) )
@ -121,8 +120,10 @@ def fetch_model(
# Prefetch weights, tokenizer and generation config so that they are in cache # Prefetch weights, tokenizer and generation config so that they are in cache
log_cache_size() log_cache_size()
start = time.time() start = time.time()
snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") snapshot_path = snapshot_download(
model_id, revision=revision, ignore_patterns="*.bin"
)
end = time.time() end = time.time()
logger.info(f"Model weights fetched in {end - start:.2f} s.") logger.info(f"Model weights fetched in {end - start:.2f} s.")
log_cache_size() log_cache_size()
return model_id return snapshot_path

View File

@ -6,12 +6,11 @@ import os
import sys import sys
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from huggingface_hub import constants
from transformers import AutoConfig
from optimum.neuron.modeling_decoder import get_available_cores from optimum.neuron.modeling_decoder import get_available_cores
from optimum.neuron.utils import get_hub_cached_entries from optimum.neuron.cache import get_hub_cached_entries
from optimum.neuron.configuration_utils import NeuronConfig
from optimum.neuron.utils.version_utils import get_neuronxcc_version from optimum.neuron.utils.version_utils import get_neuronxcc_version
from optimum.neuron.utils import map_torch_dtype
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,15 +23,9 @@ tgi_router_env_vars = [
] ]
tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"] tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
env_config_peering = [
("MAX_BATCH_SIZE", "batch_size"),
("MAX_TOTAL_TOKENS", "sequence_length"),
("HF_AUTO_CAST_TYPE", "auto_cast_type"),
("HF_NUM_CORES", "num_cores"),
]
# By the end of this script all env var should be specified properly # By the end of this script all env var should be specified properly
env_vars = tgi_server_env_vars + tgi_router_env_vars tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars
available_cores = get_available_cores() available_cores = get_available_cores()
neuronxcc_version = get_neuronxcc_version() neuronxcc_version = get_neuronxcc_version()
@ -93,9 +86,17 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
def neuron_config_to_env(neuron_config): def neuron_config_to_env(neuron_config):
if isinstance(neuron_config, NeuronConfig):
neuron_config = neuron_config.to_dict()
with open(os.environ["ENV_FILEPATH"], "w") as f: with open(os.environ["ENV_FILEPATH"], "w") as f:
for env_var, config_key in env_config_peering: f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"]))
f.write("export {}={}\n".format(env_var, neuron_config[config_key])) f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"]))
f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"]))
config_key = (
"auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype"
)
auto_cast_type = neuron_config[config_key]
f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type))
max_input_tokens = os.getenv("MAX_INPUT_TOKENS") max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
if not max_input_tokens: if not max_input_tokens:
max_input_tokens = int(neuron_config["sequence_length"]) // 2 max_input_tokens = int(neuron_config["sequence_length"]) // 2
@ -111,7 +112,7 @@ def neuron_config_to_env(neuron_config):
def sort_neuron_configs(dictionary): def sort_neuron_configs(dictionary):
return -dictionary["num_cores"], -dictionary["batch_size"] return -dictionary["tp_degree"], -dictionary["batch_size"]
def lookup_compatible_cached_model( def lookup_compatible_cached_model(
@ -119,7 +120,7 @@ def lookup_compatible_cached_model(
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
# Reuse the same mechanic as the one in use to configure the tgi server part # Reuse the same mechanic as the one in use to configure the tgi server part
# The only difference here is that we stay as flexible as possible on the compatibility part # The only difference here is that we stay as flexible as possible on the compatibility part
entries = get_hub_cached_entries(model_id, "inference") entries = get_hub_cached_entries(model_id)
logger.debug( logger.debug(
"Found %d cached entries for model %s, revision %s", "Found %d cached entries for model %s, revision %s",
@ -155,15 +156,15 @@ def lookup_compatible_cached_model(
def check_env_and_neuron_config_compatibility( def check_env_and_neuron_config_compatibility(
neuron_config: Dict[str, Any], check_compiler_version: bool neuron_config_dict: Dict[str, Any], check_compiler_version: bool
) -> bool: ) -> bool:
logger.debug( logger.debug(
"Checking the provided neuron config %s is compatible with the local setup and provided environment", "Checking the provided neuron config %s is compatible with the local setup and provided environment",
neuron_config, neuron_config_dict,
) )
# Local setup compat checks # Local setup compat checks
if neuron_config["num_cores"] > available_cores: if neuron_config_dict["tp_degree"] > available_cores:
logger.debug( logger.debug(
"Not enough neuron cores available to run the provided neuron config" "Not enough neuron cores available to run the provided neuron config"
) )
@ -171,33 +172,65 @@ def check_env_and_neuron_config_compatibility(
if ( if (
check_compiler_version check_compiler_version
and neuron_config["compiler_version"] != neuronxcc_version and neuron_config_dict["neuronxcc_version"] != neuronxcc_version
): ):
logger.debug( logger.debug(
"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)", "Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
neuronxcc_version, neuronxcc_version,
neuron_config["compiler_version"], neuron_config_dict["neuronxcc_version"],
) )
return False return False
for env_var, config_key in env_config_peering: batch_size = os.getenv("MAX_BATCH_SIZE", None)
neuron_config_value = str(neuron_config[config_key]) if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size):
env_value = os.getenv(env_var, str(neuron_config_value)) logger.debug(
"The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)",
os.getenv("MAX_BATCH_SIZE"),
neuron_config_dict["batch_size"],
)
return False
max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None)
if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int(
max_total_tokens
):
logger.debug(
"The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)",
max_total_tokens,
neuron_config_dict["sequence_length"],
)
return False
num_cores = os.getenv("HF_NUM_CORES", None)
if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores):
logger.debug(
"The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)",
num_cores,
neuron_config_dict["tp_degree"],
)
return False
auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None)
if auto_cast_type is not None:
config_key = (
"auto_cast_type"
if "auto_cast_type" in neuron_config_dict
else "torch_dtype"
)
neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key]))
env_value = map_torch_dtype(auto_cast_type)
if env_value != neuron_config_value: if env_value != neuron_config_value:
logger.debug( logger.debug(
"The provided env var '%s' and the neuron config '%s' param differ (%s != %s)", "The provided auto cast type and the neuron config param differ (%s != %s)",
env_var,
config_key,
env_value, env_value,
neuron_config_value, neuron_config_value,
) )
return False return False
max_input_tokens = int( max_input_tokens = int(
os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0)) os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
) )
if max_input_tokens > 0: if max_input_tokens > 0:
sequence_length = neuron_config["sequence_length"] if hasattr(neuron_config_dict, "max_context_length"):
sequence_length = neuron_config_dict["max_context_length"]
else:
sequence_length = neuron_config_dict["sequence_length"]
if max_input_tokens >= sequence_length: if max_input_tokens >= sequence_length:
logger.debug( logger.debug(
"Specified max input tokens is not compatible with config sequence length ( %s >= %s)", "Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
@ -211,38 +244,29 @@ def check_env_and_neuron_config_compatibility(
def get_env_dict() -> Dict[str, str]: def get_env_dict() -> Dict[str, str]:
d = {} d = {}
for k in env_vars: for k in tgi_env_vars:
d[k] = os.getenv(k) d[k] = os.getenv(k)
return d return d
def main(): def get_neuron_config_for_model(
""" model_name_or_path: str, revision: Optional[str] = None
This script determines proper default TGI env variables for the neuron precompiled models to ) -> NeuronConfig:
work properly try:
:return: neuron_config = NeuronConfig.from_pretrained(
""" model_name_or_path, revision=revision
args = parse_cmdline_and_set_env()
for env_var in env_vars:
if not os.getenv(env_var):
break
else:
logger.info(
"All env vars %s already set, skipping, user know what they are doing",
env_vars,
) )
sys.exit(0) except Exception as e:
logger.debug(
cache_dir = constants.HF_HUB_CACHE "NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_name_or_path,
logger.info("Cache dir %s, model %s", cache_dir, args.model_id) revision,
e,
config = AutoConfig.from_pretrained(args.model_id, revision=args.revision) )
neuron_config = getattr(config, "neuron", None) neuron_config = None
if neuron_config is not None: if neuron_config is not None:
compatible = check_env_and_neuron_config_compatibility( compatible = check_env_and_neuron_config_compatibility(
neuron_config, check_compiler_version=False neuron_config.to_dict(), check_compiler_version=False
) )
if not compatible: if not compatible:
env_dict = get_env_dict() env_dict = get_env_dict()
@ -252,17 +276,6 @@ def main():
logger.error(msg) logger.error(msg)
raise Exception(msg) raise Exception(msg)
else: else:
neuron_config = lookup_compatible_cached_model(args.model_id, args.revision) neuron_config = lookup_compatible_cached_model(model_name_or_path, revision)
if not neuron_config: return neuron_config
msg = (
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
).format(get_env_dict(), available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
neuron_config_to_env(neuron_config)
if __name__ == "__main__":
main()

View File

@ -4,14 +4,12 @@ import subprocess
import sys import sys
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import huggingface_hub import os
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.utils import synchronize_hub_cache from optimum.neuron.cache import synchronize_hub_cache
from optimum.neuron.version import __sdk_version__ as sdk_version
from optimum.neuron.version import __version__ as version
logging.basicConfig( logging.basicConfig(
@ -21,30 +19,14 @@ logging.basicConfig(
) )
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache" OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
# All model configurations below will be added to the neuron_model_config fixture # All model configurations below will be added to the neuron_model_config fixture
MODEL_CONFIGURATIONS = { MODEL_CONFIGURATIONS = {
"gpt2": {
"model_id": "gpt2",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 1024,
"num_cores": 2,
"auto_cast_type": "fp16",
},
},
"llama": { "llama": {
"model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B", "model_id": "unsloth/Llama-3.2-1B-Instruct",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 2048,
"num_cores": 2,
"auto_cast_type": "fp16",
},
},
"mistral": {
"model_id": "optimum/mistral-1.1b-testing",
"export_kwargs": { "export_kwargs": {
"batch_size": 4, "batch_size": 4,
"sequence_length": 4096, "sequence_length": 4096,
@ -58,7 +40,7 @@ MODEL_CONFIGURATIONS = {
"batch_size": 4, "batch_size": 4,
"sequence_length": 4096, "sequence_length": 4096,
"num_cores": 2, "num_cores": 2,
"auto_cast_type": "fp16", "auto_cast_type": "bf16",
}, },
}, },
"granite": { "granite": {
@ -73,12 +55,6 @@ MODEL_CONFIGURATIONS = {
} }
def get_hub_neuron_model_id(config_name: str):
return (
f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}"
)
def export_model(model_id, export_kwargs, neuron_model_path): def export_model(model_id, export_kwargs, neuron_model_path):
export_command = [ export_command = [
"optimum-cli", "optimum-cli",
@ -104,57 +80,35 @@ def export_model(model_id, export_kwargs, neuron_model_path):
def neuron_model_config(request): def neuron_model_config(request):
"""Expose a pre-trained neuron model """Expose a pre-trained neuron model
The fixture first makes sure the following model artifacts are present on the hub: The fixture exports a model locally and returns a dictionary containing:
- exported neuron model under optimum-internal-testing/neuron-testing-<version>-<name>,
- cached artifacts under optimum-internal-testing/neuron-testing-cache.
If not, it will export the model and push it to the hub.
It then fetches the model locally and return a dictionary containing:
- a configuration name, - a configuration name,
- the original model id, - the original model id,
- the export parameters, - the export parameters,
- the neuron model id,
- the neuron model local path. - the neuron model local path.
For each exposed model, the local directory is maintained for the duration of the For each exposed model, the local directory is maintained for the duration of the
test session and cleaned up afterwards. test session and cleaned up afterwards.
The hub model artifacts are never cleaned up and persist accross sessions.
They must be cleaned up manually when the optimum-neuron version changes.
""" """
config_name = request.param config_name = request.param
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param]) model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
model_id = model_config["model_id"] model_id = model_config["model_id"]
export_kwargs = model_config["export_kwargs"] export_kwargs = model_config["export_kwargs"]
neuron_model_id = get_hub_neuron_model_id(config_name)
with TemporaryDirectory() as neuron_model_path: with TemporaryDirectory() as neuron_model_path:
hub = huggingface_hub.HfApi() export_model(model_id, export_kwargs, neuron_model_path)
if hub.repo_exists(neuron_model_id): synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub") tokenizer = AutoTokenizer.from_pretrained(model_id)
hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path) tokenizer.save_pretrained(neuron_model_path)
else: del tokenizer
export_model(model_id, export_kwargs, neuron_model_path)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(neuron_model_path)
del tokenizer
# Create the test model on the hub
hub.create_repo(neuron_model_id, private=True)
hub.upload_folder(
folder_path=neuron_model_path,
repo_id=neuron_model_id,
ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"],
)
# Make sure it is cached
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
# Add dynamic parameters to the model configuration # Add dynamic parameters to the model configuration
model_config["neuron_model_path"] = neuron_model_path model_config["neuron_model_path"] = neuron_model_path
model_config["neuron_model_id"] = neuron_model_id
# Also add model configuration name to allow tests to adapt their expectations # Also add model configuration name to allow tests to adapt their expectations
model_config["name"] = config_name model_config["name"] = config_name
# Yield instead of returning to keep a reference to the temporary directory. # Yield instead of returning to keep a reference to the temporary directory.
# It will go out of scope and be released only once all tests needing the fixture # It will go out of scope and be released only once all tests needing the fixture
# have been completed. # have been completed.
logger.info(f"{config_name} ready for testing ...") logger.info(f"{config_name} ready for testing ...")
os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID
yield model_config yield model_config
logger.info(f"Done with {config_name}") logger.info(f"Done with {config_name}")

View File

@ -0,0 +1,42 @@
import os
import pytest
from text_generation_server.generator import NeuronGenerator
from text_generation_server.model import fetch_model, is_cached
@pytest.fixture(scope="module")
def cached_model_id(neuron_model_config) -> str:
"""
Fixture to provide a cached model ID for testing.
This assumes the model is already cached in the local environment.
"""
export_kwargs = neuron_model_config["export_kwargs"]
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
yield neuron_model_config["model_id"]
os.environ.pop("MAX_BATCH_SIZE", None)
os.environ.pop("MAX_TOTAL_TOKENS", None)
os.environ.pop("HF_AUTO_CAST_TYPE", None)
os.environ.pop("HF_NUM_CORES", None)
def test_model_is_cached(cached_model_id):
assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached"
def test_fetch_cached_model(cached_model_id: str):
model_path = fetch_model(cached_model_id)
assert os.path.exists(
model_path
), f"Model {cached_model_id} was not fetched successfully"
assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory"
def test_generator_from_cached_model(cached_model_id: str):
generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)
assert generator is not None, "Generator could not be created from cached model"
assert generator.model is not None, "Generator model is not initialized"
assert generator.tokenizer is not None, "Generator tokenizer is not initialized"

View File

@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config):
""" """
neuron_model_path = neuron_model_config["neuron_model_path"] neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path) generator = NeuronGenerator.from_pretrained(neuron_model_path)
assert generator.model.batch_size > 1 assert generator.model.neuron_config.batch_size > 1
input_text = "Once upon a time" input_text = "Once upon a time"
max_new_tokens = 20 max_new_tokens = 20
# Prefill a single request, remembering the generated token # Prefill a single request, remembering the generated token
tokens = {0: [], 1: []} tokens = {0: [], 1: []}
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens) request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
max_length = generator.model.max_length max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch) generations, next_batch = generator.prefill(batch)
assert next_batch.size == 1 assert next_batch.size == 1

View File

@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample):
request = create_request( request = create_request(
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
) )
max_length = generator.model.max_length max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch) generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times # We already generated one token: call decode max_new_tokens - 1 times
@ -40,19 +40,15 @@ def _test_decode(config_name, generator, do_sample):
assert output.finish_reason == 0 assert output.finish_reason == 0
if do_sample: if do_sample:
expected_text = { expected_text = {
"gpt2": " The sun was set", "llama": " I sat alone in the café",
"llama": "George Orwell, 1984", "qwen2": " The air was so still",
"mistral": "The sky was",
"qwen2": " A young woman with",
"granite": "1984, George Orwell", "granite": "1984, George Orwell",
}[config_name] }[config_name]
assert expected_text in output.text assert expected_text in output.text
else: else:
print(output.text) print(output.text)
expected_text = { expected_text = {
"gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going', "llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
"llama": " George Orwells classic dystopian novel, 1984, begins with this ominous sentence. The story",
"mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
"qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a", "qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
"granite": "\n\nThis opening line from George Orwell's dystopian novel \"198", "granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
}[config_name] }[config_name]

View File

@ -9,7 +9,7 @@ def test_prefill(neuron_model_config):
neuron_model_path = neuron_model_config["neuron_model_path"] neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path) generator = NeuronGenerator.from_pretrained(neuron_model_path)
max_batch_size = 4 max_batch_size = 4
assert generator.model.batch_size >= max_batch_size assert generator.model.neuron_config.batch_size >= max_batch_size
for num_requests in [1, max_batch_size]: for num_requests in [1, max_batch_size]:
for do_sample in [True, False]: for do_sample in [True, False]:
mode = "sample" if do_sample else "greedy" mode = "sample" if do_sample else "greedy"
@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
) )
) )
# Let's be pessimistic when estimating max_tokens # Let's be pessimistic when estimating max_tokens
max_length = generator.model.max_length max_length = generator.max_prefill_length()
batch = Batch( batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
) )
@ -46,17 +46,13 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
assert len(generations) == batch_size assert len(generations) == batch_size
if do_sample: if do_sample:
expectations = { expectations = {
"gpt2": [383, " The"], "llama": [358, " I"],
"llama": [10058, " George"], "qwen2": [576, " The"],
"mistral": [450, " The"],
"qwen2": [362, " A"],
"granite": [308, " ("], "granite": [308, " ("],
}[config_name] }[config_name]
else: else:
expectations = { expectations = {
"gpt2": [198, "\n"], "llama": [578, " The"],
"llama": [10058, " George"],
"mistral": [13, "\n"],
"qwen2": [358, " I"], "qwen2": [358, " I"],
"granite": [203, "\n"], "granite": [203, "\n"],
}[config_name] }[config_name]
@ -70,7 +66,7 @@ def test_prefill_truncate(neuron_model_config):
config_name = neuron_model_config["name"] config_name = neuron_model_config["name"]
neuron_model_path = neuron_model_config["neuron_model_path"] neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path) generator = NeuronGenerator.from_pretrained(neuron_model_path)
batch_size = generator.model.batch_size batch_size = generator.model.neuron_config.batch_size
# We apply truncation to all requests but the first one # We apply truncation to all requests but the first one
truncate = [ truncate = [
None, None,
@ -83,7 +79,7 @@ def test_prefill_truncate(neuron_model_config):
requests = [] requests = []
for i in range(batch_size): for i in range(batch_size):
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i])) requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
max_length = generator.model.max_length max_length = generator.max_prefill_length()
batch = Batch( batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
) )
@ -91,12 +87,12 @@ def test_prefill_truncate(neuron_model_config):
# Even if the input text is identical for all requests, the first generated token might # Even if the input text is identical for all requests, the first generated token might
# be different because of the truncation # be different because of the truncation
expectations = { expectations = {
"gpt2": [" He", " He", "\n", " He"], "llama": [" He", "iens", "\x08", " He"],
"llama": ["", " The", " He", " He"],
"mistral": [" He", "\n", " He", " He"],
"qwen2": [" He", " The", " He", " He"], "qwen2": [" He", " The", " He", " He"],
"granite": ["\n", "\n", " I", " He"], "granite": ["\n", "\n", " I", " He"],
}[config_name] }[config_name]
for i, g in enumerate(generations): for i, g in enumerate(generations):
tokens = g.tokens tokens = g.tokens
assert tokens.texts[0] == expectations[i] assert (
tokens.texts[0] == expectations[i]
), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]"

View File

@ -0,0 +1,63 @@
import os
import pytest
from tempfile import TemporaryDirectory
from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig
from optimum.neuron.utils import map_torch_dtype
from text_generation_server.tgi_env import (
get_neuron_config_for_model,
lookup_compatible_cached_model,
neuron_config_to_env,
)
def test_get_neuron_config_for_model(neuron_model_config):
neuron_model_path = neuron_model_config["neuron_model_path"]
export_kwargs = neuron_model_config["export_kwargs"]
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
neuron_config = get_neuron_config_for_model(neuron_model_path)
assert neuron_config is not None
assert neuron_config.batch_size == export_kwargs["batch_size"]
assert neuron_config.sequence_length == export_kwargs["sequence_length"]
assert neuron_config.tp_degree == export_kwargs["num_cores"]
if isinstance(neuron_config, NxDNeuronConfig):
assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(
export_kwargs["auto_cast_type"]
)
else:
assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(
export_kwargs["auto_cast_type"]
)
@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"])
def test_lookup_compatible_cached_model(model_id: str):
neuron_config = lookup_compatible_cached_model(model_id, None)
assert neuron_config is not None
def test_neuron_config_to_env(neuron_model_config) -> None:
neuron_model_path = neuron_model_config["neuron_model_path"]
neuron_config = get_neuron_config_for_model(neuron_model_path)
with TemporaryDirectory() as temp_dir:
os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh")
neuron_config_to_env(neuron_config)
with open(os.environ["ENV_FILEPATH"], "r") as env_file:
env_content = env_file.read()
assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content
assert (
f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}"
in env_content
)
assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content
if hasattr(neuron_config, "torch_dtype"):
auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(
"."
)[-1]
else:
auto_cast_type = neuron_config.auto_cast_type
assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content

View File

@ -9,7 +9,7 @@ touch $ENV_FILEPATH
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
${SCRIPT_DIR}/tgi_env.py $@ ${SCRIPT_DIR}/tgi_entry_point.py $@
source $ENV_FILEPATH source $ENV_FILEPATH

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python
import logging
import os
import sys
from text_generation_server.tgi_env import (
available_cores,
get_env_dict,
get_neuron_config_for_model,
neuron_config_to_env,
neuronxcc_version,
parse_cmdline_and_set_env,
tgi_env_vars,
)
logger = logging.getLogger(__name__)
def main():
"""
This script determines proper default TGI env variables for the neuron precompiled models to
work properly
:return:
"""
args = parse_cmdline_and_set_env()
for env_var in tgi_env_vars:
if not os.getenv(env_var):
break
else:
logger.info(
"All env vars %s already set, skipping, user know what they are doing",
tgi_env_vars,
)
sys.exit(0)
neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
if not neuron_config:
msg = (
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
).format(get_env_dict(), available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
neuron_config_to_env(neuron_config)
if __name__ == "__main__":
main()

View File

@ -28,15 +28,6 @@ logger = logging.getLogger(__file__)
# All model configurations below will be added to the neuron_model_config fixture # All model configurations below will be added to the neuron_model_config fixture
MODEL_CONFIGURATIONS = { MODEL_CONFIGURATIONS = {
"gpt2": {
"model_id": "gpt2",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 1024,
"num_cores": 2,
"auto_cast_type": "fp16",
},
},
"llama": { "llama": {
"model_id": "unsloth/Llama-3.2-1B-Instruct", "model_id": "unsloth/Llama-3.2-1B-Instruct",
"export_kwargs": { "export_kwargs": {
@ -46,15 +37,6 @@ MODEL_CONFIGURATIONS = {
"auto_cast_type": "fp16", "auto_cast_type": "fp16",
}, },
}, },
"mistral": {
"model_id": "optimum/mistral-1.1b-testing",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 4096,
"num_cores": 2,
"auto_cast_type": "bf16",
},
},
"qwen2": { "qwen2": {
"model_id": "Qwen/Qwen2.5-0.5B", "model_id": "Qwen/Qwen2.5-0.5B",
"export_kwargs": { "export_kwargs": {

View File

@ -20,9 +20,7 @@ async def test_model_single_request(tgi_service):
) )
assert response.details.generated_tokens == 17 assert response.details.generated_tokens == 17
greedy_expectations = { greedy_expectations = {
"gpt2": "\n\nDeep learning is a new field of research that has been around for a while", "llama": " and how does it work?\nDeep learning is a subset of machine learning that uses artificial",
"llama": " and How Does it Work?\nDeep learning is a subset of machine learning that uses artificial",
"mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that",
"qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on", "qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on",
"granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art", "granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art",
} }
@ -79,9 +77,7 @@ async def test_model_multiple_requests(tgi_service, neuron_generate_load):
assert len(responses) == 4 assert len(responses) == 4
expectations = { expectations = {
"gpt2": "Deep learning is a new field of research that has been around for a while",
"llama": "Deep learning is a subset of machine learning that uses artificial", "llama": "Deep learning is a subset of machine learning that uses artificial",
"mistral": "Deep Learning is a type of machine learning that",
"qwen2": "Deep Learning is a subset of Machine Learning that is based on", "qwen2": "Deep Learning is a subset of Machine Learning that is based on",
"granite": "Deep Learning is a subset of Machine Learning, which is a branch of Art", "granite": "Deep Learning is a subset of Machine Learning, which is a branch of Art",
} }