From 775115e3a571bc5d983852ce31c071a7734c1430 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 1 Feb 2023 16:22:10 +0100 Subject: [PATCH] feat(server): allow the server to use a local weight cache (#49) --- launcher/src/main.rs | 6 ++++++ server/text_generation/utils.py | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index bd449e28..20ec7faa 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -313,6 +313,12 @@ fn shard_manager( env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; + // If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard + // Useful when running inside a HuggingFace Inference Endpoint + if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") { + env.push(("WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into())); + }; + // If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into())); diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 91e6b7b7..8590f85c 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -25,6 +25,7 @@ from transformers.generation.logits_process import ( from text_generation.pb import generate_pb2 +WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) class Sampling: def __init__(self, seed: int, device: str = "cpu"): @@ -230,6 +231,9 @@ def try_to_load_from_cache(model_name, revision, filename): def weight_files(model_name, revision=None, extension=".safetensors"): """Get the local safetensors filenames""" + if WEIGHTS_CACHE_OVERRIDE is not None: + return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) + filenames = weight_hub_files(model_name, revision, extension) files = [] for filename in filenames: @@ -249,6 +253,9 @@ def weight_files(model_name, revision=None, extension=".safetensors"): def download_weights(model_name, revision=None, extension=".safetensors"): """Download the safetensors files from the hub""" + if WEIGHTS_CACHE_OVERRIDE is not None: + return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) + filenames = weight_hub_files(model_name, revision, extension) download_function = partial(