fix(nxd): adapt model retrieval to new APIs

This commit is contained in:
David Corvoysier 2025-05-27 12:27:22 +00:00
parent 39895019c8
commit b916076c72
2 changed files with 56 additions and 23 deletions

View File

@ -6,13 +6,14 @@ from typing import Optional
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.constants import HF_HUB_CACHE from huggingface_hub.constants import HF_HUB_CACHE
from loguru import logger from loguru import logger
from transformers import AutoConfig
from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.cache import get_hub_cached_entries from optimum.neuron.cache import get_hub_cached_entries
from optimum.neuron.configuration_utils import NeuronConfig from optimum.neuron.configuration_utils import NeuronConfig
from .tgi_env import check_env_and_neuron_config_compatibility
def get_export_kwargs_from_env(): def get_export_kwargs_from_env():
batch_size = os.environ.get("MAX_BATCH_SIZE", None) batch_size = os.environ.get("MAX_BATCH_SIZE", None)
if batch_size is not None: if batch_size is not None:
@ -25,7 +26,6 @@ def get_export_kwargs_from_env():
num_cores = int(num_cores) num_cores = int(num_cores)
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None) auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
return { return {
"task": "text-generation",
"batch_size": batch_size, "batch_size": batch_size,
"sequence_length": sequence_length, "sequence_length": sequence_length,
"num_cores": num_cores, "num_cores": num_cores,
@ -33,20 +33,15 @@ def get_export_kwargs_from_env():
} }
def is_cached(model_id, neuron_config): def is_cached(model_id):
# Look for cached entries for the specified model # Look for cached entries for the specified model
in_cache = False in_cache = False
entries = get_hub_cached_entries(model_id, "inference") entries = get_hub_cached_entries(model_id)
# Look for compatible entries # Look for compatible entries
for entry in entries: for entry in entries:
compatible = True if check_env_and_neuron_config_compatibility(
for key, value in neuron_config.items(): entry, check_compiler_version=True
# Only weights can be different ):
if key in ["checkpoint_id", "checkpoint_revision"]:
continue
if entry[key] != value:
compatible = False
if compatible:
in_cache = True in_cache = True
break break
return in_cache return in_cache
@ -108,17 +103,11 @@ def fetch_model(
log_cache_size() log_cache_size()
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
# Model needs to be exported: look for compatible cached entries on the hub # Model needs to be exported: look for compatible cached entries on the hub
export_kwargs = get_export_kwargs_from_env() if not is_cached(model_id):
config = AutoConfig.from_pretrained(model_id, revision=revision)
export_config = NeuronModelForCausalLM.get_export_config(
model_id, config, revision=revision, **export_kwargs
)
neuron_config = export_config.neuron
if not is_cached(model_id, neuron_config):
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache" hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi" neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
error_msg = ( error_msg = (
f"No cached version found for {model_id} with {neuron_config}." f"No cached version found for {model_id} with {get_export_kwargs_from_env()}."
f"You can start a discussion to request it on {hub_cache_url}" f"You can start a discussion to request it on {hub_cache_url}"
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}" f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
) )
@ -131,8 +120,10 @@ def fetch_model(
# Prefetch weights, tokenizer and generation config so that they are in cache # Prefetch weights, tokenizer and generation config so that they are in cache
log_cache_size() log_cache_size()
start = time.time() start = time.time()
snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") snapshot_path = snapshot_download(
model_id, revision=revision, ignore_patterns="*.bin"
)
end = time.time() end = time.time()
logger.info(f"Model weights fetched in {end - start:.2f} s.") logger.info(f"Model weights fetched in {end - start:.2f} s.")
log_cache_size() log_cache_size()
return model_id return snapshot_path

View File

@ -0,0 +1,42 @@
import os
import pytest
from text_generation_server.generator import NeuronGenerator
from text_generation_server.model import fetch_model, is_cached
@pytest.fixture(scope="module")
def cached_model_id(neuron_model_config) -> str:
"""
Fixture to provide a cached model ID for testing.
This assumes the model is already cached in the local environment.
"""
export_kwargs = neuron_model_config["export_kwargs"]
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
yield neuron_model_config["model_id"]
os.environ.pop("MAX_BATCH_SIZE", None)
os.environ.pop("MAX_TOTAL_TOKENS", None)
os.environ.pop("HF_AUTO_CAST_TYPE", None)
os.environ.pop("HF_NUM_CORES", None)
def test_model_is_cached(cached_model_id):
assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached"
def test_fetch_cached_model(cached_model_id: str):
model_path = fetch_model(cached_model_id)
assert os.path.exists(
model_path
), f"Model {cached_model_id} was not fetched successfully"
assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory"
def test_generator_from_cached_model(cached_model_id: str):
generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)
assert generator is not None, "Generator could not be created from cached model"
assert generator.model is not None, "Generator model is not initialized"
assert generator.tokenizer is not None, "Generator tokenizer is not initialized"