text-generation-inference/backends/neuron/server/text_generation_server/model.py

129 lines
4.7 KiB
Python

import os
import shutil
import time
from typing import Optional
from huggingface_hub import snapshot_download
from huggingface_hub.constants import HF_HUB_CACHE
from loguru import logger
from transformers import AutoConfig
from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.utils import get_hub_cached_entries
def get_export_kwargs_from_env():
batch_size = os.environ.get("MAX_BATCH_SIZE", None)
if batch_size is not None:
batch_size = int(batch_size)
sequence_length = os.environ.get("MAX_TOTAL_TOKENS", None)
if sequence_length is not None:
sequence_length = int(sequence_length)
num_cores = os.environ.get("HF_NUM_CORES", None)
if num_cores is not None:
num_cores = int(num_cores)
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
return {
"task": "text-generation",
"batch_size": batch_size,
"sequence_length": sequence_length,
"num_cores": num_cores,
"auto_cast_type": auto_cast_type,
}
def is_cached(model_id, neuron_config):
# Look for cached entries for the specified model
in_cache = False
entries = get_hub_cached_entries(model_id, "inference")
# Look for compatible entries
for entry in entries:
compatible = True
for key, value in neuron_config.items():
# Only weights can be different
if key in ["checkpoint_id", "checkpoint_revision"]:
continue
if entry[key] != value:
compatible = False
if compatible:
in_cache = True
break
return in_cache
def log_cache_size():
path = HF_HUB_CACHE
if os.path.exists(path):
usage = shutil.disk_usage(path)
gb = 2**30
logger.info(
f"Cache disk [{path}]: total = {usage.total / gb:.2f} G, free = {usage.free / gb:.2f} G"
)
else:
raise ValueError(f"The cache directory ({path}) does not exist.")
def fetch_model(
model_id: str,
revision: Optional[str] = None,
) -> str:
"""Fetch a neuron model.
Args:
model_id (`str`):
The *model_id* of a model on the HuggingFace hub or the path to a local model.
revision (`Optional[str]`, defaults to `None`):
The revision of the model on the HuggingFace hub.
Returns:
A string corresponding to the model_id or path.
"""
if not os.path.isdir("/sys/class/neuron_device/"):
raise SystemError("No neuron cores detected on the host.")
if os.path.isdir(model_id) and revision is not None:
logger.warning(
"Revision {} ignored for local model at {}".format(revision, model_id)
)
revision = None
# Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
# Note that the model may already be present in the cache.
config = AutoConfig.from_pretrained(model_id, revision=revision)
neuron_config = getattr(config, "neuron", None)
if neuron_config is not None:
if os.path.isdir(model_id):
return model_id
# Prefetch the neuron model from the Hub
logger.info(
f"Fetching revision [{revision}] for neuron model {model_id} under {HF_HUB_CACHE}"
)
log_cache_size()
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
# Model needs to be exported: look for compatible cached entries on the hub
export_kwargs = get_export_kwargs_from_env()
export_config = NeuronModelForCausalLM.get_export_config(
model_id, config, revision=revision, **export_kwargs
)
neuron_config = export_config.neuron
if not is_cached(model_id, neuron_config):
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
error_msg = (
f"No cached version found for {model_id} with {neuron_config}."
f"You can start a discussion to request it on {hub_cache_url}"
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
)
raise ValueError(error_msg)
logger.warning(
f"{model_id} is not a neuron model: it will be exported using cached artifacts."
)
if os.path.isdir(model_id):
return model_id
# Prefetch weights, tokenizer and generation config so that they are in cache
log_cache_size()
start = time.time()
snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
end = time.time()
logger.info(f"Model weights fetched in {end - start:.2f} s.")
log_cache_size()
return model_id