mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
refactor doc
This commit is contained in:
parent
5de40eb078
commit
a7d15c38e8
10
Dockerfile
10
Dockerfile
@ -26,21 +26,18 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
|||||||
ENV LANG=C.UTF-8 \
|
ENV LANG=C.UTF-8 \
|
||||||
LC_ALL=C.UTF-8 \
|
LC_ALL=C.UTF-8 \
|
||||||
DEBIAN_FRONTEND=noninteractive \
|
DEBIAN_FRONTEND=noninteractive \
|
||||||
MODEL_BASE_PATH=/data \
|
HUGGINGFACE_HUB_CACHE=/data \
|
||||||
MODEL_ID=bigscience/bloom-560m \
|
MODEL_ID=bigscience/bloom-560m \
|
||||||
QUANTIZE=false \
|
QUANTIZE=false \
|
||||||
NUM_GPUS=1 \
|
NUM_SHARD=1 \
|
||||||
SAFETENSORS_FAST_GPU=1 \
|
SAFETENSORS_FAST_GPU=1 \
|
||||||
PORT=80 \
|
PORT=80 \
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
|
||||||
NCCL_ASYNC_ERROR_HANDLING=1 \
|
NCCL_ASYNC_ERROR_HANDLING=1 \
|
||||||
CUDA_HOME=/usr/local/cuda \
|
CUDA_HOME=/usr/local/cuda \
|
||||||
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
||||||
CONDA_DEFAULT_ENV=text-generation \
|
CONDA_DEFAULT_ENV=text-generation \
|
||||||
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
|
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
|
||||||
|
|
||||||
SHELL ["/bin/bash", "-c"]
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/*
|
RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN cd ~ && \
|
RUN cd ~ && \
|
||||||
@ -71,4 +68,5 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca
|
|||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS --model-name $MODEL_ID --json-output
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
|
CMD ["--json-output"]
|
8
Makefile
8
Makefile
@ -16,16 +16,16 @@ router-dev:
|
|||||||
cd router && cargo run
|
cd router && cargo run
|
||||||
|
|
||||||
run-bloom-560m:
|
run-bloom-560m:
|
||||||
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2
|
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2
|
||||||
|
|
||||||
run-bloom-560m-quantize:
|
run-bloom-560m-quantize:
|
||||||
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
|
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize
|
||||||
|
|
||||||
download-bloom:
|
download-bloom:
|
||||||
text-generation-server download-weights bigscience/bloom
|
text-generation-server download-weights bigscience/bloom
|
||||||
|
|
||||||
run-bloom:
|
run-bloom:
|
||||||
text-generation-launcher --model-name bigscience/bloom --num-shard 8
|
text-generation-launcher --model-id bigscience/bloom --num-shard 8
|
||||||
|
|
||||||
run-bloom-quantize:
|
run-bloom-quantize:
|
||||||
text-generation-launcher --model-name bigscience/bloom --num-shard 8 --quantize
|
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize
|
105
README.md
105
README.md
@ -1,16 +1,42 @@
|
|||||||
|
<div align="center">
|
||||||
|
|
||||||
# Text Generation Inference
|
# Text Generation Inference
|
||||||
|
|
||||||
<div align="center">
|
<a href="https://github.com/huggingface/text-generation-inference">
|
||||||
|
<img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/huggingface/text-generation-inference?style=social">
|
||||||
|
</a>
|
||||||
|
<a href="https://github.com/huggingface/text-generation-inference/blob/main/LICENSE">
|
||||||
|
<img alt="License" src="https://img.shields.io/github/license/huggingface/text-generation-inference">
|
||||||
|
</a>
|
||||||
|
<a href="https://huggingface.github.io/text-generation-inference">
|
||||||
|
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
|
||||||
|
</a>
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
A Rust and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
|
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
|
||||||
to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
to power LLMs api-inference widgets.
|
||||||
|
|
||||||
|
## Table of contents
|
||||||
|
|
||||||
|
- [Features](#features)
|
||||||
|
- [Officially Supported Models](#officially-supported-models)
|
||||||
|
- [Get Started](#get-started)
|
||||||
|
- [Docker](#docker)
|
||||||
|
- [Local Install](#local-install)
|
||||||
|
- [OpenAPI](#api-documentation)
|
||||||
|
- [CUDA Kernels](#cuda-kernels)
|
||||||
|
- [Run BLOOM](#run-bloom)
|
||||||
|
- [Download](#download)
|
||||||
|
- [Run](#run)
|
||||||
|
- [Quantization](#quantization)
|
||||||
|
- [Develop](#develop)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
- Token streaming using Server Side Events (SSE)
|
||||||
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
|
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
|
||||||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
@ -36,30 +62,63 @@ or
|
|||||||
|
|
||||||
`AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map="auto")`
|
`AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map="auto")`
|
||||||
|
|
||||||
## Load Tests for BLOOM
|
## Get started
|
||||||
|
|
||||||
See `k6/load_test.js`
|
### Docker
|
||||||
|
|
||||||
| | avg | min | med | max | p(90) | p(95) | RPS |
|
The easiest way of getting started is using the official Docker container:
|
||||||
|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
|
|
||||||
| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
|
|
||||||
| New batching logic | **5.44s** | **959.53ms** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
|
|
||||||
|
|
||||||
## Install
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make install
|
model=bigscience/bloom-560m
|
||||||
|
num_shard=2
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run --gpus all -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
|
||||||
```
|
```
|
||||||
|
|
||||||
## Run
|
You can then query the model using either the `/generate` or `/generate_stream` routes:
|
||||||
|
|
||||||
### BLOOM 560-m
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
curl 127.0.0.1:8080/generate \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl 127.0.0.1:8080/generate_stream \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
|
To use GPUs, you will need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
|
||||||
|
|
||||||
|
### API documentation
|
||||||
|
|
||||||
|
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
|
||||||
|
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
|
||||||
|
|
||||||
|
### Local install
|
||||||
|
|
||||||
|
You can also opt to install `text-generation-inference` locally. You will need to have cargo and Python installed on your
|
||||||
|
machine
|
||||||
|
|
||||||
|
```shell
|
||||||
|
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
|
||||||
make run-bloom-560m
|
make run-bloom-560m
|
||||||
```
|
```
|
||||||
|
|
||||||
### BLOOM
|
### CUDA Kernels
|
||||||
|
|
||||||
|
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
|
||||||
|
the kernels by using the `BUILD_EXTENSIONS=False` environment variable.
|
||||||
|
|
||||||
|
Be aware that the official Docker image has them enabled by default.
|
||||||
|
|
||||||
|
## Run BLOOM
|
||||||
|
|
||||||
|
### Download
|
||||||
|
|
||||||
First you need to download the weights:
|
First you need to download the weights:
|
||||||
|
|
||||||
@ -67,26 +126,20 @@ First you need to download the weights:
|
|||||||
make download-bloom
|
make download-bloom
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Run
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make run-bloom # Requires 8xA100 80GB
|
make run-bloom # Requires 8xA100 80GB
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Quantization
|
||||||
|
|
||||||
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
|
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make run-bloom-quantize # Requires 8xA100 40GB
|
make run-bloom-quantize # Requires 8xA100 40GB
|
||||||
```
|
```
|
||||||
|
|
||||||
## Test
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl 127.0.0.1:3000/generate \
|
|
||||||
-v \
|
|
||||||
-X POST \
|
|
||||||
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
|
||||||
-H 'Content-Type: application/json'
|
|
||||||
```
|
|
||||||
|
|
||||||
## Develop
|
## Develop
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -4,9 +4,9 @@ endpoint_name: bloom-inference
|
|||||||
model: azureml:bloom:1
|
model: azureml:bloom:1
|
||||||
model_mount_path: /var/azureml-model
|
model_mount_path: /var/azureml-model
|
||||||
environment_variables:
|
environment_variables:
|
||||||
MODEL_BASE_PATH: /var/azureml-model/bloom
|
HUGGINGFACE_HUB_CACHE: /var/azureml-model/bloom
|
||||||
MODEL_ID: bigscience/bloom
|
MODEL_ID: bigscience/bloom
|
||||||
NUM_GPUS: 8
|
NUM_SHARD: 8
|
||||||
environment:
|
environment:
|
||||||
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.3.1
|
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.3.1
|
||||||
inference_config:
|
inference_config:
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 132 KiB After Width: | Height: | Size: 334 KiB |
@ -3,7 +3,7 @@
|
|||||||
<!-- Load the latest Swagger UI code and style from npm using unpkg.com -->
|
<!-- Load the latest Swagger UI code and style from npm using unpkg.com -->
|
||||||
<script src="https://unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
|
<script src="https://unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
|
||||||
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css"/>
|
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css"/>
|
||||||
<title>My New API</title>
|
<title>Text Generation Inference API</title>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div id="swagger-ui"></div> <!-- Div to hold the UI component -->
|
<div id="swagger-ui"></div> <!-- Div to hold the UI component -->
|
||||||
|
@ -19,7 +19,7 @@ use subprocess::{Popen, PopenConfig, PopenError, Redirection};
|
|||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||||
model_name: String,
|
model_id: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -49,7 +49,7 @@ struct Args {
|
|||||||
fn main() -> ExitCode {
|
fn main() -> ExitCode {
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
model_name,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
num_shard,
|
num_shard,
|
||||||
quantize,
|
quantize,
|
||||||
@ -92,7 +92,7 @@ fn main() -> ExitCode {
|
|||||||
|
|
||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..num_shard {
|
for rank in 0..num_shard {
|
||||||
let model_name = model_name.clone();
|
let model_id = model_id.clone();
|
||||||
let revision = revision.clone();
|
let revision = revision.clone();
|
||||||
let uds_path = shard_uds_path.clone();
|
let uds_path = shard_uds_path.clone();
|
||||||
let master_addr = master_addr.clone();
|
let master_addr = master_addr.clone();
|
||||||
@ -101,7 +101,7 @@ fn main() -> ExitCode {
|
|||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_name,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize,
|
quantize,
|
||||||
uds_path,
|
uds_path,
|
||||||
@ -167,7 +167,7 @@ fn main() -> ExitCode {
|
|||||||
"--master-shard-uds-path".to_string(),
|
"--master-shard-uds-path".to_string(),
|
||||||
format!("{}-0", shard_uds_path),
|
format!("{}-0", shard_uds_path),
|
||||||
"--tokenizer-name".to_string(),
|
"--tokenizer-name".to_string(),
|
||||||
model_name,
|
model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
if json_output {
|
if json_output {
|
||||||
@ -256,7 +256,7 @@ enum ShardStatus {
|
|||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn shard_manager(
|
fn shard_manager(
|
||||||
model_name: String,
|
model_id: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
@ -278,7 +278,7 @@ fn shard_manager(
|
|||||||
let mut shard_argv = vec![
|
let mut shard_argv = vec![
|
||||||
"text-generation-server".to_string(),
|
"text-generation-server".to_string(),
|
||||||
"serve".to_string(),
|
"serve".to_string(),
|
||||||
model_name,
|
model_id,
|
||||||
"--uds-path".to_string(),
|
"--uds-path".to_string(),
|
||||||
uds_path,
|
uds_path,
|
||||||
"--logger-level".to_string(),
|
"--logger-level".to_string(),
|
||||||
|
@ -29,11 +29,11 @@ struct GeneratedText {
|
|||||||
details: Details,
|
details: Details,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
||||||
let argv = vec![
|
let argv = vec![
|
||||||
"text-generation-launcher".to_string(),
|
"text-generation-launcher".to_string(),
|
||||||
"--model-name".to_string(),
|
"--model-id".to_string(),
|
||||||
model_name.clone(),
|
model_id.clone(),
|
||||||
"--num-shard".to_string(),
|
"--num-shard".to_string(),
|
||||||
num_shard.to_string(),
|
num_shard.to_string(),
|
||||||
"--port".to_string(),
|
"--port".to_string(),
|
||||||
@ -75,16 +75,16 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
|
|||||||
|
|
||||||
launcher.terminate().unwrap();
|
launcher.terminate().unwrap();
|
||||||
launcher.wait().unwrap();
|
launcher.wait().unwrap();
|
||||||
panic!("failed to launch {}", model_name)
|
panic!("failed to launch {}", model_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_model(
|
fn test_model(
|
||||||
model_name: String,
|
model_id: String,
|
||||||
num_shard: usize,
|
num_shard: usize,
|
||||||
port: usize,
|
port: usize,
|
||||||
master_port: usize,
|
master_port: usize,
|
||||||
) -> GeneratedText {
|
) -> GeneratedText {
|
||||||
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
|
let mut launcher = start_launcher(model_id, num_shard, port, master_port);
|
||||||
|
|
||||||
let data = r#"
|
let data = r#"
|
||||||
{
|
{
|
||||||
|
@ -13,7 +13,7 @@ app = typer.Typer()
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model_name: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
@ -46,16 +46,16 @@ def serve(
|
|||||||
os.getenv("MASTER_PORT", None) is not None
|
os.getenv("MASTER_PORT", None) is not None
|
||||||
), "MASTER_PORT must be set when sharded is True"
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
server.serve(model_name, revision, sharded, quantize, uds_path)
|
server.serve(model_id, revision, sharded, quantize, uds_path)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def download_weights(
|
def download_weights(
|
||||||
model_name: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
extension: str = ".safetensors",
|
extension: str = ".safetensors",
|
||||||
):
|
):
|
||||||
utils.download_weights(model_name, revision, extension)
|
utils.download_weights(model_id, revision, extension)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -30,31 +30,31 @@ torch.backends.cudnn.allow_tf32 = True
|
|||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_name: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||||
) -> Model:
|
) -> Model:
|
||||||
config = AutoConfig.from_pretrained(model_name, revision=revision)
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||||
|
|
||||||
if config.model_type == "bloom":
|
if config.model_type == "bloom":
|
||||||
if sharded:
|
if sharded:
|
||||||
return BLOOMSharded(model_name, revision, quantize=quantize)
|
return BLOOMSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return BLOOM(model_name, revision, quantize=quantize)
|
return BLOOM(model_id, revision, quantize=quantize)
|
||||||
elif config.model_type == "gpt_neox":
|
elif config.model_type == "gpt_neox":
|
||||||
if sharded:
|
if sharded:
|
||||||
return GPTNeoxSharded(model_name, revision, quantize=quantize)
|
return GPTNeoxSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return GPTNeox(model_name, revision, quantize=quantize)
|
return GPTNeox(model_id, revision, quantize=quantize)
|
||||||
elif model_name.startswith("facebook/galactica"):
|
elif model_id.startswith("facebook/galactica"):
|
||||||
if sharded:
|
if sharded:
|
||||||
return GalacticaSharded(model_name, revision, quantize=quantize)
|
return GalacticaSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return Galactica(model_name, revision, quantize=quantize)
|
return Galactica(model_id, revision, quantize=quantize)
|
||||||
elif "santacoder" in model_name:
|
elif "santacoder" in model_id:
|
||||||
return SantaCoder(model_name, revision, quantize)
|
return SantaCoder(model_id, revision, quantize)
|
||||||
else:
|
else:
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
try:
|
try:
|
||||||
return CausalLM(model_name, revision, quantize=quantize)
|
return CausalLM(model_id, revision, quantize=quantize)
|
||||||
except Exception:
|
except Exception:
|
||||||
return Seq2SeqLM(model_name, revision, quantize=quantize)
|
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||||
|
@ -57,10 +57,10 @@ class BLOOM(CausalLM):
|
|||||||
|
|
||||||
class BLOOMSharded(BLOOM):
|
class BLOOMSharded(BLOOM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
if not model_name.startswith("bigscience/bloom"):
|
if not model_id.startswith("bigscience/bloom"):
|
||||||
raise ValueError(f"Model {model_name} is not supported")
|
raise ValueError(f"Model {model_id} is not supported")
|
||||||
|
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
@ -72,22 +72,20 @@ class BLOOMSharded(BLOOM):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_name, revision=revision, slow_but_exact=False, tp_parallel=True
|
model_id, revision=revision, slow_but_exact=False, tp_parallel=True
|
||||||
)
|
)
|
||||||
config.pad_token_id = 3
|
config.pad_token_id = 3
|
||||||
|
|
||||||
# Only download weights for small models
|
# Only download weights for small models
|
||||||
if self.master and model_name == "bigscience/bloom-560m":
|
if self.master and model_id == "bigscience/bloom-560m":
|
||||||
download_weights(model_name, revision=revision, extension=".safetensors")
|
download_weights(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
model_name, revision=revision, extension=".safetensors"
|
|
||||||
)
|
|
||||||
if not filenames:
|
if not filenames:
|
||||||
raise ValueError("No safetensors weights found")
|
raise ValueError("No safetensors weights found")
|
||||||
|
|
||||||
|
@ -232,7 +232,7 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
|
|
||||||
class CausalLM(Model):
|
class CausalLM(Model):
|
||||||
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
@ -244,10 +244,10 @@ class CausalLM(Model):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
|
@ -149,10 +149,10 @@ class Galactica(CausalLM):
|
|||||||
|
|
||||||
class GalacticaSharded(Galactica):
|
class GalacticaSharded(Galactica):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
if not model_name.startswith("facebook/galactica"):
|
if not model_id.startswith("facebook/galactica"):
|
||||||
raise ValueError(f"Model {model_name} is not supported")
|
raise ValueError(f"Model {model_id} is not supported")
|
||||||
|
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
@ -164,22 +164,20 @@ class GalacticaSharded(Galactica):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_name, revision=revision, tp_parallel=True
|
model_id, revision=revision, tp_parallel=True
|
||||||
)
|
)
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
# Only download weights for small models
|
# Only download weights for small models
|
||||||
if self.master and model_name == "facebook/galactica-125m":
|
if self.master and model_id == "facebook/galactica-125m":
|
||||||
download_weights(model_name, revision=revision, extension=".safetensors")
|
download_weights(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
model_name, revision=revision, extension=".safetensors"
|
|
||||||
)
|
|
||||||
if not filenames:
|
if not filenames:
|
||||||
raise ValueError("No safetensors weights found")
|
raise ValueError("No safetensors weights found")
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
|
|||||||
|
|
||||||
class GPTNeoxSharded(GPTNeox):
|
class GPTNeoxSharded(GPTNeox):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
@ -61,22 +61,20 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_name, revision=revision, tp_parallel=True
|
model_id, revision=revision, tp_parallel=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only master download weights
|
# Only master download weights
|
||||||
if self.master:
|
if self.master:
|
||||||
download_weights(model_name, revision=revision, extension=".safetensors")
|
download_weights(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
model_name, revision=revision, extension=".safetensors"
|
|
||||||
)
|
|
||||||
if not filenames:
|
if not filenames:
|
||||||
raise ValueError("No safetensors weights found")
|
raise ValueError("No safetensors weights found")
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ EOD = "<|endoftext|>"
|
|||||||
|
|
||||||
|
|
||||||
class SantaCoder(CausalLM):
|
class SantaCoder(CausalLM):
|
||||||
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{
|
{
|
||||||
@ -43,7 +43,7 @@ class SantaCoder(CausalLM):
|
|||||||
|
|
||||||
self.model = (
|
self.model = (
|
||||||
AutoModelForCausalLM.from_pretrained(
|
AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize,
|
load_in_8bit=quantize,
|
||||||
|
@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
|
|
||||||
class Seq2SeqLM(Model):
|
class Seq2SeqLM(Model):
|
||||||
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
@ -301,14 +301,14 @@ class Seq2SeqLM(Model):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
model_name,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
load_in_8bit=quantize,
|
load_in_8bit=quantize,
|
||||||
).eval()
|
).eval()
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
||||||
|
|
||||||
|
@ -66,14 +66,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
model_name: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_name: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
@ -89,7 +89,7 @@ def serve(
|
|||||||
local_url = unix_socket_template.format(uds_path, 0)
|
local_url = unix_socket_template.format(uds_path, 0)
|
||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
model = get_model(model_name, revision, sharded, quantize)
|
model = get_model(model_id, revision, sharded, quantize)
|
||||||
|
|
||||||
server = aio.server(interceptors=[ExceptionInterceptor()])
|
server = aio.server(interceptors=[ExceptionInterceptor()])
|
||||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||||
@ -109,4 +109,4 @@ def serve(
|
|||||||
logger.info("Signal received. Shutting down")
|
logger.info("Signal received. Shutting down")
|
||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(serve_inner(model_name, revision, sharded, quantize))
|
asyncio.run(serve_inner(model_id, revision, sharded, quantize))
|
||||||
|
@ -182,20 +182,20 @@ def initialize_torch_distributed():
|
|||||||
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
|
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
|
||||||
|
|
||||||
|
|
||||||
def weight_hub_files(model_name, revision=None, extension=".safetensors"):
|
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
|
||||||
"""Get the safetensors filenames on the hub"""
|
"""Get the safetensors filenames on the hub"""
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
info = api.model_info(model_name, revision=revision)
|
info = api.model_info(model_id, revision=revision)
|
||||||
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def try_to_load_from_cache(model_name, revision, filename):
|
def try_to_load_from_cache(model_id, revision, filename):
|
||||||
"""Try to load a file from the Hugging Face cache"""
|
"""Try to load a file from the Hugging Face cache"""
|
||||||
if revision is None:
|
if revision is None:
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
||||||
object_id = model_name.replace("/", "--")
|
object_id = model_id.replace("/", "--")
|
||||||
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
|
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
|
||||||
|
|
||||||
if not repo_cache.is_dir():
|
if not repo_cache.is_dir():
|
||||||
@ -230,38 +230,38 @@ def try_to_load_from_cache(model_name, revision, filename):
|
|||||||
return str(cached_file) if cached_file.is_file() else None
|
return str(cached_file) if cached_file.is_file() else None
|
||||||
|
|
||||||
|
|
||||||
def weight_files(model_name, revision=None, extension=".safetensors"):
|
def weight_files(model_id, revision=None, extension=".safetensors"):
|
||||||
"""Get the local safetensors filenames"""
|
"""Get the local safetensors filenames"""
|
||||||
if WEIGHTS_CACHE_OVERRIDE is not None:
|
if WEIGHTS_CACHE_OVERRIDE is not None:
|
||||||
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
||||||
|
|
||||||
filenames = weight_hub_files(model_name, revision, extension)
|
filenames = weight_hub_files(model_id, revision, extension)
|
||||||
files = []
|
files = []
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
cache_file = try_to_load_from_cache(
|
cache_file = try_to_load_from_cache(
|
||||||
model_name, revision=revision, filename=filename
|
model_id, revision=revision, filename=filename
|
||||||
)
|
)
|
||||||
if cache_file is None:
|
if cache_file is None:
|
||||||
raise LocalEntryNotFoundError(
|
raise LocalEntryNotFoundError(
|
||||||
f"File {filename} of model {model_name} not found in "
|
f"File {filename} of model {model_id} not found in "
|
||||||
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
||||||
f"Please run `text-generation-server download-weights {model_name}` first."
|
f"Please run `text-generation-server download-weights {model_id}` first."
|
||||||
)
|
)
|
||||||
files.append(cache_file)
|
files.append(cache_file)
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
def download_weights(model_name, revision=None, extension=".safetensors"):
|
def download_weights(model_id, revision=None, extension=".safetensors"):
|
||||||
"""Download the safetensors files from the hub"""
|
"""Download the safetensors files from the hub"""
|
||||||
if WEIGHTS_CACHE_OVERRIDE is not None:
|
if WEIGHTS_CACHE_OVERRIDE is not None:
|
||||||
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
||||||
|
|
||||||
filenames = weight_hub_files(model_name, revision, extension)
|
filenames = weight_hub_files(model_id, revision, extension)
|
||||||
|
|
||||||
download_function = partial(
|
download_function = partial(
|
||||||
hf_hub_download,
|
hf_hub_download,
|
||||||
repo_id=model_name,
|
repo_id=model_id,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user