import os import shutil import time from typing import Optional from huggingface_hub import snapshot_download from huggingface_hub.constants import HF_HUB_CACHE from loguru import logger from transformers import AutoConfig from optimum.neuron import NeuronModelForCausalLM from optimum.neuron.utils import get_hub_cached_entries def get_export_kwargs_from_env(): batch_size = os.environ.get("MAX_BATCH_SIZE", None) if batch_size is not None: batch_size = int(batch_size) sequence_length = os.environ.get("MAX_TOTAL_TOKENS", None) if sequence_length is not None: sequence_length = int(sequence_length) num_cores = os.environ.get("HF_NUM_CORES", None) if num_cores is not None: num_cores = int(num_cores) auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None) return { "task": "text-generation", "batch_size": batch_size, "sequence_length": sequence_length, "num_cores": num_cores, "auto_cast_type": auto_cast_type, } def is_cached(model_id, neuron_config): # Look for cached entries for the specified model in_cache = False entries = get_hub_cached_entries(model_id, "inference") # Look for compatible entries for entry in entries: compatible = True for key, value in neuron_config.items(): # Only weights can be different if key in ["checkpoint_id", "checkpoint_revision"]: continue if entry[key] != value: compatible = False if compatible: in_cache = True break return in_cache def log_cache_size(): path = HF_HUB_CACHE if os.path.exists(path): usage = shutil.disk_usage(path) gb = 2**30 logger.info( f"Cache disk [{path}]: total = {usage.total / gb:.2f} G, free = {usage.free / gb:.2f} G" ) else: raise ValueError(f"The cache directory ({path}) does not exist.") def fetch_model( model_id: str, revision: Optional[str] = None, ) -> str: """Fetch a neuron model. Args: model_id (`str`): The *model_id* of a model on the HuggingFace hub or the path to a local model. revision (`Optional[str]`, defaults to `None`): The revision of the model on the HuggingFace hub. Returns: A string corresponding to the model_id or path. """ if not os.path.isdir("/sys/class/neuron_device/"): raise SystemError("No neuron cores detected on the host.") if os.path.isdir(model_id) and revision is not None: logger.warning( "Revision {} ignored for local model at {}".format(revision, model_id) ) revision = None # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) # Note that the model may already be present in the cache. config = AutoConfig.from_pretrained(model_id, revision=revision) neuron_config = getattr(config, "neuron", None) if neuron_config is not None: if os.path.isdir(model_id): return model_id # Prefetch the neuron model from the Hub logger.info( f"Fetching revision [{revision}] for neuron model {model_id} under {HF_HUB_CACHE}" ) log_cache_size() return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") # Model needs to be exported: look for compatible cached entries on the hub export_kwargs = get_export_kwargs_from_env() export_config = NeuronModelForCausalLM.get_export_config( model_id, config, revision=revision, **export_kwargs ) neuron_config = export_config.neuron if not is_cached(model_id, neuron_config): hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache" neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi" error_msg = ( f"No cached version found for {model_id} with {neuron_config}." f"You can start a discussion to request it on {hub_cache_url}" f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}" ) raise ValueError(error_msg) logger.warning( f"{model_id} is not a neuron model: it will be exported using cached artifacts." ) if os.path.isdir(model_id): return model_id # Prefetch weights, tokenizer and generation config so that they are in cache log_cache_size() start = time.time() snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") end = time.time() logger.info(f"Model weights fetched in {end - start:.2f} s.") log_cache_size() return model_id