diff --git a/Dockerfile b/Dockerfile index 99fe6dde..93846d77 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,21 +26,18 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 ENV LANG=C.UTF-8 \ LC_ALL=C.UTF-8 \ DEBIAN_FRONTEND=noninteractive \ - MODEL_BASE_PATH=/data \ + HUGGINGFACE_HUB_CACHE=/data \ MODEL_ID=bigscience/bloom-560m \ QUANTIZE=false \ - NUM_GPUS=1 \ + NUM_SHARD=1 \ SAFETENSORS_FAST_GPU=1 \ PORT=80 \ - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ NCCL_ASYNC_ERROR_HANDLING=1 \ CUDA_HOME=/usr/local/cuda \ LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \ CONDA_DEFAULT_ENV=text-generation \ PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin -SHELL ["/bin/bash", "-c"] - RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/* RUN cd ~ && \ @@ -71,4 +68,5 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca # Install launcher COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher -CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS --model-name $MODEL_ID --json-output \ No newline at end of file +ENTRYPOINT ["text-generation-launcher"] +CMD ["--json-output"] \ No newline at end of file diff --git a/Makefile b/Makefile index d427ff87..39017944 100644 --- a/Makefile +++ b/Makefile @@ -16,16 +16,16 @@ router-dev: cd router && cargo run run-bloom-560m: - text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 + text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 run-bloom-560m-quantize: - text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize + text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize download-bloom: text-generation-server download-weights bigscience/bloom run-bloom: - text-generation-launcher --model-name bigscience/bloom --num-shard 8 + text-generation-launcher --model-id bigscience/bloom --num-shard 8 run-bloom-quantize: - text-generation-launcher --model-name bigscience/bloom --num-shard 8 --quantize \ No newline at end of file + text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize \ No newline at end of file diff --git a/README.md b/README.md index d092781a..74d7a988 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,42 @@ +
+ # Text Generation Inference -
+ + GitHub Repo stars + + + License + + + Swagger API documentation + ![architecture](assets/architecture.jpg)
-A Rust and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) -to power Bloom, BloomZ and MT0-XXL api-inference widgets. +A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) +to power LLMs api-inference widgets. +## Table of contents + +- [Features](#features) +- [Officially Supported Models](#officially-supported-models) +- [Get Started](#get-started) + - [Docker](#docker) + - [Local Install](#local-install) + - [OpenAPI](#api-documentation) + - [CUDA Kernels](#cuda-kernels) +- [Run BLOOM](#run-bloom) + - [Download](#download) + - [Run](#run) + - [Quantization](#quantization) +- [Develop](#develop) + ## Features +- Token streaming using Server Side Events (SSE) - [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput - Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - [Safetensors](https://github.com/huggingface/safetensors) weight loading @@ -36,30 +62,63 @@ or `AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto")` -## Load Tests for BLOOM +## Get started -See `k6/load_test.js` +### Docker -| | avg | min | med | max | p(90) | p(95) | RPS | -|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------| -| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 | -| New batching logic | **5.44s** | **959.53ms** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** | - -## Install +The easiest way of getting started is using the official Docker container: ```shell -make install +model=bigscience/bloom-560m +num_shard=2 +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --gpus all -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard ``` -## Run - -### BLOOM 560-m +You can then query the model using either the `/generate` or `/generate_stream` routes: ```shell +curl 127.0.0.1:8080/generate \ + -X POST \ + -d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \ + -H 'Content-Type: application/json' +``` + +```shell +curl 127.0.0.1:8080/generate_stream \ + -X POST \ + -d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \ + -H 'Content-Type: application/json' +``` + +To use GPUs, you will need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). + +### API documentation + +You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. +The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference). + +### Local install + +You can also opt to install `text-generation-inference` locally. You will need to have cargo and Python installed on your +machine + +```shell +BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels make run-bloom-560m ``` -### BLOOM +### CUDA Kernels + +The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove +the kernels by using the `BUILD_EXTENSIONS=False` environment variable. + +Be aware that the official Docker image has them enabled by default. + +## Run BLOOM + +### Download First you need to download the weights: @@ -67,26 +126,20 @@ First you need to download the weights: make download-bloom ``` +### Run + ```shell make run-bloom # Requires 8xA100 80GB ``` +### Quantization + You can also quantize the weights with bitsandbytes to reduce the VRAM requirement: ```shell make run-bloom-quantize # Requires 8xA100 40GB ``` -## Test - -```shell -curl 127.0.0.1:3000/generate \ - -v \ - -X POST \ - -d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \ - -H 'Content-Type: application/json' -``` - ## Develop ```shell diff --git a/aml/deployment.yaml b/aml/deployment.yaml index 67690722..16ef3dc7 100644 --- a/aml/deployment.yaml +++ b/aml/deployment.yaml @@ -4,9 +4,9 @@ endpoint_name: bloom-inference model: azureml:bloom:1 model_mount_path: /var/azureml-model environment_variables: - MODEL_BASE_PATH: /var/azureml-model/bloom + HUGGINGFACE_HUB_CACHE: /var/azureml-model/bloom MODEL_ID: bigscience/bloom - NUM_GPUS: 8 + NUM_SHARD: 8 environment: image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.3.1 inference_config: diff --git a/assets/architecture.jpg b/assets/architecture.jpg index e0a5f7c6..c4a511c9 100644 Binary files a/assets/architecture.jpg and b/assets/architecture.jpg differ diff --git a/docs/index.html b/docs/index.html index e00e7446..16d143d8 100644 --- a/docs/index.html +++ b/docs/index.html @@ -3,7 +3,7 @@ - My New API + Text Generation Inference API
diff --git a/launcher/src/main.rs b/launcher/src/main.rs index dea6fcc8..3df6e911 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -19,7 +19,7 @@ use subprocess::{Popen, PopenConfig, PopenError, Redirection}; #[clap(author, version, about, long_about = None)] struct Args { #[clap(default_value = "bigscience/bloom-560m", long, env)] - model_name: String, + model_id: String, #[clap(long, env)] revision: Option, #[clap(long, env)] @@ -49,7 +49,7 @@ struct Args { fn main() -> ExitCode { // Pattern match configuration let Args { - model_name, + model_id, revision, num_shard, quantize, @@ -92,7 +92,7 @@ fn main() -> ExitCode { // Start shard processes for rank in 0..num_shard { - let model_name = model_name.clone(); + let model_id = model_id.clone(); let revision = revision.clone(); let uds_path = shard_uds_path.clone(); let master_addr = master_addr.clone(); @@ -101,7 +101,7 @@ fn main() -> ExitCode { let shutdown_sender = shutdown_sender.clone(); thread::spawn(move || { shard_manager( - model_name, + model_id, revision, quantize, uds_path, @@ -167,7 +167,7 @@ fn main() -> ExitCode { "--master-shard-uds-path".to_string(), format!("{}-0", shard_uds_path), "--tokenizer-name".to_string(), - model_name, + model_id, ]; if json_output { @@ -256,7 +256,7 @@ enum ShardStatus { #[allow(clippy::too_many_arguments)] fn shard_manager( - model_name: String, + model_id: String, revision: Option, quantize: bool, uds_path: String, @@ -278,7 +278,7 @@ fn shard_manager( let mut shard_argv = vec![ "text-generation-server".to_string(), "serve".to_string(), - model_name, + model_id, "--uds-path".to_string(), uds_path, "--logger-level".to_string(), diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs index 9336f36e..b70b1628 100644 --- a/launcher/tests/integration_tests.rs +++ b/launcher/tests/integration_tests.rs @@ -29,11 +29,11 @@ struct GeneratedText { details: Details, } -fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen { +fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen { let argv = vec![ "text-generation-launcher".to_string(), - "--model-name".to_string(), - model_name.clone(), + "--model-id".to_string(), + model_id.clone(), "--num-shard".to_string(), num_shard.to_string(), "--port".to_string(), @@ -75,16 +75,16 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port launcher.terminate().unwrap(); launcher.wait().unwrap(); - panic!("failed to launch {}", model_name) + panic!("failed to launch {}", model_id) } fn test_model( - model_name: String, + model_id: String, num_shard: usize, port: usize, master_port: usize, ) -> GeneratedText { - let mut launcher = start_launcher(model_name, num_shard, port, master_port); + let mut launcher = start_launcher(model_id, num_shard, port, master_port); let data = r#" { diff --git a/server/text_generation/cli.py b/server/text_generation/cli.py index b133cb0a..1418a803 100644 --- a/server/text_generation/cli.py +++ b/server/text_generation/cli.py @@ -13,7 +13,7 @@ app = typer.Typer() @app.command() def serve( - model_name: str, + model_id: str, revision: Optional[str] = None, sharded: bool = False, quantize: bool = False, @@ -46,16 +46,16 @@ def serve( os.getenv("MASTER_PORT", None) is not None ), "MASTER_PORT must be set when sharded is True" - server.serve(model_name, revision, sharded, quantize, uds_path) + server.serve(model_id, revision, sharded, quantize, uds_path) @app.command() def download_weights( - model_name: str, + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", ): - utils.download_weights(model_name, revision, extension) + utils.download_weights(model_id, revision, extension) if __name__ == "__main__": diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 15d8e97e..943d45e9 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -30,31 +30,31 @@ torch.backends.cudnn.allow_tf32 = True def get_model( - model_name: str, revision: Optional[str], sharded: bool, quantize: bool + model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: - config = AutoConfig.from_pretrained(model_name, revision=revision) + config = AutoConfig.from_pretrained(model_id, revision=revision) if config.model_type == "bloom": if sharded: - return BLOOMSharded(model_name, revision, quantize=quantize) + return BLOOMSharded(model_id, revision, quantize=quantize) else: - return BLOOM(model_name, revision, quantize=quantize) + return BLOOM(model_id, revision, quantize=quantize) elif config.model_type == "gpt_neox": if sharded: - return GPTNeoxSharded(model_name, revision, quantize=quantize) + return GPTNeoxSharded(model_id, revision, quantize=quantize) else: - return GPTNeox(model_name, revision, quantize=quantize) - elif model_name.startswith("facebook/galactica"): + return GPTNeox(model_id, revision, quantize=quantize) + elif model_id.startswith("facebook/galactica"): if sharded: - return GalacticaSharded(model_name, revision, quantize=quantize) + return GalacticaSharded(model_id, revision, quantize=quantize) else: - return Galactica(model_name, revision, quantize=quantize) - elif "santacoder" in model_name: - return SantaCoder(model_name, revision, quantize) + return Galactica(model_id, revision, quantize=quantize) + elif "santacoder" in model_id: + return SantaCoder(model_id, revision, quantize) else: if sharded: raise ValueError("sharded is not supported for AutoModel") try: - return CausalLM(model_name, revision, quantize=quantize) + return CausalLM(model_id, revision, quantize=quantize) except Exception: - return Seq2SeqLM(model_name, revision, quantize=quantize) + return Seq2SeqLM(model_id, revision, quantize=quantize) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 4f55afc0..992d7b5b 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -57,10 +57,10 @@ class BLOOM(CausalLM): class BLOOMSharded(BLOOM): def __init__( - self, model_name: str, revision: Optional[str] = None, quantize: bool = False + self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - if not model_name.startswith("bigscience/bloom"): - raise ValueError(f"Model {model_name} is not supported") + if not model_id.startswith("bigscience/bloom"): + raise ValueError(f"Model {model_id} is not supported") self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -72,22 +72,20 @@ class BLOOMSharded(BLOOM): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_name, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left" ) config = AutoConfig.from_pretrained( - model_name, revision=revision, slow_but_exact=False, tp_parallel=True + model_id, revision=revision, slow_but_exact=False, tp_parallel=True ) config.pad_token_id = 3 # Only download weights for small models - if self.master and model_name == "bigscience/bloom-560m": - download_weights(model_name, revision=revision, extension=".safetensors") + if self.master and model_id == "bigscience/bloom-560m": + download_weights(model_id, revision=revision, extension=".safetensors") torch.distributed.barrier(group=self.process_group) - filenames = weight_files( - model_name, revision=revision, extension=".safetensors" - ) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") if not filenames: raise ValueError("No safetensors weights found") diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 1d1945cd..f21423ea 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -232,7 +232,7 @@ class CausalLMBatch(Batch): class CausalLM(Model): - def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): + def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -244,10 +244,10 @@ class CausalLM(Model): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_name, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left" ) self.model = AutoModelForCausalLM.from_pretrained( - model_name, + model_id, revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index be9b1699..f1dc8a30 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -149,10 +149,10 @@ class Galactica(CausalLM): class GalacticaSharded(Galactica): def __init__( - self, model_name: str, revision: Optional[str] = None, quantize: bool = False + self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - if not model_name.startswith("facebook/galactica"): - raise ValueError(f"Model {model_name} is not supported") + if not model_id.startswith("facebook/galactica"): + raise ValueError(f"Model {model_id} is not supported") self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -164,22 +164,20 @@ class GalacticaSharded(Galactica): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_name, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left" ) config = AutoConfig.from_pretrained( - model_name, revision=revision, tp_parallel=True + model_id, revision=revision, tp_parallel=True ) tokenizer.pad_token_id = config.pad_token_id # Only download weights for small models - if self.master and model_name == "facebook/galactica-125m": - download_weights(model_name, revision=revision, extension=".safetensors") + if self.master and model_id == "facebook/galactica-125m": + download_weights(model_id, revision=revision, extension=".safetensors") torch.distributed.barrier(group=self.process_group) - filenames = weight_files( - model_name, revision=revision, extension=".safetensors" - ) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") if not filenames: raise ValueError("No safetensors weights found") diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py index a8f7f365..2d467f4c 100644 --- a/server/text_generation/models/gpt_neox.py +++ b/server/text_generation/models/gpt_neox.py @@ -49,7 +49,7 @@ class GPTNeox(CausalLM): class GPTNeoxSharded(GPTNeox): def __init__( - self, model_name: str, revision: Optional[str] = None, quantize: bool = False + self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -61,22 +61,20 @@ class GPTNeoxSharded(GPTNeox): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_name, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left" ) tokenizer.pad_token = tokenizer.eos_token config = AutoConfig.from_pretrained( - model_name, revision=revision, tp_parallel=True + model_id, revision=revision, tp_parallel=True ) # Only master download weights if self.master: - download_weights(model_name, revision=revision, extension=".safetensors") + download_weights(model_id, revision=revision, extension=".safetensors") torch.distributed.barrier(group=self.process_group) - filenames = weight_files( - model_name, revision=revision, extension=".safetensors" - ) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") if not filenames: raise ValueError("No safetensors weights found") diff --git a/server/text_generation/models/santacoder.py b/server/text_generation/models/santacoder.py index 6c1a250f..fb496197 100644 --- a/server/text_generation/models/santacoder.py +++ b/server/text_generation/models/santacoder.py @@ -14,7 +14,7 @@ EOD = "<|endoftext|>" class SantaCoder(CausalLM): - def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): + def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -26,7 +26,7 @@ class SantaCoder(CausalLM): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_name, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left" ) tokenizer.add_special_tokens( { @@ -43,7 +43,7 @@ class SantaCoder(CausalLM): self.model = ( AutoModelForCausalLM.from_pretrained( - model_name, + model_id, revision=revision, torch_dtype=dtype, load_in_8bit=quantize, diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 1ae266d8..27cbe1c0 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch): class Seq2SeqLM(Model): - def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): + def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -301,14 +301,14 @@ class Seq2SeqLM(Model): dtype = torch.float32 self.model = AutoModelForSeq2SeqLM.from_pretrained( - model_name, + model_id, revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, load_in_8bit=quantize, ).eval() tokenizer = AutoTokenizer.from_pretrained( - model_name, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left" ) tokenizer.bos_token_id = self.model.config.decoder_start_token_id diff --git a/server/text_generation/server.py b/server/text_generation/server.py index 852deebf..a8a9da6c 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -66,14 +66,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( - model_name: str, + model_id: str, revision: Optional[str], sharded: bool, quantize: bool, uds_path: Path, ): async def serve_inner( - model_name: str, + model_id: str, revision: Optional[str], sharded: bool = False, quantize: bool = False, @@ -89,7 +89,7 @@ def serve( local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] - model = get_model(model_name, revision, sharded, quantize) + model = get_model(model_id, revision, sharded, quantize) server = aio.server(interceptors=[ExceptionInterceptor()]) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( @@ -109,4 +109,4 @@ def serve( logger.info("Signal received. Shutting down") await server.stop(0) - asyncio.run(serve_inner(model_name, revision, sharded, quantize)) + asyncio.run(serve_inner(model_id, revision, sharded, quantize)) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index ea97ed4a..83458969 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -182,20 +182,20 @@ def initialize_torch_distributed(): return torch.distributed.distributed_c10d._get_default_group(), rank, world_size -def weight_hub_files(model_name, revision=None, extension=".safetensors"): +def weight_hub_files(model_id, revision=None, extension=".safetensors"): """Get the safetensors filenames on the hub""" api = HfApi() - info = api.model_info(model_name, revision=revision) + info = api.model_info(model_id, revision=revision) filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] return filenames -def try_to_load_from_cache(model_name, revision, filename): +def try_to_load_from_cache(model_id, revision, filename): """Try to load a file from the Hugging Face cache""" if revision is None: revision = "main" - object_id = model_name.replace("/", "--") + object_id = model_id.replace("/", "--") repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" if not repo_cache.is_dir(): @@ -230,38 +230,38 @@ def try_to_load_from_cache(model_name, revision, filename): return str(cached_file) if cached_file.is_file() else None -def weight_files(model_name, revision=None, extension=".safetensors"): +def weight_files(model_id, revision=None, extension=".safetensors"): """Get the local safetensors filenames""" if WEIGHTS_CACHE_OVERRIDE is not None: return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) - filenames = weight_hub_files(model_name, revision, extension) + filenames = weight_hub_files(model_id, revision, extension) files = [] for filename in filenames: cache_file = try_to_load_from_cache( - model_name, revision=revision, filename=filename + model_id, revision=revision, filename=filename ) if cache_file is None: raise LocalEntryNotFoundError( - f"File {filename} of model {model_name} not found in " + f"File {filename} of model {model_id} not found in " f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " - f"Please run `text-generation-server download-weights {model_name}` first." + f"Please run `text-generation-server download-weights {model_id}` first." ) files.append(cache_file) return files -def download_weights(model_name, revision=None, extension=".safetensors"): +def download_weights(model_id, revision=None, extension=".safetensors"): """Download the safetensors files from the hub""" if WEIGHTS_CACHE_OVERRIDE is not None: return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) - filenames = weight_hub_files(model_name, revision, extension) + filenames = weight_hub_files(model_id, revision, extension) download_function = partial( hf_hub_download, - repo_id=model_name, + repo_id=model_id, local_files_only=False, )