mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix(nxd): adapt model retrieval to new APIs
This commit is contained in:
parent
39895019c8
commit
b916076c72
@ -6,13 +6,14 @@ from typing import Optional
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.constants import HF_HUB_CACHE
|
from huggingface_hub.constants import HF_HUB_CACHE
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
from optimum.neuron import NeuronModelForCausalLM
|
|
||||||
from optimum.neuron.cache import get_hub_cached_entries
|
from optimum.neuron.cache import get_hub_cached_entries
|
||||||
from optimum.neuron.configuration_utils import NeuronConfig
|
from optimum.neuron.configuration_utils import NeuronConfig
|
||||||
|
|
||||||
|
|
||||||
|
from .tgi_env import check_env_and_neuron_config_compatibility
|
||||||
|
|
||||||
|
|
||||||
def get_export_kwargs_from_env():
|
def get_export_kwargs_from_env():
|
||||||
batch_size = os.environ.get("MAX_BATCH_SIZE", None)
|
batch_size = os.environ.get("MAX_BATCH_SIZE", None)
|
||||||
if batch_size is not None:
|
if batch_size is not None:
|
||||||
@ -25,7 +26,6 @@ def get_export_kwargs_from_env():
|
|||||||
num_cores = int(num_cores)
|
num_cores = int(num_cores)
|
||||||
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
|
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
|
||||||
return {
|
return {
|
||||||
"task": "text-generation",
|
|
||||||
"batch_size": batch_size,
|
"batch_size": batch_size,
|
||||||
"sequence_length": sequence_length,
|
"sequence_length": sequence_length,
|
||||||
"num_cores": num_cores,
|
"num_cores": num_cores,
|
||||||
@ -33,20 +33,15 @@ def get_export_kwargs_from_env():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def is_cached(model_id, neuron_config):
|
def is_cached(model_id):
|
||||||
# Look for cached entries for the specified model
|
# Look for cached entries for the specified model
|
||||||
in_cache = False
|
in_cache = False
|
||||||
entries = get_hub_cached_entries(model_id, "inference")
|
entries = get_hub_cached_entries(model_id)
|
||||||
# Look for compatible entries
|
# Look for compatible entries
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
compatible = True
|
if check_env_and_neuron_config_compatibility(
|
||||||
for key, value in neuron_config.items():
|
entry, check_compiler_version=True
|
||||||
# Only weights can be different
|
):
|
||||||
if key in ["checkpoint_id", "checkpoint_revision"]:
|
|
||||||
continue
|
|
||||||
if entry[key] != value:
|
|
||||||
compatible = False
|
|
||||||
if compatible:
|
|
||||||
in_cache = True
|
in_cache = True
|
||||||
break
|
break
|
||||||
return in_cache
|
return in_cache
|
||||||
@ -108,17 +103,11 @@ def fetch_model(
|
|||||||
log_cache_size()
|
log_cache_size()
|
||||||
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
|
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
|
||||||
# Model needs to be exported: look for compatible cached entries on the hub
|
# Model needs to be exported: look for compatible cached entries on the hub
|
||||||
export_kwargs = get_export_kwargs_from_env()
|
if not is_cached(model_id):
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
|
||||||
export_config = NeuronModelForCausalLM.get_export_config(
|
|
||||||
model_id, config, revision=revision, **export_kwargs
|
|
||||||
)
|
|
||||||
neuron_config = export_config.neuron
|
|
||||||
if not is_cached(model_id, neuron_config):
|
|
||||||
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
|
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
|
||||||
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
|
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"No cached version found for {model_id} with {neuron_config}."
|
f"No cached version found for {model_id} with {get_export_kwargs_from_env()}."
|
||||||
f"You can start a discussion to request it on {hub_cache_url}"
|
f"You can start a discussion to request it on {hub_cache_url}"
|
||||||
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
|
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
|
||||||
)
|
)
|
||||||
@ -131,8 +120,10 @@ def fetch_model(
|
|||||||
# Prefetch weights, tokenizer and generation config so that they are in cache
|
# Prefetch weights, tokenizer and generation config so that they are in cache
|
||||||
log_cache_size()
|
log_cache_size()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
|
snapshot_path = snapshot_download(
|
||||||
|
model_id, revision=revision, ignore_patterns="*.bin"
|
||||||
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.info(f"Model weights fetched in {end - start:.2f} s.")
|
logger.info(f"Model weights fetched in {end - start:.2f} s.")
|
||||||
log_cache_size()
|
log_cache_size()
|
||||||
return model_id
|
return snapshot_path
|
||||||
|
42
backends/neuron/tests/server/test_cached_model.py
Normal file
42
backends/neuron/tests/server/test_cached_model.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation_server.generator import NeuronGenerator
|
||||||
|
from text_generation_server.model import fetch_model, is_cached
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def cached_model_id(neuron_model_config) -> str:
|
||||||
|
"""
|
||||||
|
Fixture to provide a cached model ID for testing.
|
||||||
|
This assumes the model is already cached in the local environment.
|
||||||
|
"""
|
||||||
|
export_kwargs = neuron_model_config["export_kwargs"]
|
||||||
|
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
|
||||||
|
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
|
||||||
|
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
|
||||||
|
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
|
||||||
|
yield neuron_model_config["model_id"]
|
||||||
|
os.environ.pop("MAX_BATCH_SIZE", None)
|
||||||
|
os.environ.pop("MAX_TOTAL_TOKENS", None)
|
||||||
|
os.environ.pop("HF_AUTO_CAST_TYPE", None)
|
||||||
|
os.environ.pop("HF_NUM_CORES", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_is_cached(cached_model_id):
|
||||||
|
assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_cached_model(cached_model_id: str):
|
||||||
|
model_path = fetch_model(cached_model_id)
|
||||||
|
assert os.path.exists(
|
||||||
|
model_path
|
||||||
|
), f"Model {cached_model_id} was not fetched successfully"
|
||||||
|
assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generator_from_cached_model(cached_model_id: str):
|
||||||
|
generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)
|
||||||
|
assert generator is not None, "Generator could not be created from cached model"
|
||||||
|
assert generator.model is not None, "Generator model is not initialized"
|
||||||
|
assert generator.tokenizer is not None, "Generator tokenizer is not initialized"
|
Loading…
Reference in New Issue
Block a user