diff --git a/Dockerfile b/Dockerfile
index 99fe6dde..93846d77 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -26,21 +26,18 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \
- MODEL_BASE_PATH=/data \
+ HUGGINGFACE_HUB_CACHE=/data \
MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \
- NUM_GPUS=1 \
+ NUM_SHARD=1 \
SAFETENSORS_FAST_GPU=1 \
PORT=80 \
- CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NCCL_ASYNC_ERROR_HANDLING=1 \
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
-SHELL ["/bin/bash", "-c"]
-
RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/*
RUN cd ~ && \
@@ -71,4 +68,5 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca
# Install launcher
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
-CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS --model-name $MODEL_ID --json-output
\ No newline at end of file
+ENTRYPOINT ["text-generation-launcher"]
+CMD ["--json-output"]
\ No newline at end of file
diff --git a/Makefile b/Makefile
index d427ff87..39017944 100644
--- a/Makefile
+++ b/Makefile
@@ -16,16 +16,16 @@ router-dev:
cd router && cargo run
run-bloom-560m:
- text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2
+ text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2
run-bloom-560m-quantize:
- text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
+ text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize
download-bloom:
text-generation-server download-weights bigscience/bloom
run-bloom:
- text-generation-launcher --model-name bigscience/bloom --num-shard 8
+ text-generation-launcher --model-id bigscience/bloom --num-shard 8
run-bloom-quantize:
- text-generation-launcher --model-name bigscience/bloom --num-shard 8 --quantize
\ No newline at end of file
+ text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize
\ No newline at end of file
diff --git a/README.md b/README.md
index d092781a..74d7a988 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,42 @@
+
+
# Text Generation Inference
-
+
+
+
+
+
+
+
+
+

-A Rust and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
-to power Bloom, BloomZ and MT0-XXL api-inference widgets.
+A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
+to power LLMs api-inference widgets.
+## Table of contents
+
+- [Features](#features)
+- [Officially Supported Models](#officially-supported-models)
+- [Get Started](#get-started)
+ - [Docker](#docker)
+ - [Local Install](#local-install)
+ - [OpenAPI](#api-documentation)
+ - [CUDA Kernels](#cuda-kernels)
+- [Run BLOOM](#run-bloom)
+ - [Download](#download)
+ - [Run](#run)
+ - [Quantization](#quantization)
+- [Develop](#develop)
+
## Features
+- Token streaming using Server Side Events (SSE)
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
@@ -36,30 +62,63 @@ or
`AutoModelForSeq2SeqLM.from_pretrained(
, device_map="auto")`
-## Load Tests for BLOOM
+## Get started
-See `k6/load_test.js`
+### Docker
-| | avg | min | med | max | p(90) | p(95) | RPS |
-|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
-| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
-| New batching logic | **5.44s** | **959.53ms** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
-
-## Install
+The easiest way of getting started is using the official Docker container:
```shell
-make install
+model=bigscience/bloom-560m
+num_shard=2
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run --gpus all -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
```
-## Run
-
-### BLOOM 560-m
+You can then query the model using either the `/generate` or `/generate_stream` routes:
```shell
+curl 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
+ -H 'Content-Type: application/json'
+```
+
+```shell
+curl 127.0.0.1:8080/generate_stream \
+ -X POST \
+ -d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
+ -H 'Content-Type: application/json'
+```
+
+To use GPUs, you will need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
+
+### API documentation
+
+You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
+The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
+
+### Local install
+
+You can also opt to install `text-generation-inference` locally. You will need to have cargo and Python installed on your
+machine
+
+```shell
+BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
make run-bloom-560m
```
-### BLOOM
+### CUDA Kernels
+
+The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
+the kernels by using the `BUILD_EXTENSIONS=False` environment variable.
+
+Be aware that the official Docker image has them enabled by default.
+
+## Run BLOOM
+
+### Download
First you need to download the weights:
@@ -67,26 +126,20 @@ First you need to download the weights:
make download-bloom
```
+### Run
+
```shell
make run-bloom # Requires 8xA100 80GB
```
+### Quantization
+
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
```shell
make run-bloom-quantize # Requires 8xA100 40GB
```
-## Test
-
-```shell
-curl 127.0.0.1:3000/generate \
- -v \
- -X POST \
- -d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
- -H 'Content-Type: application/json'
-```
-
## Develop
```shell
diff --git a/aml/deployment.yaml b/aml/deployment.yaml
index 67690722..16ef3dc7 100644
--- a/aml/deployment.yaml
+++ b/aml/deployment.yaml
@@ -4,9 +4,9 @@ endpoint_name: bloom-inference
model: azureml:bloom:1
model_mount_path: /var/azureml-model
environment_variables:
- MODEL_BASE_PATH: /var/azureml-model/bloom
+ HUGGINGFACE_HUB_CACHE: /var/azureml-model/bloom
MODEL_ID: bigscience/bloom
- NUM_GPUS: 8
+ NUM_SHARD: 8
environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.3.1
inference_config:
diff --git a/assets/architecture.jpg b/assets/architecture.jpg
index e0a5f7c6..c4a511c9 100644
Binary files a/assets/architecture.jpg and b/assets/architecture.jpg differ
diff --git a/docs/index.html b/docs/index.html
index e00e7446..16d143d8 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -3,7 +3,7 @@
- My New API
+ Text Generation Inference API
diff --git a/launcher/src/main.rs b/launcher/src/main.rs
index dea6fcc8..3df6e911 100644
--- a/launcher/src/main.rs
+++ b/launcher/src/main.rs
@@ -19,7 +19,7 @@ use subprocess::{Popen, PopenConfig, PopenError, Redirection};
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)]
- model_name: String,
+ model_id: String,
#[clap(long, env)]
revision: Option,
#[clap(long, env)]
@@ -49,7 +49,7 @@ struct Args {
fn main() -> ExitCode {
// Pattern match configuration
let Args {
- model_name,
+ model_id,
revision,
num_shard,
quantize,
@@ -92,7 +92,7 @@ fn main() -> ExitCode {
// Start shard processes
for rank in 0..num_shard {
- let model_name = model_name.clone();
+ let model_id = model_id.clone();
let revision = revision.clone();
let uds_path = shard_uds_path.clone();
let master_addr = master_addr.clone();
@@ -101,7 +101,7 @@ fn main() -> ExitCode {
let shutdown_sender = shutdown_sender.clone();
thread::spawn(move || {
shard_manager(
- model_name,
+ model_id,
revision,
quantize,
uds_path,
@@ -167,7 +167,7 @@ fn main() -> ExitCode {
"--master-shard-uds-path".to_string(),
format!("{}-0", shard_uds_path),
"--tokenizer-name".to_string(),
- model_name,
+ model_id,
];
if json_output {
@@ -256,7 +256,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)]
fn shard_manager(
- model_name: String,
+ model_id: String,
revision: Option,
quantize: bool,
uds_path: String,
@@ -278,7 +278,7 @@ fn shard_manager(
let mut shard_argv = vec![
"text-generation-server".to_string(),
"serve".to_string(),
- model_name,
+ model_id,
"--uds-path".to_string(),
uds_path,
"--logger-level".to_string(),
diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs
index 9336f36e..b70b1628 100644
--- a/launcher/tests/integration_tests.rs
+++ b/launcher/tests/integration_tests.rs
@@ -29,11 +29,11 @@ struct GeneratedText {
details: Details,
}
-fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
+fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
let argv = vec![
"text-generation-launcher".to_string(),
- "--model-name".to_string(),
- model_name.clone(),
+ "--model-id".to_string(),
+ model_id.clone(),
"--num-shard".to_string(),
num_shard.to_string(),
"--port".to_string(),
@@ -75,16 +75,16 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
launcher.terminate().unwrap();
launcher.wait().unwrap();
- panic!("failed to launch {}", model_name)
+ panic!("failed to launch {}", model_id)
}
fn test_model(
- model_name: String,
+ model_id: String,
num_shard: usize,
port: usize,
master_port: usize,
) -> GeneratedText {
- let mut launcher = start_launcher(model_name, num_shard, port, master_port);
+ let mut launcher = start_launcher(model_id, num_shard, port, master_port);
let data = r#"
{
diff --git a/server/text_generation/cli.py b/server/text_generation/cli.py
index b133cb0a..1418a803 100644
--- a/server/text_generation/cli.py
+++ b/server/text_generation/cli.py
@@ -13,7 +13,7 @@ app = typer.Typer()
@app.command()
def serve(
- model_name: str,
+ model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
@@ -46,16 +46,16 @@ def serve(
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
- server.serve(model_name, revision, sharded, quantize, uds_path)
+ server.serve(model_id, revision, sharded, quantize, uds_path)
@app.command()
def download_weights(
- model_name: str,
+ model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
):
- utils.download_weights(model_name, revision, extension)
+ utils.download_weights(model_id, revision, extension)
if __name__ == "__main__":
diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py
index 15d8e97e..943d45e9 100644
--- a/server/text_generation/models/__init__.py
+++ b/server/text_generation/models/__init__.py
@@ -30,31 +30,31 @@ torch.backends.cudnn.allow_tf32 = True
def get_model(
- model_name: str, revision: Optional[str], sharded: bool, quantize: bool
+ model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
- config = AutoConfig.from_pretrained(model_name, revision=revision)
+ config = AutoConfig.from_pretrained(model_id, revision=revision)
if config.model_type == "bloom":
if sharded:
- return BLOOMSharded(model_name, revision, quantize=quantize)
+ return BLOOMSharded(model_id, revision, quantize=quantize)
else:
- return BLOOM(model_name, revision, quantize=quantize)
+ return BLOOM(model_id, revision, quantize=quantize)
elif config.model_type == "gpt_neox":
if sharded:
- return GPTNeoxSharded(model_name, revision, quantize=quantize)
+ return GPTNeoxSharded(model_id, revision, quantize=quantize)
else:
- return GPTNeox(model_name, revision, quantize=quantize)
- elif model_name.startswith("facebook/galactica"):
+ return GPTNeox(model_id, revision, quantize=quantize)
+ elif model_id.startswith("facebook/galactica"):
if sharded:
- return GalacticaSharded(model_name, revision, quantize=quantize)
+ return GalacticaSharded(model_id, revision, quantize=quantize)
else:
- return Galactica(model_name, revision, quantize=quantize)
- elif "santacoder" in model_name:
- return SantaCoder(model_name, revision, quantize)
+ return Galactica(model_id, revision, quantize=quantize)
+ elif "santacoder" in model_id:
+ return SantaCoder(model_id, revision, quantize)
else:
if sharded:
raise ValueError("sharded is not supported for AutoModel")
try:
- return CausalLM(model_name, revision, quantize=quantize)
+ return CausalLM(model_id, revision, quantize=quantize)
except Exception:
- return Seq2SeqLM(model_name, revision, quantize=quantize)
+ return Seq2SeqLM(model_id, revision, quantize=quantize)
diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py
index 4f55afc0..992d7b5b 100644
--- a/server/text_generation/models/bloom.py
+++ b/server/text_generation/models/bloom.py
@@ -57,10 +57,10 @@ class BLOOM(CausalLM):
class BLOOMSharded(BLOOM):
def __init__(
- self, model_name: str, revision: Optional[str] = None, quantize: bool = False
+ self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
- if not model_name.startswith("bigscience/bloom"):
- raise ValueError(f"Model {model_name} is not supported")
+ if not model_id.startswith("bigscience/bloom"):
+ raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
@@ -72,22 +72,20 @@ class BLOOMSharded(BLOOM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
- model_name, revision=revision, padding_side="left"
+ model_id, revision=revision, padding_side="left"
)
config = AutoConfig.from_pretrained(
- model_name, revision=revision, slow_but_exact=False, tp_parallel=True
+ model_id, revision=revision, slow_but_exact=False, tp_parallel=True
)
config.pad_token_id = 3
# Only download weights for small models
- if self.master and model_name == "bigscience/bloom-560m":
- download_weights(model_name, revision=revision, extension=".safetensors")
+ if self.master and model_id == "bigscience/bloom-560m":
+ download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
- filenames = weight_files(
- model_name, revision=revision, extension=".safetensors"
- )
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py
index 1d1945cd..f21423ea 100644
--- a/server/text_generation/models/causal_lm.py
+++ b/server/text_generation/models/causal_lm.py
@@ -232,7 +232,7 @@ class CausalLMBatch(Batch):
class CausalLM(Model):
- def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
+ def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@@ -244,10 +244,10 @@ class CausalLM(Model):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
- model_name, revision=revision, padding_side="left"
+ model_id, revision=revision, padding_side="left"
)
self.model = AutoModelForCausalLM.from_pretrained(
- model_name,
+ model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py
index be9b1699..f1dc8a30 100644
--- a/server/text_generation/models/galactica.py
+++ b/server/text_generation/models/galactica.py
@@ -149,10 +149,10 @@ class Galactica(CausalLM):
class GalacticaSharded(Galactica):
def __init__(
- self, model_name: str, revision: Optional[str] = None, quantize: bool = False
+ self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
- if not model_name.startswith("facebook/galactica"):
- raise ValueError(f"Model {model_name} is not supported")
+ if not model_id.startswith("facebook/galactica"):
+ raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
@@ -164,22 +164,20 @@ class GalacticaSharded(Galactica):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
- model_name, revision=revision, padding_side="left"
+ model_id, revision=revision, padding_side="left"
)
config = AutoConfig.from_pretrained(
- model_name, revision=revision, tp_parallel=True
+ model_id, revision=revision, tp_parallel=True
)
tokenizer.pad_token_id = config.pad_token_id
# Only download weights for small models
- if self.master and model_name == "facebook/galactica-125m":
- download_weights(model_name, revision=revision, extension=".safetensors")
+ if self.master and model_id == "facebook/galactica-125m":
+ download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
- filenames = weight_files(
- model_name, revision=revision, extension=".safetensors"
- )
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation/models/gpt_neox.py
index a8f7f365..2d467f4c 100644
--- a/server/text_generation/models/gpt_neox.py
+++ b/server/text_generation/models/gpt_neox.py
@@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
class GPTNeoxSharded(GPTNeox):
def __init__(
- self, model_name: str, revision: Optional[str] = None, quantize: bool = False
+ self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
@@ -61,22 +61,20 @@ class GPTNeoxSharded(GPTNeox):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
- model_name, revision=revision, padding_side="left"
+ model_id, revision=revision, padding_side="left"
)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(
- model_name, revision=revision, tp_parallel=True
+ model_id, revision=revision, tp_parallel=True
)
# Only master download weights
if self.master:
- download_weights(model_name, revision=revision, extension=".safetensors")
+ download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
- filenames = weight_files(
- model_name, revision=revision, extension=".safetensors"
- )
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
diff --git a/server/text_generation/models/santacoder.py b/server/text_generation/models/santacoder.py
index 6c1a250f..fb496197 100644
--- a/server/text_generation/models/santacoder.py
+++ b/server/text_generation/models/santacoder.py
@@ -14,7 +14,7 @@ EOD = "<|endoftext|>"
class SantaCoder(CausalLM):
- def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
+ def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
- model_name, revision=revision, padding_side="left"
+ model_id, revision=revision, padding_side="left"
)
tokenizer.add_special_tokens(
{
@@ -43,7 +43,7 @@ class SantaCoder(CausalLM):
self.model = (
AutoModelForCausalLM.from_pretrained(
- model_name,
+ model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize,
diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py
index 1ae266d8..27cbe1c0 100644
--- a/server/text_generation/models/seq2seq_lm.py
+++ b/server/text_generation/models/seq2seq_lm.py
@@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch):
class Seq2SeqLM(Model):
- def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
+ def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@@ -301,14 +301,14 @@ class Seq2SeqLM(Model):
dtype = torch.float32
self.model = AutoModelForSeq2SeqLM.from_pretrained(
- model_name,
+ model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
).eval()
tokenizer = AutoTokenizer.from_pretrained(
- model_name, revision=revision, padding_side="left"
+ model_id, revision=revision, padding_side="left"
)
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
diff --git a/server/text_generation/server.py b/server/text_generation/server.py
index 852deebf..a8a9da6c 100644
--- a/server/text_generation/server.py
+++ b/server/text_generation/server.py
@@ -66,14 +66,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve(
- model_name: str,
+ model_id: str,
revision: Optional[str],
sharded: bool,
quantize: bool,
uds_path: Path,
):
async def serve_inner(
- model_name: str,
+ model_id: str,
revision: Optional[str],
sharded: bool = False,
quantize: bool = False,
@@ -89,7 +89,7 @@ def serve(
local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url]
- model = get_model(model_name, revision, sharded, quantize)
+ model = get_model(model_id, revision, sharded, quantize)
server = aio.server(interceptors=[ExceptionInterceptor()])
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
@@ -109,4 +109,4 @@ def serve(
logger.info("Signal received. Shutting down")
await server.stop(0)
- asyncio.run(serve_inner(model_name, revision, sharded, quantize))
+ asyncio.run(serve_inner(model_id, revision, sharded, quantize))
diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py
index ea97ed4a..83458969 100644
--- a/server/text_generation/utils.py
+++ b/server/text_generation/utils.py
@@ -182,20 +182,20 @@ def initialize_torch_distributed():
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
-def weight_hub_files(model_name, revision=None, extension=".safetensors"):
+def weight_hub_files(model_id, revision=None, extension=".safetensors"):
"""Get the safetensors filenames on the hub"""
api = HfApi()
- info = api.model_info(model_name, revision=revision)
+ info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames
-def try_to_load_from_cache(model_name, revision, filename):
+def try_to_load_from_cache(model_id, revision, filename):
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
- object_id = model_name.replace("/", "--")
+ object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
@@ -230,38 +230,38 @@ def try_to_load_from_cache(model_name, revision, filename):
return str(cached_file) if cached_file.is_file() else None
-def weight_files(model_name, revision=None, extension=".safetensors"):
+def weight_files(model_id, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
- filenames = weight_hub_files(model_name, revision, extension)
+ filenames = weight_hub_files(model_id, revision, extension)
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
- model_name, revision=revision, filename=filename
+ model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
- f"File {filename} of model {model_name} not found in "
+ f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
- f"Please run `text-generation-server download-weights {model_name}` first."
+ f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
-def download_weights(model_name, revision=None, extension=".safetensors"):
+def download_weights(model_id, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
- filenames = weight_hub_files(model_name, revision, extension)
+ filenames = weight_hub_files(model_id, revision, extension)
download_function = partial(
hf_hub_download,
- repo_id=model_name,
+ repo_id=model_id,
local_files_only=False,
)