mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Merge branch 'main' into feat/prefix_chunking
This commit is contained in:
commit
5c8c5ac81a
6
.github/workflows/build.yaml
vendored
6
.github/workflows/build.yaml
vendored
@ -75,10 +75,10 @@ jobs:
|
|||||||
export label_extension="-intel-cpu"
|
export label_extension="-intel-cpu"
|
||||||
export docker_devices="none"
|
export docker_devices="none"
|
||||||
export docker_volume="/mnt/cache"
|
export docker_volume="/mnt/cache"
|
||||||
export runs_on="ubuntu-latest"
|
# export runs_on="ubuntu-latest"
|
||||||
# export runs_on="aws-highmemory-32-plus-priv"
|
export runs_on="aws-highmemory-32-plus-priv"
|
||||||
export platform="cpu"
|
export platform="cpu"
|
||||||
export extra_pytest="-k test_flash_llama_load"
|
export extra_pytest="-k test_flash_gemma_simple"
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
echo $dockerfile
|
echo $dockerfile
|
||||||
|
@ -112,6 +112,8 @@ ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/
|
|||||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||||
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||||
|
ENV TORCH_LLM_ALLREDUCE=1
|
||||||
|
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
@ -128,12 +130,22 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
curl \
|
curl \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
make \
|
make \
|
||||||
g++ \
|
g++-12 \
|
||||||
|
gcc-12 \
|
||||||
git \
|
git \
|
||||||
wget \
|
wget \
|
||||||
cmake \
|
cmake \
|
||||||
libnuma-dev
|
libnuma-dev
|
||||||
|
|
||||||
|
RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12
|
||||||
|
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12
|
||||||
|
RUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30
|
||||||
|
RUN update-alternatives --set cc /usr/bin/gcc
|
||||||
|
|
||||||
|
RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30
|
||||||
|
RUN update-alternatives --set c++ /usr/bin/g++
|
||||||
|
|
||||||
|
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
PORT=80
|
PORT=80
|
||||||
@ -165,16 +177,17 @@ RUN case ${TARGETPLATFORM} in \
|
|||||||
|
|
||||||
RUN conda install -c conda-forge gperftools mkl
|
RUN conda install -c conda-forge gperftools mkl
|
||||||
|
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
|
|
||||||
RUN pip install triton py-libnuma
|
RUN pip install triton py-libnuma
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
|
||||||
|
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
|
||||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
|
|
||||||
|
|
||||||
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ curl localhost:3000/v1/chat/completions \
|
|||||||
|
|
||||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||||
|
|
||||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above.
|
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1-rocm --model-id $model` instead of the command above.
|
||||||
|
|
||||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||||
```
|
```
|
||||||
@ -150,7 +150,7 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
token=<your cli READ token>
|
token=<your cli READ token>
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
### A note on Shared Memory (shm)
|
### A note on Shared Memory (shm)
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
title: Text Generation Inference
|
title: Text Generation Inference
|
||||||
- local: quicktour
|
- local: quicktour
|
||||||
title: Quick Tour
|
title: Quick Tour
|
||||||
|
- local: supported_models
|
||||||
|
title: Supported Models
|
||||||
- local: installation_nvidia
|
- local: installation_nvidia
|
||||||
title: Using TGI with Nvidia GPUs
|
title: Using TGI with Nvidia GPUs
|
||||||
- local: installation_amd
|
- local: installation_amd
|
||||||
@ -15,8 +17,7 @@
|
|||||||
title: Using TGI with Intel GPUs
|
title: Using TGI with Intel GPUs
|
||||||
- local: installation
|
- local: installation
|
||||||
title: Installation from source
|
title: Installation from source
|
||||||
- local: supported_models
|
|
||||||
title: Supported Models and Hardware
|
|
||||||
- local: architecture
|
- local: architecture
|
||||||
title: Internal Architecture
|
title: Internal Architecture
|
||||||
- local: usage_statistics
|
- local: usage_statistics
|
||||||
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
|||||||
--shm-size 1g \
|
--shm-size 1g \
|
||||||
-e HF_TOKEN=$token \
|
-e HF_TOKEN=$token \
|
||||||
-p 8080:80 \
|
-p 8080:80 \
|
||||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
|
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
|
|||||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes
|
||||||
```
|
```
|
||||||
|
|
||||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||||
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
|
|||||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes-nf4
|
||||||
```
|
```
|
||||||
|
|
||||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||||
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
|
|||||||
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
|
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize gptq
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize gptq
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.
|
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.
|
||||||
|
@ -17,8 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
|||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
If you want to serve gated or private models, which provide
|
If you want to serve gated or private models, please refer to
|
||||||
controlled access to sensitive or proprietary content, refer to
|
|
||||||
[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)
|
[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)
|
||||||
for detailed instructions.
|
for detailed instructions.
|
||||||
|
|
||||||
@ -97,7 +96,7 @@ curl 127.0.0.1:8080/generate \
|
|||||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run ghcr.io/huggingface/text-generation-inference:2.2.0 --help
|
docker run ghcr.io/huggingface/text-generation-inference:2.3.1 --help
|
||||||
```
|
```
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
|
|
||||||
# Supported Models and Hardware
|
# Supported Models
|
||||||
|
|
||||||
Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported.
|
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
||||||
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
||||||
@ -38,6 +36,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
|||||||
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
|
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -575,7 +575,10 @@ def launcher(event_loop):
|
|||||||
print(container_output, file=sys.stderr)
|
print(container_output, file=sys.stderr)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
container.remove()
|
try:
|
||||||
|
container.remove()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
if DOCKER_IMAGE is not None:
|
if DOCKER_IMAGE is not None:
|
||||||
return docker_launcher
|
return docker_launcher
|
||||||
|
@ -16,7 +16,7 @@ async def flash_gemma(flash_gemma_handle):
|
|||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma(flash_gemma, response_snapshot):
|
async def test_flash_gemma_simple(flash_gemma, response_snapshot):
|
||||||
response = await flash_gemma.generate(
|
response = await flash_gemma.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,7 @@ async def flash_llama(flash_llama_handle):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama(flash_llama, response_snapshot):
|
async def test_flash_llama_simple(flash_llama, response_snapshot):
|
||||||
response = await flash_llama.generate(
|
response = await flash_llama.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
)
|
)
|
||||||
|
@ -94,7 +94,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("0".to_string());
|
||||||
}
|
}
|
||||||
match config.model_type.as_deref() {
|
match config.model_type.as_deref() {
|
||||||
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
Some("falcon") | Some("deepseek_v2") => {
|
||||||
// Required because gemma2 needs bfloat16 which is not supported by
|
// Required because gemma2 needs bfloat16 which is not supported by
|
||||||
// flashinfer ?
|
// flashinfer ?
|
||||||
if attention.is_none() {
|
if attention.is_none() {
|
||||||
@ -944,17 +944,19 @@ fn shard_manager(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
// We read stdin in another thread as it seems that lines() can block in some cases
|
// We read stdin in another thread as it seems that lines() can block in some cases
|
||||||
thread::spawn(move || {
|
if LevelFilter::current() >= tracing::Level::DEBUG {
|
||||||
let mut stdin = io::stdin(); // We get `Stdin` here.
|
thread::spawn(move || {
|
||||||
loop {
|
let mut stdin = io::stdin(); // We get `Stdin` here.
|
||||||
let mut buffer = vec![0; 4096];
|
loop {
|
||||||
if let Ok(n) = stdin.read(&mut buffer) {
|
let mut buffer = vec![0; 4096];
|
||||||
if n > 0 {
|
if let Ok(n) = stdin.read(&mut buffer) {
|
||||||
let _ = pstdin.write_all(&buffer[..n]);
|
if n > 0 {
|
||||||
|
let _ = pstdin.write_all(&buffer[..n]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
});
|
}
|
||||||
|
|
||||||
let mut ready = false;
|
let mut ready = false;
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
use crate::infer::Infer;
|
use crate::infer::Infer;
|
||||||
use crate::server::{generate_internal, ComputeType};
|
use crate::server::{generate_internal, ComputeType};
|
||||||
use crate::{
|
use crate::{ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest};
|
||||||
ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message,
|
|
||||||
StreamOptions, Tool, ToolChoice,
|
|
||||||
};
|
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, StatusCode};
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
use axum::response::{IntoResponse, Response};
|
use axum::response::{IntoResponse, Response};
|
||||||
@ -21,162 +18,12 @@ pub(crate) struct GenerateVertexInstance {
|
|||||||
pub parameters: Option<GenerateParameters>,
|
pub parameters: Option<GenerateParameters>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
|
||||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
|
||||||
pub(crate) struct VertexChat {
|
|
||||||
messages: Vec<Message>,
|
|
||||||
// Messages is ignored there.
|
|
||||||
#[serde(default)]
|
|
||||||
parameters: VertexParameters,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
|
|
||||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
|
||||||
pub(crate) struct VertexParameters {
|
|
||||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
|
||||||
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
|
||||||
pub model: Option<String>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
|
||||||
/// decreasing the model's likelihood to repeat the same line verbatim.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(example = "1.0")]
|
|
||||||
pub frequency_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// UNUSED
|
|
||||||
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
|
||||||
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
|
||||||
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
|
||||||
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
|
||||||
/// result in a ban or exclusive selection of the relevant token.
|
|
||||||
#[serde(default)]
|
|
||||||
pub logit_bias: Option<Vec<f32>>,
|
|
||||||
|
|
||||||
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
|
||||||
/// output token returned in the content of message.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(example = "false")]
|
|
||||||
pub logprobs: Option<bool>,
|
|
||||||
|
|
||||||
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
|
||||||
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(example = "5")]
|
|
||||||
pub top_logprobs: Option<u32>,
|
|
||||||
|
|
||||||
/// The maximum number of tokens that can be generated in the chat completion.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(example = "32")]
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// UNUSED
|
|
||||||
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
|
||||||
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = "2")]
|
|
||||||
pub n: Option<u32>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
|
||||||
/// increasing the model's likelihood to talk about new topics
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = 0.1)]
|
|
||||||
pub presence_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Up to 4 sequences where the API will stop generating further tokens.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = "null")]
|
|
||||||
pub stop: Option<Vec<String>>,
|
|
||||||
|
|
||||||
#[serde(default = "bool::default")]
|
|
||||||
pub stream: bool,
|
|
||||||
|
|
||||||
#[schema(nullable = true, example = 42)]
|
|
||||||
pub seed: Option<u64>,
|
|
||||||
|
|
||||||
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
|
||||||
/// lower values like 0.2 will make it more focused and deterministic.
|
|
||||||
///
|
|
||||||
/// We generally recommend altering this or `top_p` but not both.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = 1.0)]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
|
|
||||||
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
|
||||||
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = 0.95)]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
|
|
||||||
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
|
||||||
/// functions the model may generate JSON inputs for.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = "null")]
|
|
||||||
pub tools: Option<Vec<Tool>>,
|
|
||||||
|
|
||||||
/// A prompt to be appended before the tools
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(
|
|
||||||
nullable = true,
|
|
||||||
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
|
||||||
)]
|
|
||||||
pub tool_prompt: Option<String>,
|
|
||||||
|
|
||||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = "null")]
|
|
||||||
pub tool_choice: ToolChoice,
|
|
||||||
|
|
||||||
/// Response format constraints for the generation.
|
|
||||||
///
|
|
||||||
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
|
||||||
pub response_format: Option<GrammarType>,
|
|
||||||
|
|
||||||
/// A guideline to be used in the chat_template
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
|
||||||
pub guideline: Option<String>,
|
|
||||||
|
|
||||||
/// Options for streaming response. Only set this when you set stream: true.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = "null")]
|
|
||||||
pub stream_options: Option<StreamOptions>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<VertexChat> for ChatRequest {
|
|
||||||
fn from(val: VertexChat) -> Self {
|
|
||||||
Self {
|
|
||||||
messages: val.messages,
|
|
||||||
frequency_penalty: val.parameters.frequency_penalty,
|
|
||||||
guideline: val.parameters.guideline,
|
|
||||||
logit_bias: val.parameters.logit_bias,
|
|
||||||
logprobs: val.parameters.logprobs,
|
|
||||||
max_tokens: val.parameters.max_tokens,
|
|
||||||
model: val.parameters.model,
|
|
||||||
n: val.parameters.n,
|
|
||||||
presence_penalty: val.parameters.presence_penalty,
|
|
||||||
response_format: val.parameters.response_format,
|
|
||||||
seed: val.parameters.seed,
|
|
||||||
stop: val.parameters.stop,
|
|
||||||
stream_options: val.parameters.stream_options,
|
|
||||||
stream: val.parameters.stream,
|
|
||||||
temperature: val.parameters.temperature,
|
|
||||||
tool_choice: val.parameters.tool_choice,
|
|
||||||
tool_prompt: val.parameters.tool_prompt,
|
|
||||||
tools: val.parameters.tools,
|
|
||||||
top_logprobs: val.parameters.top_logprobs,
|
|
||||||
top_p: val.parameters.top_p,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub(crate) enum VertexInstance {
|
pub(crate) enum VertexInstance {
|
||||||
Generate(GenerateVertexInstance),
|
Generate(GenerateVertexInstance),
|
||||||
Chat(VertexChat),
|
Chat(ChatRequest),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, ToSchema)]
|
#[derive(Deserialize, ToSchema)]
|
||||||
@ -257,9 +104,8 @@ pub(crate) async fn vertex_compatibility(
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
VertexInstance::Chat(instance) => {
|
VertexInstance::Chat(instance) => {
|
||||||
let chat_request: ChatRequest = instance.into();
|
|
||||||
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
||||||
chat_request.try_into_generate(&infer)?;
|
instance.try_into_generate(&infer)?;
|
||||||
generate_request
|
generate_request
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -305,34 +151,14 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn vertex_deserialization() {
|
fn vertex_deserialization() {
|
||||||
let string = serde_json::json!({
|
|
||||||
|
|
||||||
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
|
||||||
"parameters": {
|
|
||||||
"max_tokens": 128,
|
|
||||||
"top_p": 0.95,
|
|
||||||
"temperature": 0.7
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
|
||||||
|
|
||||||
let string = serde_json::json!({
|
|
||||||
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
|
||||||
});
|
|
||||||
|
|
||||||
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
|
||||||
|
|
||||||
let string = serde_json::json!({
|
let string = serde_json::json!({
|
||||||
|
|
||||||
"instances": [
|
"instances": [
|
||||||
{
|
{
|
||||||
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||||
"parameters": {
|
"max_tokens": 128,
|
||||||
"max_tokens": 128,
|
"top_p": 0.95,
|
||||||
"top_p": 0.95,
|
"temperature": 0.7
|
||||||
"temperature": 0.7
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -341,18 +167,16 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
request,
|
request,
|
||||||
VertexRequest {
|
VertexRequest {
|
||||||
instances: vec![VertexInstance::Chat(VertexChat {
|
instances: vec![VertexInstance::Chat(ChatRequest {
|
||||||
messages: vec![Message {
|
messages: vec![Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
|
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
},],
|
},],
|
||||||
parameters: VertexParameters {
|
max_tokens: Some(128),
|
||||||
max_tokens: Some(128),
|
top_p: Some(0.95),
|
||||||
top_p: Some(0.95),
|
temperature: Some(0.7),
|
||||||
temperature: Some(0.7),
|
..Default::default()
|
||||||
..Default::default()
|
|
||||||
}
|
|
||||||
})]
|
})]
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
@ -52,3 +52,53 @@ class Seqlen:
|
|||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
# Flash decoding doesn't need to clamp
|
# Flash decoding doesn't need to clamp
|
||||||
return self
|
return self
|
||||||
|
=======
|
||||||
|
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Seqlen:
|
||||||
|
input_lengths: torch.Tensor
|
||||||
|
prefix_lengths: torch.Tensor
|
||||||
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
|
max_q: int
|
||||||
|
max_k: int
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_lengths,
|
||||||
|
prefix_lengths,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=None,
|
||||||
|
max_k=None,
|
||||||
|
):
|
||||||
|
self.input_lengths = input_lengths
|
||||||
|
self.prefix_lengths = prefix_lengths
|
||||||
|
device = self.input_lengths.device
|
||||||
|
shape = self.input_lengths.shape
|
||||||
|
if cu_seqlen_q is None:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
shape[0] + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
max_q = 1
|
||||||
|
else:
|
||||||
|
assert max_q is not None
|
||||||
|
assert max_k is not None
|
||||||
|
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
|
# Although FA2 might not want the clamping
|
||||||
|
# cu_seqlen_k[0] = 0
|
||||||
|
total = self.input_lengths + self.prefix_lengths
|
||||||
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
|
self.cu_seqlen_k = cu_seqlen_k
|
||||||
|
self.max_q = max_q
|
||||||
|
self.max_k = max_k
|
||||||
|
|
||||||
|
def clamp(self, max):
|
||||||
|
# Flash decoding doesn't need to clamp
|
||||||
|
return self
|
@ -518,14 +518,13 @@ class CausalLM(Model):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
device = torch.device("cpu")
|
||||||
device = torch.device(f"xpu:{rank}")
|
# Float16 doesn't exist on target.
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
# Float16 doesn't exist on target.
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
@ -594,8 +593,14 @@ class CausalLM(Model):
|
|||||||
if speculator:
|
if speculator:
|
||||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
|
device_count = 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device("xpu")
|
||||||
|
device_count = torch.xpu.device_count()
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
@ -615,20 +620,12 @@ class CausalLM(Model):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=(
|
device_map=("auto" if device_count > 1 else None),
|
||||||
"auto"
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if (
|
if device_count == 1 and quantize != "bitsandbytes":
|
||||||
torch.cuda.is_available()
|
model = model.to(device)
|
||||||
and torch.cuda.device_count() == 1
|
|
||||||
and quantize != "bitsandbytes"
|
|
||||||
):
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
|
@ -559,14 +559,13 @@ class Seq2SeqLM(Model):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
device = torch.device("cpu")
|
||||||
device = torch.device(f"xpu:{rank}")
|
# Float16 doesn't exist on target.
|
||||||
dtype = default_dtype if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
# Float16 doesn't exist on target.
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
@ -631,8 +630,14 @@ class Seq2SeqLM(Model):
|
|||||||
if speculator:
|
if speculator:
|
||||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
|
device_count = 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device("xpu")
|
||||||
|
device_count = torch.xpu.device_count()
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
@ -645,16 +650,12 @@ class Seq2SeqLM(Model):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=(
|
device_map=("auto" if device_count > 1 else None),
|
||||||
"auto"
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if device_count == 1:
|
||||||
model = model.cuda()
|
model = model.to(device)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -66,6 +66,11 @@ elif is_ipex_available():
|
|||||||
empty_cache = noop
|
empty_cache = noop
|
||||||
synchronize = noop
|
synchronize = noop
|
||||||
get_free_memory = get_cpu_free_memory
|
get_free_memory = get_cpu_free_memory
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
SYSTEM = "xpu"
|
||||||
|
empty_cache = torch.xpu.empty_cache
|
||||||
|
synchronize = torch.xpu.synchronize
|
||||||
|
get_free_memory = get_xpu_free_memory
|
||||||
else:
|
else:
|
||||||
SYSTEM = "cpu"
|
SYSTEM = "cpu"
|
||||||
|
|
||||||
|
@ -5,14 +5,13 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
TEMPLATE = """
|
TEMPLATE = """
|
||||||
# Supported Models and Hardware
|
# Supported Models
|
||||||
|
|
||||||
Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported.
|
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
SUPPORTED_MODELS
|
SUPPORTED_MODELS
|
||||||
|
|
||||||
|
|
||||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
Loading…
Reference in New Issue
Block a user