Compare commits

..

481 Commits
v2.2.0 ... main

Author SHA1 Message Date
Nicolas Patry
8f8819795f
Fixing CI (#3184) 2025-04-18 13:07:18 +02:00
Alvaro Bartolome
95ccba3705
Bump sccache to 0.10.0 (#3179)
* Ensure that `sccache` version is 0.10.0 or higher

* Rename `ACTIONS_CACHE_URL` to `ACTIONS_RESULTS_URL`
2025-04-18 12:45:32 +02:00
Hyeongchan Kim
b400c275e4
Get opentelemetry trace id from request headers instead of creating a new trace (#2648)
feature: get trace id from req headers

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-04-18 09:06:41 +02:00
Daniël de Kok
84ab88d843
Support flashinfer for Gemma3 prefill (#3167)
* launcher: ensure correct detection of Gemma 3 head size

* Support flashinfer for Gemma3 prefill

Gemma3 uses bidirectional attention for images. Flashinfer
supports custom masks. Hook up the mask with flashinfer, so that we do
not have to use the slower SDPA implementation for prefills with images.

* Update Gemma3 test outputs

* Fixed unused import
2025-04-17 18:07:41 +02:00
Nicolas Patry
4645678ff0
Hotfix gaudi2 with newer transformers. (#3176) 2025-04-15 12:39:28 +02:00
Nicolas Patry
ad765cd06b
Hotfixing gaudi deps. (#3174) 2025-04-15 11:55:28 +02:00
Nicolas Patry
16b4b7974a
Upgrading the dependencies in Gaudi backend. (#3170)
* Upgrading the dependencies in Gaudi backend.

* Upgrading transformers version.
2025-04-15 11:49:06 +02:00
Wang, Yi
459fbdebe3
transformers flash llm/vlm enabling in ipex (#3152)
* transformers flash llm/vlm enabling in xpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* ipex cpu could also support in function

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-04-15 11:08:01 +02:00
Nicolas Patry
449cee49ca
setuptools <= 70.0 is vulnerable: CVE-2024-6345 (#3171) 2025-04-15 10:09:37 +02:00
Mohit Sharma
73e797528d
L4 fixes (#3161)
add fix
2025-04-14 22:13:53 +05:30
Nicolas Patry
fe56f760df
Upgrading the python client deps (still deprecated, but used for
integration-tests)
2025-04-14 17:18:43 +02:00
Wang, Yi
d62c941c56
Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu (#3113)
* clean cuda/rocm code in hpu backend, enable flat_hpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix TP in pageattn

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* adjust block table in hpu to improve performance

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable all the model. not testet yet

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* use tensor cache in hpu graph to avoid replay issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add moe support, fix qwen/mistral/mixtral crash

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix phimoe issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* gpt_bigcode could also go pageattn

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable dbrx remove some unused code

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* multi-modality initial PR

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* adjust warmup and enable vlm

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix incorrect output in qwen2 idefics if hpu graph is used

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* remove unused quantization code and enable awq/gptq int4

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix gptq issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable fp8

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* warmup prefill

remove model where pageattn is not used, set block table to None since it's not used

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add warmup_decode

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* warmup decode

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* remove block_tables and prefill_cache_indices which will lead to dynamic shape

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* missing gptj change...

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix some issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* remove torch.where to fix incorrect output in hpu graph model

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* match the latest vllm_extension ops

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-04-14 15:58:13 +02:00
Nicolas Patry
9a8d0462e1
Fixing tokenization like https://github.com/huggingface/text-embeddin… (#3156)
Fixing tokenization like https://github.com/huggingface/text-embeddings-inference/issues/525
2025-04-09 18:42:25 +02:00
Nicolas Patry
5861da1ad7
Fixing Qwen 2.5 VL (32B). (#3157)
Reduce the config constraints, and use common ground between the 8B and
32B.
2025-04-09 17:07:30 +02:00
Nicolas Patry
0b28aabb94
3.2.3 (#3151) 2025-04-08 10:16:37 +02:00
oOraph
24bec29ffc
fix: compute type typo (#3150)
Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com>
Co-authored-by: Raphael Glon <oOraph@users.noreply.github.com>
2025-04-07 17:24:11 +02:00
Baptiste Colle
37104acd75
Gaudi: Add Integration Test for Gaudi Backend (#3142)
* feat(gaudi): add integration test

* feat(test): add more models to integration tests

* remove debug comments

* fix typos
2025-04-07 16:55:03 +02:00
Mohit Sharma
87a0af4ec2
Update transformers to 4.51 (#3148)
* update transformres

* Upgrading the nix deps too.

* Forcing torchvision to be in there.

* Fixing bug in mllama.

* Those tests cannot be run in CI.

* Lint.

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-04-07 12:55:43 +02:00
Mohit Sharma
9c26b52940
Use ROCM 6.3.1 (#3141)
* update dockerfile

* add updated makefile

* fix docker

* Lint.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-04-07 12:55:11 +02:00
Nicolas Patry
d23b385eee
Preparing for release. (#3147)
* Preparing for release.

* Adding hf-xet dependency.

* Merged tgi-nix update.
2025-04-06 11:36:00 +02:00
Mohit Sharma
d9bb9bebc9
Add llama4 (#3145)
* initial changes

* Add support for other vlm

* cleanup comment

* Improve attn_implementation

* Add comments for support of models

* add model

* add model

* fixes and improvements

* update docker

* Add cache position

* Add tests

* remove redundant changes

* remove tr version

* Upgrade doc + fix linting.

* Fixing the CI.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-04-06 10:20:22 +02:00
Yuan Wu
3d059f91ab
Gaudi: Use exponential growth to replace BATCH_BUCKET_SIZE (#3131)
* Gaudi: Use exponential growth to replace BATCH_BUCKET_SIZE

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Remove debug modifications

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
2025-04-03 10:34:53 +02:00
Corentin REGAL
0142550096
nix-v3.2.1 -> v3.2.1-nix (#3129)
make it easier to check for version using semver semantic (same major
and minor)
2025-03-26 15:36:43 +01:00
Yuan Wu
f5f14dc660
Gaudi: Fix llava-next and mllama crash issue (#3127)
Signed-off-by: yuanwu <yuan.wu@intel.com>
2025-03-25 15:08:15 +01:00
Nicolas Patry
54d15462dc
Torch 2.6 (#3134)
* Torch 2.6

* Upgrade the toolchain.

* Don't upgrade just yet.

* Upgrade toolchain.

* Time upgrade.

* TGI-nix main.

* Upgrade to transformers 4.50
2025-03-24 11:55:49 +01:00
Baptiste Colle
2e60a8dd65
CI: enable server tests for backends (#3128)
add test for backends
2025-03-20 16:07:31 +01:00
Erik Kaunismäki
e5503eba78
configurable termination timeout (#3126)
* make shard and webserver termination timeouts configurable

* Updating documentation.

* Fmt.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-03-20 14:25:56 +01:00
Nicolas Patry
e497bc09f6
Minor fixes. (#3125) 2025-03-18 15:42:35 +01:00
Nicolas Patry
67ce543e04
Intel docker. (#3121)
* Intel docker.

* torchaudio ?

* Fixing dockerfile ?
2025-03-18 15:12:11 +01:00
Nicolas Patry
83fe45c15e
Prepare for patch release. (#3124) 2025-03-18 15:11:55 +01:00
Nicolas Patry
11f2eec10e
Publish nix docker image. (#3122)
* Publish nix docker image.

* Run during PR.

* Something else.

* Forgot to push.

* Build zstd.

* Pushing with skopeo

* Testing the PR.

* Runnign from nix.

* Cleaner tags.
2025-03-18 12:58:21 +01:00
Mohit Sharma
a35fbdb925
Bug Fix: Sliding Window Attention (#3112)
* (fix) sliding window attention

* (fix) flashinfer

* (typo) collection link

* Add window_size_left param ipex rocm

* Update window size rocm flash decoding

* fix: bump snapshots and improve exceed window test case

* feat: add tests for image types and remove alpha from png

* Upgrading `from_env` to get token from file when necessary + fix
pali_gemma.

* fix: add pillow dependency and bump lock+requirements

* fix: bump org name in gemma3 test

* Fix qwen2.

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-03-18 10:37:33 +01:00
Baptiste Colle
8c2c348f3c
Gaudi: Sync TGI with the latest changes from the TGI-Gaudi fork (#3117)
feat(gaudi): add all the changes from tgi-gaudi fork up to PR #289
2025-03-18 09:45:52 +01:00
Daniël de Kok
095775e05c
launcher: correctly get the head dimension for VLMs (#3116)
* launcher: correctly get the head dimension for VLMs

For most (?) VLMs, the head dimension is in the `text_config`
configuration section. However, since we only queried the top-level
`head_dim` (which typically doesn't exist in VLMs), we would never use
flashinfer. This change adds a method that gets the head dimension from
the top-level `Config` struct or `text_config` when that fails.

* fix: bump org name in gemma3 test

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
2025-03-17 18:19:37 +01:00
Wang, Yi
0b3e3db043
xpu 2.6 update (#3051)
* xpu 2.6 update

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* install whl

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update get xpu memory api

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* int

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix awq crash if modules_to_not_convert is None

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-03-17 13:48:48 +01:00
Daniël de Kok
f91434e99b
Make the Nix-based Docker container work on non-NixOS (#3109)
On NixOS, the CUDA driver shim gets mounted on /run/opengl-driver,
where Nix packages expect the shim to be. However, on other
distributions, some FHS paths are mounted. This is a small change
to make the dynamic loader find the shim.
2025-03-13 14:02:45 +01:00
Nicolas Patry
8b91f92978
Fixing the docker build. (#3108)
* Fixing the docker build.

* Apply suggestions from code review
2025-03-13 11:26:44 +01:00
Baptiste Colle
27ed848676
Release of Gaudi Backend for TGI (#3091)
* feat(gaudi): release ready (docs, docker image and vlm ready)

* fix(gaudi): add default argument for the dockerfile

* fix(gaudi): remove use of latest for gaudi docker image + redid gaudi benchmarking section to include best practices
2025-03-13 10:56:01 +01:00
Nicolas Patry
83ef364177
We need gcc during runtime to enable triton to compile kernels. (#3103)
* We need gcc during runtime to enable triton to compile kernels.

* Fixing the docker build.
2025-03-13 10:45:47 +01:00
Daniël de Kok
83b7b7bb92
Router: add gemma3-text model type (#3107) 2025-03-13 10:41:33 +01:00
Daniël de Kok
c73ae0bd88
Update to kernels 0.2.1 (#3084)
* Update to `kernels` 0.2.1

The package was renamed from `hf-kernels` to `kernels`. The new version
also updates the lockfile format.

* Download kernels in `install-cuda` target
2025-03-13 10:36:29 +01:00
Nicolas Patry
d4c6faa67b
Try to fix on main CI color. (#3101) 2025-03-12 10:12:24 +01:00
Nicolas Patry
4ac06ddf56
Preparing relase 3.2.0 (#3100)
* Preparing relase 3.2.0

* Forgot the README.

* Update doc.
2025-03-12 10:11:33 +01:00
David Corvoysier
f01dc9e743
Update neuron backend (#3098)
* feat(neuron): use AWS Neuron SDK 2.21.1

* feat(neuron): bump optimum-neuron version

* feat(neuron): tag latest image for local tests

* test(neuron): simplify sampling test
2025-03-12 09:53:15 +01:00
Nicolas Patry
5c5528e362
Fix tool call4 (#3094)
* Removing the no_tool content information.

* Removing a lot of NO_TOOL shenanigans.

* Update the tests.
2025-03-12 09:28:47 +01:00
Mohit Sharma
ed46c2c414
Add gemma3 model (#3099) 2025-03-12 09:25:51 +01:00
Nicolas Patry
f74c36fe0d
Fix tool call3 (#3086)
* Fixing the tool calling convention.

* Update tehe doc.

* Fixing some corner cases.

* Fixing the tool call id.

* Fmt.

* Snapshot update with the new updated tool_call_id.

* More qwen2.
2025-03-12 09:22:53 +01:00
celsowm
ae4451c3da
Update README.md (#3095)
space between param and value
2025-03-11 11:05:21 +01:00
Nicolas Patry
b447f7e821
Fix qwen vl (#3096)
* Fixing qwen2.5 VL.

* Fixing the CI.
2025-03-11 11:00:41 +01:00
Adrien Gallouët
094975c3a8
Update the llamacpp backend (#3022)
* Build faster

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Make --model-gguf optional

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Bump llama.cpp

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Enable mmap, offload_kqv & flash_attention by default

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update doc

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Better error message

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update doc

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update installed packages

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Save gguf in models/MODEL_ID/model.gguf

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix build with Mach-O

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Quantize without llama-quantize

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Bump llama.cpp and switch to ggml-org

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove make-gguf.sh

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update Cargo.lock

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Support HF_HUB_USER_AGENT_ORIGIN

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Bump llama.cpp

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --build-arg llamacpp_native & llamacpp_cpu_arm_arch

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-03-11 09:19:01 +01:00
drbh
dc5f05f8e6
Pr 3003 ci branch (#3007)
* change ChatCompletionChunk to align with "OpenAI Chat Completions streaming API"

Moving after tool_calls2

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

add in Buffering..

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

fix: handle usage outside of stream state and add tests

Simplifying everything quite a bit.

Remove the unused model_dump.

Clippy.

Clippy ?

Ruff.

Uppgrade the flake for latest transformers.

Upgrade after rebase.

Remove potential footgun.

Fix completion test.

* Clippy.

* Tweak for multi prompt.

* Ruff.

* Update the snapshot a bit.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-03-10 17:56:19 +01:00
Daniël de Kok
124398fa57
hotfix: qwen2 formatting (#3093)
* hotfix: qwen2 formatting

* cargo fmt
2025-03-10 16:19:50 +01:00
Daniël de Kok
c5ecc7a4de
Small test and typing fixes (#3078)
* test_weights: add modules_to_not_convert

* More typing fixes
2025-03-10 15:08:23 +01:00
jiqing-feng
cae0cbe87d
Add modules_to_not_convert in quantized model (#3053)
* fix modules_to_not_convert

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix tp quant skip

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* revert unquantized changes

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* use DefaultWeightsLoader in skip modules

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
2025-03-10 15:03:51 +01:00
EachSheep
bbe218a4f7
Add qwen2 multi lora layers support (#3089)
add qwen2 multi lora layers support to solve problem like https://github.com/huggingface/text-generation-inference/issues/2881, the similar PR are at https://github.com/huggingface/text-generation-inference/pull/2883

Co-authored-by: hjs <hjs@pku.edu.cn>
2025-03-10 12:42:59 +01:00
Alex Weston
58a65f7914
Add request parameters to OTel span for /v1/chat/completions endpoint (#3000)
Record request parameters in OTel span for /v1/chat/completions endpoint
2025-03-10 12:26:57 +01:00
Daniël de Kok
976eae216f
Nix: the launcher needs a Python env with Torch for GPU detection (#3085)
This makes `nix run .` in the repository work again. Should fix #3025.
2025-03-10 12:11:10 +01:00
Nicolas Patry
622908deab
Fix tool call2 (#3076)
* Making `tool_calls` a vector.

* Arguments output is a string.

* Update all the integration tests.

* Add the requirements.

* Upgrade other tests.

* Clippy.

* Update the old test.
2025-03-07 19:45:57 +01:00
Alvaro Bartolome
55a6618434
Update --max-batch-total-tokens description (#3083)
* Update `--max-batch-total-tokens` description

* Update docstring in `launcher/src/main.rs` instead
2025-03-07 14:24:26 +01:00
Daniël de Kok
036d802b62
Nix: add openai to impure shell for integration tests (#3081) 2025-03-07 13:04:21 +01:00
Nicolas Patry
8e92942a18
Making tool_calls a vector. (#3075)
* Making `tool_calls` a vector.

* Update doc.

* Fixing the nix overlay with updated version.

* Add openai dependency.

* Updating the old tests.

* Trying to reduce the logs in the case of errors.

* Less spammy logs too.
2025-03-05 22:32:31 +01:00
Nicolas Patry
3208d1cd1d
Revert "Trying to reduce the logs in the case of errors."
This reverts commit cdf70d6a28.
2025-03-05 20:52:38 +01:00
Nicolas Patry
cdf70d6a28
Trying to reduce the logs in the case of errors. 2025-03-05 20:50:43 +01:00
Nicolas Patry
ab9dafc68f
Making sure Olmo (transformers backend) works. (#3074) 2025-03-05 17:46:47 +01:00
Nicolas Patry
31766dad77
Force upgrade transformers version for olmo. 2025-03-05 12:17:09 +01:00
Nicolas Patry
ec35976f82
Only add token when it is defined. (#3073)
* Only add token when it is defined.

* Update router/src/server.rs
2025-03-05 11:59:52 +01:00
David Corvoysier
cb42b3ad83
fix(neuron): explicitly install toolchain (#3072)
* fix(neuron): explicitly install toolchain

* ci(neuron): trigger CI when Dockerfile is modified
2025-03-05 11:46:58 +01:00
Nicolas Patry
491ed9e11d
Patch rust release. (#3069)
* Patch rust release.

* Trying to remove the rust-toolchain hardcoded in action.

* Upgrade rust toolchain.

* Put back the toolchain ?

* Fix neuron dockerfile.

* Move to the proper version of Rust.

* 1.85 since the GH action doesn't respect the override.

* Typo.

* Fixing the github action.

* Fixing docker llamacpp.

* Fixing the github action.

* Update clippy.
2025-03-04 18:07:33 +01:00
Sadra Barikbin
144d99c147
Fix a tiny typo in monitoring.md tutorial (#3056)
Update monitoring.md
2025-03-04 17:06:26 +01:00
Nicolas Patry
08bbfa16a1
Preparing for release. (#3060)
* Preparing for release.

* Upgrade doc.

* Fix docs auto-generated.

* Fix update doc along.
2025-03-04 16:47:10 +01:00
Hugo Larcher
d8ff7f2623
feat: add support for HF_HUB_USER_AGENT_ORIGIN to add user-agent Origin field in Hub requests. (#3061)
* feat: add support for HF_HUB_USER_AGENT_ORIGIN to add user-agent Origin field in Hub requests.

* fix: Rust version for Neuron

* fix: PR comments, use rust-toolchain.toml
2025-03-04 16:43:50 +01:00
Daniël de Kok
e88f6f6ee9
Add property-based testing for RadixAllocator (#3068) 2025-03-04 15:09:46 +01:00
Daniël de Kok
fa4e9511f8
Fix two edge cases in RadixTrie::find (#3067)
- Always return a node, not its parent.
- Do not recurse when a node does not represent a full prefix of the
  input.
2025-03-04 13:23:27 +01:00
Nicolas Patry
a914a21899
Revert "Patch rust release."
This reverts commit aad9c2b0bd.
2025-03-04 12:16:18 +00:00
Nicolas Patry
aad9c2b0bd
Patch rust release. 2025-03-04 12:14:58 +00:00
Nicolas Patry
1f35cc7a31
Updating patch rust release. 2025-03-04 12:13:58 +00:00
Baptiste Colle
683ff53fa3
Add Gaudi Backend (#3055)
* wip(gaudi): import server and dockerfile from tgi-gaudi fork

* feat(gaudi): new gaudi backend working

* fix: fix style

* fix prehooks issues

* fix(gaudi): refactor server and implement requested changes
2025-02-28 12:14:58 +01:00
David Corvoysier
5eec3a8bb6
Avoid running neuron integration tests twice (#3054)
* test(neuron): refactor to prepare batch export

* test(neuron): add helper to batch export models

Also rename fixture file fro clarity.

* ci(neuron): do not run tests twice

* ci(neuron): rename precompilation job

* test(neuron): remove redundant subdirectory

* test(neuron): remove erroneous line

* doc(neuron): update links to installation page

* feat(neuron): cleanup Dockerfile

CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse is not required anymore.

* test(neuron): try to reduce download errors
2025-02-26 12:15:01 +01:00
drbh
b0069e0485
fix: run linters and fix formatting (#3057) 2025-02-25 16:11:34 -05:00
Wang, Yi
d7a24c03cf
some minor fix (#3048)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-02-25 12:07:55 +01:00
Nicolas Patry
cea9dbc971
You need to seek apparently. (#3049) 2025-02-24 14:58:23 +01:00
David Corvoysier
c00add9c03
Add Neuron backend (#3033)
* feat: add neuron backend

* feat(neuron): add server standalone installation

* feat(neuron): add server and integration tests

* fix(neuron): increase ulimit when building image

The base image used to compile the rust components seems to have a low
ulimit for opened files, which leads to errors during compilation.

* test(neuron): merge integration tests and fixtures

* test: add --neuron option

* review: do not use latest tag

* review: remove ureq pinned version

* review: --privileged should be the exception

* feat: add neuron case to build ci

* fix(neuron): export models from container in test fixtures

The neuron tests require models to have been previously exported and
cached on the hub. This is done automatically by the neuron.model
fixture the first time the tests are ran for a specific version.
This fixture used to export the models using optimum-neuron directly,
but this package is not necessarily present on the system.
Instead, it is now done through the neuron TGI itself, since it
contains all the tools required to export the models.
Note that since the CI runs docker in docker (dind) it does not seem
possible to share a volume between the CI container and the container
used to export the model.
For that reason, a specific image with a modified entrypoint is built
on-the-fly when a model export is required.

* refactor: remove sagemaker entry-point

The SageMaker image is built differently anyway.

* fix(neuron): avoid using Levenshtein

* test(neuron): use smaller llama model

* feat(neuron): avoid installing CUDA in image

* test(neuron): no error anymore when requesting too many tokens

* ci: doing a precompilation step (with a different token).

* test(neuron): avoid using image sha when exporting models

We now manually evaluate the apparent hash of the neuron backend by
combining the hash of the neuron backend directory and Dockerfile.
This new hash is used to identify exported neuron models instead of the
image sha.
This has two benefits:
- it changes less frequently (only hwen the neuron backend changes),
  which means less neuron models being pushed to the hub,
- it can be evaluated locally, meaning that running the tests once
  locally will export the models before the CI uses them.

* test(neuron): added a small script to prune test models

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-02-24 09:10:05 +01:00
Daniël de Kok
97c5f7e685
Use rotary kernel from the Hub (#3041) 2025-02-21 13:55:31 +01:00
drbh
1cae3197c4
Improve tool call message processing (#3036)
* make content field optional in chat request

* add tool_calls field to Message struct

* feat: add test and serialize tool messages

* fix: bump utopia, openapi doc version and improve test

* fix: rerun update docs

* fix: suppoer tool call id in template and remove unnecessary changes

* fix: ruff lint remove unused import

* fix: adjust message types in tests

---------

Co-authored-by: sailesh duddupudi <saileshradar@gmail.com>
2025-02-21 10:30:29 +01:00
Adrien Gallouët
3498f6085e
Update Gradio ChatInterface configuration in consuming_tgi.md (#3042)
The current code does not work and gives the following message:

    UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.
      warnings.warn(
    Traceback (most recent call last):
      File "/Users/angt/hf/tgi/test-gradio.py", line 22, in <module>
        gr.ChatInterface(
    TypeError: ChatInterface.__init__() got an unexpected keyword argument 'retry_btn'

Signed-off-by: Adrien Gallouët <adrien@gallouet.fr>
2025-02-21 10:11:28 +01:00
Nicolas Patry
142a49a80d
Simplify logs2. (#3045)
* Simplify logs2.

* Changing the scope from module to session to fix the event_loop issue.
2025-02-21 10:03:40 +01:00
Wang, Yi
06dfe9abfe
fix qwen2 vl crash in continous batching (#3004)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-02-20 18:36:45 -05:00
Daniël de Kok
ed96ba6503
flashinfer 0.2.0.post1 -> post2 (#3040)
* flashinfer 0.2.0.post1 -> post2

* Fix ruff stuff.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-02-20 12:34:20 +01:00
Wang, Yi
feaa2477b7
update ipex and torch to 2.6 for cpu (#3039)
ipex cpu 2.6 support topk_group in moe fusion ops

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-02-20 09:12:28 +01:00
Hugo Larcher
230aa25641
feat: Add the parsing of HF_HUB_USER_AGENT_ORIGIN environment variable for telemetry (#3027)
* feat: Add the parsing of HF_HUB_USER_AGENT_ORIGIN environment variable to add info about the environment running TGI. That is useful to track usage in case of collaborations for example.

* fix: trufflehog
2025-02-19 21:09:12 +01:00
Nicolas Patry
9c89d0070e
Having less logs in case of failure for checking CI more easily. (#3037)
* Having less logs in case of failure for checking CI more easily.

* Cleaning up the versions to uv for the client.

* Ignore entirely the API.
2025-02-19 17:01:33 +01:00
Nicolas Patry
fde3234cbc
Using public external registry (to use external runners for CI). (#3031)
* Using public external registry (to use external runners for CI).

* Fix build.

* Fixing the external registry.

* Fixing trtllm tests.
2025-02-19 14:53:14 +01:00
drbh
d6a0c67e2f
feat: add initial qwen2.5-vl model and test (#2971)
* feat: support qwen2.5 vl model

* fix: bump support models doc

* feat: check before rope type adjustment and small refactors

* fix: add transformer overlay for processor support

* fix: vendor processor and config from transformers

* fix: refactor/simplify conditionals
2025-02-19 12:38:20 +01:00
Cyril Vallez
a7448661f7
Improve Transformers support (#2970)
* Much better support

* add gpt neox

* bump transformers version

* bump version
2025-02-18 19:04:34 +01:00
Nicolas Patry
5543fdc765
It's find in some machine. using hf_hub::api::sync::Api to download c… (#3030)
It's find in some machine. using hf_hub::api::sync::Api to download config is not successful which will make warmup fail since attribute like max_position_embeddings could not be got. update hf-hub to the latest version could fix it

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
2025-02-18 12:19:51 +01:00
Nicolas Patry
b8a4928d0e
Pinning trufflehog. (#3032) 2025-02-18 12:03:41 +01:00
Alvaro Bartolome
8a1cfd6122
Add loop_controls feature to minijinja to handle {% break %} (#2998)
* Add `loop_controls` feature to `minijinja`

* Add `test_chat_template_loop_controls` to test `break`
2025-02-18 10:33:22 +01:00
celsowm
794ec58b75
Update README.md (#3024)
only way to avoid:
error: experimental Nix feature 'nix-command' is disabled; add '--extra-experimental-features nix-command' to enable it
2025-02-18 10:08:28 +01:00
Daniël de Kok
f0ed76583c
Use eetq kernel from the hub (#3029)
* Use eetq kernel from the hub

* Fixing the CI.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-02-18 10:03:53 +01:00
Adrien Gallouët
cfd4fbb479
[Backend] Add Llamacpp backend (#2975)
* Add llamacpp backend

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Get rid of llama_batch_get_one()

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Use max_batch_total_tokens

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Handle max_batch_size

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add some input validation checks

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Handle ctx args & fix sampling

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add GPU args

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --defrag-threshold

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add a stupid batch mechanism

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --numa

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix args

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Enable flash attention by default

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --offload-kqv

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix batch_pos

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* backend(llama): add CUDA Dockerfile_llamacpp for now

* Only export the latest logits

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Output real logprobs

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix batching

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix seq iterations

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Auto-detect n_threads when not provided

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Clear request cache after completion

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove warmup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* backend(llama): add CUDA architectures build argument for Dockerfile

* Add specific args for batch

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --type-v & --type-k

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Bump llamacpp to b4623

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Disable graceful shutdown in debug mode

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update Dockerfile_llamacpp

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Cleanup Dockerfile

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update Cargo.lock

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update args

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Simplify batching logic

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Set TGI_LLAMA_PKG_CUDA from CUDA_VERSION

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Rename bindings

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove n_ctx

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Make max_batch_total_tokens optional

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Ensure all samplers are freed on error

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Initialize penalty_last_n with llamacpp default value

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Improve default settings

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add doc

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update docs

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Thanks clippy

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Thanks cargo fmt

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update docs

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Do not use HOSTNAME env

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Bump llama.cpp & cuda

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix requirements.txt

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix fmt

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Enable KQV offload by default

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove Ngrok tunneling

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove .cargo/config.toml

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix Dockerfile

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add missing cuda prefix

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Handle custom llama.cpp dir

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add README.md

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add HF transfer

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix bool args

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update doc

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update doc

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
2025-02-14 13:40:57 +01:00
Daniël de Kok
6df0fc0b55
Support sigmoid scoring function in GPTQ-MoE (#3017) 2025-02-14 11:33:49 +01:00
Nicolas Patry
d6881c37ab
Putting back the NCCL forced upgrade. (#2999)
* Putting back the NCCL forced upgrade.

* .

* ...

* Ignoring conda.

* Dropping conda from the buidl system + torch 2.6

* Cache min.

* Rolling back torch version.

* Reverting the EETQ modification.

* Fix flash attention ?

* Actually stay on flash v1.

* Patching flash v1.

* Torch 2.6, fork of rotary, eetq updated.

* Put back nccl latest (override torch).

* Slightly more reproducible build and not as scary.
2025-02-14 11:31:59 +01:00
Nicolas Patry
8a211dc7fc
Preventing single user hugging the server to death by asking (#3016)
for way too many tokens.
2025-02-13 11:23:17 +01:00
Nicolas Patry
4cccce4b44
Update the flaky mllama test. (#3015) 2025-02-12 12:26:52 +01:00
Wang, Yi
76bcb4948d
fix Qwen VL break in intel platform (#3002)
* fix Qwen VL break in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* could use PositionRotaryEmbedding impl so rocm and ipex could all work

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-02-12 11:31:34 +01:00
Nicolas Patry
b86c3947ab
Revert "Update the flaky mllama test."
This reverts commit 8a870b31b9.
2025-02-11 17:13:06 +01:00
Nicolas Patry
8a870b31b9
Update the flaky mllama test. 2025-02-11 17:10:36 +01:00
Daniël de Kok
571ac9b507
Use kernels from the kernel hub (#2988)
* Use Hub kernels for Marlin and cutlass quantization kernels

* Use hub kernels for MoE/GPTQ-Marlin MoE

* Use attention kernels from the Hub

* Cache the kernels in the Docker image

* Update moe kernels

* Support loading local kernels for development

* Support latest moe kernels

* Update to moe 0.1.1

* CI: download locked kernels for server tests

* Fixup some imports

* CI: activate venv

* Fix unused imports

* Nix: add attention/moe/quantization kernels

* Update hf-kernels to 0.1.5

* Update kernels

* Update tgi-nix flake for hf-kernels

* Fix EOF

* Take `load_kernel` out of a frequently-called function

* Hoist another case of kernel loading out of a somewhat hot function

* marlin-kernels -> quantization

* attention -> paged-attention

* EOF fix

* Update hf-kernels, fixup Docker

* ipex fix

* Remove outdated TODO
2025-02-10 19:19:25 +01:00
Nicolas Patry
4b8cda684b
Updating mllama after strftime. (#2993)
* Updating mllama after strftime.

* Town instead village.

* Forgot the integration snapshot.

* Attempt to fix intel CPU.

* Intel extension fix.

* Workaround intel.

* Moving those deps directly into pyproject.

* Revert "Moving those deps directly into pyproject."

This reverts commit 98c1496ea6.

* Non system uv.

* Fixing the docker environment hopefully.

* Missed a step.

* Move workdir up a bit.

* Bailing out of reproducible python env.

* Triton version.
2025-02-07 10:38:13 +01:00
Funtowicz Morgan
856709d5c3
[Backend] Bump TRTLLM to v.0.17.0 (#2991)
* backend(trtllm): bump TRTLLM to v.0.17.0

* backend(trtllm): forget to bump dockerfile

* backend(trtllm): use arg instead of env

* backend(trtllm): use correct library reference decoder_attention_src

* backend(trtllm): link against decoder_attention_{0|1}

* backend(trtllm): build against gcc-14 with cuda12.8

* backend(trtllm): use return value optimization flag as as error if available

* backend(trtllm): make sure we escalade all warnings as errors on the backend impl in debug mode

* backend(trtllm): link against CUDA 12.8
2025-02-06 16:45:03 +01:00
Wang, Yi
36223f834e
Triton fix (#2995)
fix triton to 3.1.0 to fix ipex import issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-02-06 12:28:41 +01:00
Nicolas Patry
0ef8c8a97a
Using the "lockfile". (#2992)
* Using the "lockfile".

* Revert dummy modifications.

* Lock on python 3.11

* Another attempt.

* ..

* Bad cache hits.

* The good old monkey.

* How in the world...

* We need the launcher still.

* .

* ..

* Attempt #42

* Don't break all other builds.

* Mode max.

* Applying to other builds.
2025-02-06 12:28:24 +01:00
drbh
c1cf36c0dc
Improve qwen vl impl (#2943)
* feat: refactor model, improve startup and re enable tests

* fix: improve multimodal rotary embed caching

* fix: limit vision flop calc to qwen2 vl models and update config typing

* fix: include clippy lint

* feat: refactor position ids in warmup and bump tests

* fix: prefer default dtype

* fix: enable all cuda graphs and bump snapshots

* fix: adjust rotaty init path

* fix: simplify get position ids and remove usused vision config

* fix: update position ids so first dim is batch, simplify rotary and bump vlm default token limit

* fix: improve position id init during cuda warmup for mrope and simplfy rotary forward

* fix: check existance before accessing rope type in cuda warmup

* fix: check key before access

* fix: improve mrope check in cuda graph warmup

* fix: remove check for default rope type

* fix: add more test and improve model generation

* fix: improve and simplify get_cos_sin, refactors and cleanup  get_position_ids

* fix: adjust signatures with types
2025-02-04 12:44:18 -05:00
Daniël de Kok
dd2bd5fdb3
impureWithCuda: fix gcc version (#2990)
* impureWithCuda: fix gcc version

* trufflehog: do not fail on unverified results
2025-02-04 17:01:59 +01:00
Alvaro Bartolome
88fd56f549
Add strftime_now callable function for minijinja chat templates (#2983)
* Add `chrono` and `strftime_now` function callable

* Fix `test_chat_template_valid_with_strftime_now`

* Fix `test_chat_template_valid_with_strftime_now`
2025-02-03 15:30:48 +01:00
Hugo Larcher
e3f2018cb5
hotfix: fix trtllm CI build on release (#2981)
* hotfix: fix trtllm CI build on release

* fix: test release.

* fix: test release.

* fix: test release. env not recognized https://github.com/actions/runner/issues/1661

* fix: test release. Works.
2025-02-03 11:11:15 +01:00
Nicolas Patry
bb69c5b199
Back on nix main. (#2979) 2025-01-31 14:39:52 +01:00
Nicolas Patry
c9d68945cc
Prepare for release 3.1.0 (#2972)
* Prepare for release 3.1.0

* Back on main flake.

* Fixing stuff.

* Upgrade to moe-kernels 0.8.2 for Hip support.

* Deactivating the flaky test.
2025-01-31 14:19:01 +01:00
Mohit Sharma
c07a2cc82b
Update moe-kernel to 0.8.2 for rocm (#2977)
update moe-kernel for amd
2025-01-31 11:40:00 +01:00
Hugo Larcher
065aabb13d
doc: Update TRTLLM deployment doc. (#2960)
* doc: Update TRTLLM deployment doc. Update TRTLLM CI to allow release builds when tagging TGI.

* doc: Update TRTLLM deployment doc. Update TRTLLM CI to allow release builds when tagging TGI.

* fix: PR comments
2025-01-30 18:04:42 +01:00
Nicolas Patry
cb747b33da
Add deepseekv3 (#2968)
* Add fp8 support moe models

add deepseekv3

format codfe'

update dockerfile

update doc

* Small modifications.

* Moe kernels 0.8.1

* Upgrade to 0.8.1

* Fixing moe import.

* Black.

* Apply suggestions from code review

Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>

* Fixing Mixtral + Nits.

* Put link to ref.

* Fix other call locations.

* Scoring func `softmax` is the only one that works.

---------

Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2025-01-30 16:40:25 +01:00
Nicolas Patry
80e7d98f88
Hotfixing intel-cpu (not sure how it was working before). (#2967)
* Hotfixing intel-cpu (not sure how it was working before).

* Do not fail on missing moe-kernels (Intel-cpu).
2025-01-29 22:34:41 +01:00
Daniël de Kok
ee0dffcd14
Update to moe-kernels 0.8.0 (#2966) 2025-01-29 18:19:55 +01:00
Mohit Sharma
4ef2e045c9
Add fp8 support moe models (#2928)
* Add fp8 support moe models

* flatten condition
2025-01-29 13:56:32 +01:00
Hugo Larcher
73b7cf83f6
Add backend name to telemetry (#2962)
* feat: Add backend name to telemetry
2025-01-28 16:53:16 +01:00
Nicolas Patry
eb3df0f46f
Fixing the oom maybe with 2.5.1 change. (#2958) 2025-01-28 10:30:28 +01:00
Hugo Larcher
c690da5973
fix: Telemetry (#2957)
* fix: add telemetry regular pings and fix unhandled errors avoid not sending telemetry stop events.

* fix: simplify error handling

* fix: update ping delay and update doc.

* fix: clippy

* doc: Rephrase properly.
2025-01-28 10:29:18 +01:00
Daniël de Kok
db922eb77e
Update to attention-kernels 0.2.0 (#2950)
This version removes our patches/custom API. Makes it simpler to
get changes from upstream. One of which is that we can enable FP8
KV cache for paged attention as well.
2025-01-27 11:42:36 +01:00
Funtowicz Morgan
40b00275b2
Attempt to remove AWS S3 flaky cache for sccache (#2953)
* backend(trtllm): attempt to remove AWS S3 flaky cache for sccache

* backend(trtllm): what if we expose ENV instead of inline?

* backend(trtllm): and with the right env var for gha sccache

* backend(trtllm): relax the way to detect sccache

* backend(trtllm): make sccache definition manually

* backend(trtllm): ok let's try to define the launchers in build.rs when rustc_wrapper is present

* backend(trtllm): export env variable in run mb?

* backend(trtllm): Cache mode max to cache intermediate layers

* backend(trtllm): inject ompi_version build arg in dependent step
2025-01-27 11:21:48 +01:00
Nicolas Patry
6cb41a80a1
Revert "Remove AWS credentials?"
This reverts commit d2ff68e98d.
2025-01-24 14:34:17 +01:00
Nicolas Patry
d2ff68e98d
Remove AWS credentials? 2025-01-24 12:18:28 +01:00
Nicolas Patry
d9dda11726
Trying to put back the archlist (to fix the oom). (#2947) 2025-01-24 09:32:17 +01:00
Nicolas Patry
d937eb64da
Fixing cargo lock. 2025-01-23 18:54:34 +01:00
Cyril Vallez
18c4607d46
Transformers backend TP fix (#2945)
* init dispatch

* cohere fix
2025-01-23 18:09:57 +01:00
Nicolas Patry
29a0893b67
Tmp tp transformers (#2942)
* Upgrade the version number.

* Remove modifications in Lock.

* Tmp branch to test transformers backend with 2.5.1 and TP>1

* Fixing the transformers backend.

inference_mode forces the use of `aten.matmul` instead of `aten.mm` the
former doesn't have sharding support crashing the transformers TP
support.

`lm_head.forward` also crashes because it skips the hook that
cast/decast the DTensor.

Torch 2.5.1 is required for sharding support.

* Put back the attention impl.

* Revert the flashinfer (this will fails).

* Building AOT.

* Using 2.5 kernels.

* Remove the archlist, it's defined in the docker anyway.
2025-01-23 18:07:30 +01:00
Funtowicz Morgan
0a89902663
[TRTLLM] Expose finish reason (#2841)
* feat(trtllm): expose finish reason to Rust

* misc(llamacpp): fix typo

* misc(backend): update deps
2025-01-23 16:48:26 +01:00
Nikolai Kolodziej
4e172028aa
Add NVIDIA A40 to known cards (#2941)
feat: add NVIDIA A40 to known cards
2025-01-23 14:19:21 +01:00
Alvaro Bartolome
6ab02931cf
Set alias for max_completion_tokens in ChatRequest (#2932) 2025-01-23 14:18:47 +01:00
Funtowicz Morgan
cc212154e0
Bump TensorRT-LLM backend dependency to v0.16.0 (#2931)
* backend(trtllm): update to 0.16.0

* backend(trtllm): do not use shallow clone

* backend(trtllm): use tag instead

* backend(trtllm): move to nvidia remote instead of hf

* backend(trtllm): reenable shallow clone

* backend(trtllm): attempt to use ADD instead of RUN for openmpi

* backend(trtllm): make sure we are using correct path for openmpi ADD in dockerfile

* backend(trtllm): add correctly untar it
2025-01-23 13:54:40 +01:00
Daniël de Kok
1dd346666a
Clarify FP8-Marlin use on capability 8.9 (#2940)
The log message stated that the GPU does not support FP8 on capability
8.9. However we use FP8-Marlin on that capability because it is faster.
2025-01-22 18:18:11 +01:00
Wang, Yi
1d3c9beba8
fix moe in quantization path (#2935)
update ipex xpu to support moe for mixtral

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-01-22 14:36:15 +01:00
Nicolas Patry
2dfe3b3ee6
Upgrading the deps to have transformers==4.48.0 necessary (#2937) 2025-01-22 12:20:15 +01:00
Alvaro Bartolome
64a33c1f05
Run pre-commit run --all-files to fix CI (#2933) 2025-01-21 17:33:33 +01:00
Nicolas Patry
bdb3e488e4
Trying to avoid the random timeout. (#2929)
* Trying to avoid the random timeout.

* More read timeout ?

* Longer timeout ?

* Remove legacy ENV directive.

* Remove the dummy test, only increase the read timeout.

* Wat?
2025-01-21 11:06:10 +01:00
Funtowicz Morgan
17367438f3
Give TensorRT-LLMa proper CI/CD 😍 (#2886)
* test(ctest) enable address sanitizer

* feat(trtllm): expose finish reason to Rust

* feat(trtllm): fix logits retrieval

* misc(ci): enabe building tensorrt-llm

* misc(ci): update Rust action toolchain

* misc(ci): let's try to build the Dockerfile for trtllm

# Conflicts:
#	Dockerfile_trtllm

* misc(ci): provide mecanism to cache inside container

* misc(ci): export aws creds as output of step

* misc(ci): let's try this way

* misc(ci): again

* misc(ci): again

* misc(ci): add debug profile

* misc(ci): add debug profile

* misc(ci): lets actually use sccache ...

* misc(ci): do not build with ssl enabled

* misc(ci): WAT

* misc(ci): WAT

* misc(ci): WAT

* misc(ci): WAT

* misc(ci): WAT

* misc(backend): test with TGI S3 conf

* misc(backend): test with TGI S3 conf

* misc(backend): once more?

* misc(backend): let's try with GHA

* misc(backend): missing env directive

* misc(backend): make sure to correctly set IS_GHA_BUILD=true in wf

* misc(backend): ok let's debug smtg

* misc(backend): WWWWWWWWWWWWWAAAAAAAA

* misc(backend): kthxbye retry s3

* misc(backend): use session token

* misc(backend): add more info

* misc(backend): lets try 1h30

* misc(backend): lets try 1h30

* misc(backend): increase to 2h

* misc(backend): lets try...

* misc(backend): lets try...

* misc(backend): let's build for ci-runtime

* misc(backend): let's add some more tooling

* misc(backend): add some tags

* misc(backend): disable Werror for now

* misc(backend): added automatic gha detection

* misc(backend): remove leak sanitizer which is included in asan

* misc(backend): forward env

* misc(backend): forward env

* misc(backend): let's try

* misc(backend): let's try

* misc(backend): again

* misc(backend): again

* misc(backend): again

* misc(backend): again

* misc(backend): again

* misc(backend): fix sscache -> sccache

* misc(backend): fix sscache -> sccache

* misc(backend): fix sscache -> sccache

* misc(backend): let's actually cache things now

* misc(backend): let's actually cache things now

* misc(backend): attempt to run the testS?

* misc(backend): attempt to run the tests?

* misc(backend): attempt to run the tests?

* change runner size

* fix: Correctly tag docker images (#2878)

* fix: Correctly tag docker images

* fix: Correctly tag docker images

* misc(llamacpp): maybe?

* misc(llamacpp): maybe?

* misc(llamacpp): maybe?

* misc(ci): gogogo

* misc(ci): gogogo

* misc(ci): gogogo

* misc(ci): gogogo

* misc(ci): gogogo

* misc(ci): gogogo

* misc(ci): go

* misc(ci): go

* misc(ci): go

* misc(ci): use bin folder

* misc(ci): make the wf callable for reuse

* misc(ci): make the wf callable for reuse (bis)

* misc(ci): make the wf callable for reuse (bis)

* misc(ci): give the wf a name

* Create test-trtllm.yml

* Update test-trtllm.yml

* Create build-trtllm2

* Rename build-trtllm2 to 1-build-trtllm2

* Rename test-trtllm.yml to 1-test-trtllm2.yml

* misc(ci): fw secrets

* Update 1-test-trtllm2.yml

* Rename 1-build-trtllm2 to 1-build-trtllm2.yml

* Update 1-test-trtllm2.yml

* misc(ci): use ci-build.yaml as main dispatcher

* Delete .github/workflows/1-test-trtllm2.yml

* Delete .github/workflows/1-build-trtllm2.yml

* misc(ci): rights?

* misc(ci): rights?

* misc(ci): once more?

* misc(ci): once more?

* misc(ci): baby more time?

* misc(ci): baby more time?

* misc(ci): try the permission above again?

* misc(ci): try the permission above again?

* misc(ci): try the permission scoped again?

* misc(ci): install tensorrt_llm_executor_static

* misc(ci): attempt to rebuild with sccache?

* misc(ci):run the tests on GPU instance

* misc(ci): let's actually setup sccache in the build.rs

* misc(ci): reintroduce variables

* misc(ci): enforce sccache

* misc(ci): correct right job name dependency

* misc(ci): detect dev profile for debug

* misc(ci): detect gha build

* misc(ci): detect gha build

* misc(ci): ok debug

* misc(ci): wtf

* misc(ci): wtf2

* misc(ci): wtf3

* misc(ci): use commit HEAD instead of merge commit for image id

* misc(ci): wtfinfini

* misc(ci): wtfinfini

* misc(ci): KAMEHAMEHA

* Merge TRTLLM in standard CI

* misc(ci): remove input machine

* misc(ci): missing id-token for AWS auth

* misc(ci): missing id-token for AWS auth

* misc(ci): missing id-token for AWS auth

* misc(ci): again...

* misc(ci): again...

* misc(ci): again...

* misc(ci): again...

* misc(ci): missing benchmark

* misc(ci): missing backends

* misc(ci): missing launcher

* misc(ci): give everything aws needs

* misc(ci): give everything aws needs

* misc(ci): fix warnings

* misc(ci): attempt to fix sccache not building trtllm

* misc(ci): attempt to fix sccache not building trtllm again

---------

Co-authored-by: Guillaume LEGENDRE <glegendre01@gmail.com>
Co-authored-by: Hugo Larcher <hugo.larcher@huggingface.co>
Co-authored-by: Pauline Bailly-Masson <155966238+paulinebm@users.noreply.github.com>
2025-01-21 10:19:16 +01:00
Cyril Vallez
b980848abf
Flash Transformers modeling backend support (#2913)
* add transformers_flash

* inits

* switch version to make it work

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* runnable version

* working

* push change

* fix high dim

* init

* default

* latest transformers changes

* revert

* simplify check

* remove flag

* improve type hints + required args

* Update based on transformers PR

* small fix

* Remove Warpers for Processor

* fix compatibility version issue

* raise error if needed

* Simplify with monkey patch

* revert + style + minor improvements

* update comment

* device check

* move the import to avoid device issue

* Update __init__.py

* check for non-native models

* oupsi

---------

Co-authored-by: System administrator <root@ip-10-90-0-159.ec2.internal>
2025-01-21 10:01:51 +01:00
Nicolas Patry
447a5b2f87
Fixing TRTLLM dockerfile. (#2922)
* Fixing TRTLLM dockerfile.

* Fixed.

* Creating a dummy modification to chekc CI runs.

* Removing the cache directive.

* Modifying this should cache hit.

* Revert "Modifying this should cache hit."

This reverts commit 46a2bde108.

* Modifying this should cache hit.

* Unwanted files.
2025-01-20 11:13:46 +01:00
Daniël de Kok
630f198624
flashinfer: switch to plan API (#2904)
This change doesn't switch `forward` to `run` yet, since it requires
that we have access to the softmax scale and the logit softcap outside
the model.
2025-01-17 18:18:02 +01:00
drbh
8f6146f11a
Revert "feat: improve qwen2-vl startup " (#2924)
Revert "feat: improve qwen2-vl startup  (#2802)"

This reverts commit eecca27113.
2025-01-17 12:09:05 -05:00
drbh
eecca27113
feat: improve qwen2-vl startup (#2802)
* feat: tokenize each request individually and increase warmup image size

* feat: adjust rotary embed and avoid cuda graphs of size 2 and smaller

* fix: address image resize and rebase changes

* feat: update to run qwen2-vl tests

* fix: tweak param types
2025-01-17 11:50:41 -05:00
Wang, Yi
6e982f43a1
fix the crash of meta-llama/Llama-3.2-1B (#2918)
* fix the crash of meta-llama/Llama-3.2-1B

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Apply suggestions from code review

Simpler fix (which doesn't break vlms).

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-01-17 15:50:58 +01:00
Mohit Sharma
c20025dbf7
Add fp8 kv cache for ROCm (#2856)
* add fp8 kv cache for rocm

* improvements

* update log statement

* remove bookkeeping field
2025-01-17 18:43:29 +05:30
Nicolas Patry
de19e7e844
Moving to uv instead of poetry. (#2919)
* Moving to `uv` instead of `poetry`.

More in the standard, faster, seemingly better lockfile.

* Creating venv if not created.

* Create the venv.

* Fix ?

* Fixing the test by activating the environment ?

* Install system  ?

* Add the cli entry point.

* docker install on system

* Monkeying this...

* `--system` is redundant.

* Trying to force-include this pb folder.

* TRying to check that pb is imported correctly.

* Editable install necessary ?

* Non editable?

* Editable it is.
2025-01-17 12:32:00 +01:00
Daniël de Kok
d61f14f271
nix: update to PyTorch 2.5.1 (#2921) 2025-01-17 12:12:11 +01:00
Wang, Yi
885144166f
Flash decoding kernel adding and prefill-chunking and prefix caching enabling in intel cpu/xpu (#2815)
* flash decoding

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable xpu flashdecoding

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* set flashdecoding blocksize as 64

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable flashdecoding, prefill chunking and prefix caching

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add flashdecoding-ipex

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-01-17 12:04:57 +01:00
drbh
82f6ea1b71
feat: improve star coder to support multi lora layers (#2883)
* feat: improve star coder to support multi lora layers

* feat: improve weight that support adapters and add tests for starcoder with lora

* fix: bump snapshot for added tests

* fix: rerun pre commit lints

* fix: bump adapter test for added later names
2025-01-16 16:23:55 -05:00
Daniël de Kok
5f78ec32a5
Do not convert weight scale to e4m3fnuz on CUDA (#2917) 2025-01-16 13:44:32 +01:00
Nicolas Patry
922cc38fbc
Upgrading bitsandbytes. (#2910)
* Upgrading bitsandbytes.

Co-Authored-By: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Tighter lock.

---------

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
2025-01-15 20:07:21 +01:00
Nicolas Patry
120bd3e3bb
Removing the github runner. (#2912) 2025-01-15 19:20:44 +01:00
Baptiste Colle
1470aec9d9
Fix typo in TPU docs (#2911)
docs(tpu): fix typo
2025-01-15 18:32:07 +01:00
Nicolas Patry
203cade244
Upgrading our rustc version. (#2908)
* Upgrading our rustc version.

* Fixing the rust tests to proper version.

* Clippy everything.
2025-01-15 17:04:03 +01:00
Baptiste Colle
46994b34fb
📝 add guide on using TPU with TGI in the docs (#2907) 2025-01-15 16:26:11 +01:00
Alvaro Bartolome
dc9b8e9814
Fix docker run in README.md (#2861)
* Fix `docker run` in `README.md`

* Add line-break in `docker run` for readability

Co-authored-by: Daniël de Kok <danieldk@users.noreply.github.com>

* Add line-break in `docker run` for readability

Co-authored-by: Daniël de Kok <danieldk@users.noreply.github.com>

---------

Co-authored-by: Daniël de Kok <danieldk@users.noreply.github.com>
2025-01-15 16:07:10 +01:00
Guspan Tanadi
3c7ae48f7f
docs(conceptual/speculation): available links Train Medusa (#2863) 2025-01-15 16:05:54 +01:00
Wang, Yi
cc8b9650bd
Baichuan2-13B does not have max_position_embeddings in config (#2903)
* Baichuan2-13B does not have max_position_embeddings in config
see https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/config.json

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Update server/text_generation_server/models/flash_causal_lm.py

Co-authored-by: Daniël de Kok <me@github.danieldk.eu>

* fmt

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Daniël de Kok <me@github.danieldk.eu>
2025-01-15 15:56:52 +01:00
Mohit Sharma
e07acc7f68
Enable FP8 Per-Tensor Scales and Integrate Marlin/MoE Kernels Repo for ROCm (#2825)
* (feat) convert tscales to tensorwise

* (fix) fp8 scaling for cuda

* (kernel) add marlin-kernels

* add moe-kernels

* fix moe kernel comit

* fix scaling

* nm changes
2025-01-15 11:38:58 +05:30
Mohit Sharma
880ab9c2f3
Add Flash decoding kernel ROCm (#2855)
* (vllm) updated vllm rocm kernels

* revert silu

* update partition size

* remove grouped_topk

* (nit) remove log

* add flash decoding
2025-01-13 11:12:35 +01:00
Wang, Yi
1660154ae6
fix crash in torch2.6 if TP=1 (#2885)
error like "ValueError: Expecting a ProcessGroup, but got a <class
'text_generation_server.utils.dist.FakeGroup'>. rank=0"

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-01-13 11:11:31 +01:00
Nicholas Broad
2e22164d4a
Update using_guidance.md (#2901)
deletes one copy of a sentence that repeated twice
2025-01-13 11:09:35 +01:00
lazariv
83624a07be
Add possible variants for A100 and H100 GPUs for auto-detecting flops (#2837)
* Update main.rs with A100 and H100 variants

* Add another variant "nvidia-h100-nvl"

* Update main.rs

Add nvidia-a100-sxm4-40gb
2025-01-10 16:12:02 +01:00
Dmitry Dygalo
01067f8ba8
chore: Update jsonschema to 0.28.0 (#2870)
* chore: Update jsonschema to 0.28.0

Signed-off-by: Dmitry Dygalo <dmitry@dygalo.dev>

* chore: Enable blocking feature for reqwest

Signed-off-by: Dmitry Dygalo <dmitry@dygalo.dev>

---------

Signed-off-by: Dmitry Dygalo <dmitry@dygalo.dev>
2025-01-10 15:01:54 +01:00
Daniël de Kok
4f7e00f4ce
Update to marlin-kernels 0.3.7 (#2882)
This fixes a race condition. See:

https://github.com/vllm-project/vllm/pull/11493
2025-01-10 12:43:44 +01:00
drbh
da5ab46705
Improve vlm support (add idefics3 support) (#2437)
* feat: expand vlm support and add image token logic and tests

* fix: avoid unused perceiver config

* feat: integrate image tokens into inputs embeds

* feat: add simple idefics3 test

* feat: update docs, image token logic and weight names

* fix: improve image processing

* feat: improve prefix for idefics3

* fix: bump idefics3 tests and snapshots

* fix: improve text model loading

* feat: consolidate changes with existing vlms and add support and test for smolvlm

* fix: create new idefic3 file, simplify logic and adjust llama weight loading

* fix: lint with ruff

* fix: clean up idefics 3 and improve prefix handling

* fix: improve typing

* fix: improve prompt_split_image with ref to original impl

* fix: adjust ruff lints and small refactors

* fix: adjust FlashLlamaModel prefix logic
2025-01-09 10:35:32 -05:00
Daniël de Kok
a9c7d2e3b6
Basic flashinfer 0.2 support (#2862)
* Basic flashinfer 0.2 support

This change does not use any of the new features yet, but makes
some small compatibility changes.

* Update to flashinfer 0.2.0.post1

* flashinfer: remove `contiguous` calls

* Fix flashinfer install

* flashinfer: fixup kv cache dtype

* Fix some annoying perturbations

* More output changes
2025-01-09 16:25:00 +01:00
Wang, Yi
afb6c728d8
update ipex xpu to fix issue in ARC770 (#2884)
* update ipex xpu to fix issue in ARC770

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add ats support

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-01-09 10:11:03 +01:00
Ruida Zeng
d37a43e581
chore: fixed some typos and attribute issues in README (#2891)
* chore: fixed html repeated attribute in README

* chore: fix minor grammar/capitalization

* chore: fixed spelling mistakes in README
2025-01-09 10:09:23 +01:00
drbh
23bc38b10d
fix: include add_special_tokens in kserve request (#2859)
merging as this patch is already used, and fully limit to the kserve feature
2024-12-19 16:55:17 -05:00
Wang, Yi
ab5f616920
change xpu lib download link (#2852)
Signed-off-by: Wang,Yi A <yi.a.wang@intel.com>
2024-12-19 12:18:58 +01:00
Mohit Sharma
8f66d323d0
Update vllm kernels for ROCM (#2826)
* (vllm) updated vllm rocm kernels

* revert silu

* update partition size

* remove grouped_topk

* (nit) remove log

* update moe-kernels commit
2024-12-18 12:44:42 +01:00
janne-alatalo
7eeefa3b57
Qwen2-VL runtime error fix when prompted with multiple images (#2840)
* Fix runtime error when Qwen2-VL was prompted with multiple images

Fix runtime error when Qwen2-VL model is prompted with prompt with more
than one image. The runtime error was:

 File "text-generation-inference/server/text_generation_server/models/custom_modeling/qwen2_vl.py", line 459, in get_position_ids
    text_pos_ids = torch.arange(text_length, device=d)
RuntimeError: upper bound and larger bound inconsistent with step sign

The error was caused by text_length variable going to negative value
when multiple images caused multiple loops in the get_position_ids
function's main loop.

The error is a simple logic mistake where next_image_pos is initialized
as relative offset from current_pos, but was used like it was absolute
position from zero.

* Fix runtime error when Qwen2-VL was prompted with multiple images

Fix runtime error when Qwen2-VL model is prompted with prompt with more
than one image. The runtime error was:

File "text-generation-inference/server/text_generation_server/models/custom_modeling/qwen2_vl.py", line 534, in forward
    inputs_embeds[input_ids == self.image_token_id] = image_embeds
RuntimeError: shape mismatch: value tensor of shape [512, 3584] cannot be broadcast to indexing result of shape [1024, 3584]

(The error message shape numbers can be different depending on the input
image resolutions)

The error was caused by adding the wrong number of <|image_pad|> tokens
to the tokenized input in the image_text_replacement function.

The error is a simple logical mistake where the number of image pad
tokens is checked from pixel_value_shape tensor's first dimension
length. However, the pixel_value_shape contains patches from all of the
images. Therefore the code added the total number of required image pad
tokens for the whole input to each of the images locations. This
resulted to extra image pad tokens to be present in the tokenized input.

The fix was to check the number of required tokens from the
image_grid_thw tensor. The tensor includes grid_t, grid_h, and grid_w
values for each image. grid_t * grid_h * grid_w results to the total
number of patches for the image [1]. The number of required image pad
tokens is number_of_patches // 4.

[1] 31f9a289a6/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py (L311)

---------

Co-authored-by: Janne Alatalo <janne.alatalo@jamk.fi>
2024-12-16 22:55:11 -05:00
drbh
a72f339c79
fix: lint backend and doc files (#2850) 2024-12-16 16:12:34 -05:00
Nicolas Patry
11ab329883
Fixing CI. (#2846) 2024-12-16 10:58:15 +01:00
Nicolas Patry
6f0b8c947d
New arg. (#2845) 2024-12-16 10:34:50 +01:00
Hugo Larcher
1708865fdc
Feat/trtllm cancellation dev container (#2795)
Add devcontainers for TRTLLM backend.

---------

Co-authored-by: Morgan Funtowicz <morgan@huggingface.co>
2024-12-13 16:19:06 +01:00
Funtowicz Morgan
ea7f4082c4
TensorRT-LLM backend bump to latest version + misc fixes (#2791)
* misc(cmake) update dependencies

* feat(hardware) enable new hardware.hpp and unittests

* test(ctest) enable address sanitizer

* feat(backend): initial rewrite of the backend for simplicity

* feat(backend): remove all the logs from hardware.hpp

* feat(backend): added some logging

* feat(backend): enable compiler warning if support for RVO not applying

* feat(backend): missing return statement

* feat(backend): introduce backend_workspace_t to store precomputed information from the engine folder

* feat(backend): delete previous backend impl

* feat(backend): more impl

* feat(backend): use latest trtllm main version to have g++ >= 13 compatibility

* feat(backend): allow overriding which Python to use

* feat(backend): fix backend_exception_t -> backend_error_t naming

* feat(backend): impl missing generation_step_t as return value of pull_tokens

* feat(backend): make backend_workspace_t::engines_folder constexpr

* feat(backend): fix main.rs retrieving the tokenizer

* feat(backend): add guard to multiple header definitions

* test(backend): add more unittest

* feat(backend): remove constexpr from par

* feat(backend): remove constexpig

* test(backend): more test coverage

* chore(trtllm): update dependency towards 0.15.0

* effectively cancel the request on the executor

* feat(backend) fix moving backend when pulling

* feat(backend): make sure we can easily cancel request on the executor

* feat(backend): fix missing "0" field access

* misc(backend): fix reborrowing Pin<&mut T> as described in the doc https://doc.rust-lang.org/stable/std/pin/struct.Pin.html#method.as_mut

* chore: Add doc and CI for TRTLLM (#2799)

* chore: Add doc and CI for TRTLLM

* chore: Add doc and CI for TRTLLM

* chore: Add doc and CI for TRTLLM

* chore: Add doc and CI for TRTLLM

* doc: Formatting

* misc(backend): indent

---------

Co-authored-by: Hugo Larcher <hugo.larcher@huggingface.co>
2024-12-13 15:50:59 +01:00
Nicolas Patry
3bb3fd19ae
Fixup opt to reduce the amount of odd if statements. (#2833)
* Fixup opt to reduce the amount of odd if statements.

* Fixing cargo lock
2024-12-12 18:20:13 +01:00
Wang, Yi
bf59118a93
fix facebook/opt-125m not working issue (#2824)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-12-12 14:41:30 +01:00
Nicolas Patry
c3bd7212c2
Fixing latest flavor by disabling it. (#2831) 2024-12-12 14:09:35 +01:00
Guspan Tanadi
f01f2fb6e7
docs(README): supported hardware links TGI AMD GPUs (#2814) 2024-12-12 13:49:33 +01:00
Nicolas Patry
07b01293c5
Prepare patch release. (#2829) 2024-12-11 21:03:50 +01:00
RodriMora
cc66dccbe8
Update README.md (#2827)
Added instructions to clone the repo and change directory into it. 

In following steps there is a "make install" step that would fail if people have not cloned the repo and cd into it, so it may be confusing for some

Added python venv alternative to conda too.
2024-12-11 19:45:49 +01:00
Nicolas Patry
82c24f7420
Using both value from config as they might not be correct. (#2817)
* Using both value from config as they might not be correct.

* Fixing max_position_embeddings for falcon.

* Simple attempt to fix the healthcheck block allocation.

* Much simpler solution.

* Default value for Backend start_health
2024-12-10 19:37:09 +01:00
Nicolas Patry
a2d878fa0f
Small update to docs (#2816) 2024-12-10 10:46:26 +01:00
Nicolas Patry
b2fac5d947
Hotfix link2 (#2812)
2nd hotfix ?
2024-12-09 20:57:18 +01:00
Nicolas Patry
a70dd2998b
Hotfixing the link. (#2811) 2024-12-09 20:50:07 +01:00
Nicolas Patry
042791fbd5
Prep new version (#2810)
* New version.

* Link fixup.

* Update docs.

* FIxup.
2024-12-09 20:42:42 +01:00
Nicolas Patry
27fa83ca5b
V3 doc (#2809)
* V3 document.

* Updating asset.
2024-12-09 19:58:07 +01:00
Nicolas Patry
a04356fb8c
Attempt for cleverer auto batch_prefill values (some simplifications). (#2808)
* Attempt for cleverer auto batch_prefill values (some simplifications).

* Less flaky tests.

* Fixing typo insertion.

* Update launcher/src/main.rs

Co-authored-by: Daniël de Kok <me@danieldk.eu>

* Adding small comment for source of calculation.

* Adding L40.

* Adding L40s.

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2024-12-09 19:44:32 +01:00
drbh
9f5c9a5e22
Enable paligemma2 (#2807)
* feat: support loading gemma2 as vlm text model

* feat: add test for paligemma2
2024-12-06 14:41:49 -05:00
Nicolas Patry
08f6fa0b59
Removing experimental to prefill chunking. 2024-12-06 19:09:40 +01:00
Nicolas Patry
d96dcb1797
Adding A100 compute. (#2806) 2024-12-06 18:19:15 +01:00
Nicolas Patry
5df8059037
Auto max prefill (#2797)
* Attempt at automatic max batch prefill.

* Taking into account number of shards.

* Adding more cards.

* Adding A100 + H100

* Adding a few more cards.

* Logprobs cost too much.

* h100 better name, and keep factor of 2

* Damn inflated sparse tflops.

* Typo in h100.

* Updated the flops calculation (checked with fvcore).

* chunking by default.

* Fix prefix caching for chat completion since we removed logprobs.

* More tests.

* Dropping all the prefill logprobs.

* Add a flag that enables users to get logprobs back.

* Repairing prompt token counting.

* Fixing a few tests.

* Remove some scaffolding.

* Attempting to reduces the issues (workarounds for now).
2024-12-06 05:52:00 +01:00
OlivierDehaene
8c3669b287
feat: auto max_new_tokens (#2803)
* feat: auto max_new_tokens

* update default

* Fixing the tests.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-12-06 05:50:35 +01:00
Wang, Yi
6685e8fcda
use oneapi 2024 docker image directly for xpu (#2793)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-12-06 09:36:23 +05:30
drbh
e0db633396
fix: avoid setting use_sgmv if no kernels present (#2796) 2024-12-04 15:26:09 -05:00
Nicolas Patry
b57f370386
Saving some VRAM. (#2790)
* Saving some VRAM.

- 8B on 4xL4 attention=flashdecoding . Before 4.28GB left, After 4.32GB
  left, so 400MB saved.

- Effect not as visible on attention=flashinfer and n_shard=1. I suspect
  it's linked to the torch allocator.

* Adding assertion.
2024-12-03 04:04:21 +01:00
Daniël de Kok
2003d8be0c
Sync (most) server dependencies with Nix (#2782)
* Sync (most) server dependencies with Nix

Skipped most grpcio packages, because of protobuf version
incompatibility with the opentelemetry packages.

* Add a primitive script to generate Poetry commands to sync with Nix

This is not fully automated, since getting the Nix versions may be
unresolvable. However, it does take most of the work out of doing
this manually.

* Upgrade eetq ?

* Fmt.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-12-03 04:04:06 +01:00
Dmitry Rogozhkin
535149d872
fix: only use eos_token_id as pad_token_id if int (#2774)
LLama 3 has a list of values as eos_token_id:
  "['<|end_of_text|>', '<|eom_id|>', '<|eot_id|>']"
This breaks tokenizer since it expects single value. This
commit uses tokenizer.eos_token_id instead in such a case.

Fixes: #2440

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
2024-12-02 06:26:37 +01:00
drbh
2c74c55637
fix: add merge-lora arg for model id (#2788) 2024-12-02 05:52:02 +01:00
Torsten Raudssus
a35d1e6fe5
Removing ../ that broke the link (#2789) 2024-12-02 05:48:55 +01:00
Nicolas Patry
1d2cb356b9
Fix doc. (#2792) 2024-12-02 05:28:26 +01:00
drbh
d471805134
Support continue final message (#2733)
* feat: support continue_final_message param in chat request

* feat: add test for continue final message

* fix: bump openapi docs

* fix: remove continue_final_message chat request param

* fix: remove unneeded launcher args in continue test

* fix: bump test output

* fix: remove accidentally included guideline from rebase

* fix: remove guideline tests

* fix: adjust continuation tests expected text

* fix: replace expected output for continue test
2024-11-27 19:13:30 -05:00
jp
caff779dd4
Fix: docs typo (#2777)
Fix: typo in model loading code

Fix typo in model loading code
2024-11-26 14:28:58 +01:00
Wang, Yi
892a26e549
upgrade ipex cpu to fix coredump in tiiuae/falcon-7b-instruct (pageat… (#2778)
upgrade ipex cpu to fix coredump in tiiuae/falcon-7b-instruct (pageattention)

Signed-off-by: Wang,Yi A <yi.a.wang@intel.com>
2024-11-26 14:28:11 +01:00
Daniël de Kok
72ab60fdd5
Use FP8 KV cache when specified by compressed-tensors (#2761)
The compressed-tensors configuration can specify the configuration of
the KV cache as well. Use an FP8 KV cache when the configuration tells
us to do so (all other options and types are ignored for now).
2024-11-26 08:27:41 +01:00
Daniël de Kok
289aa48554
Move JSON grammar -> regex grammar conversion to the router (#2772)
* Move JSON grammar -> regex grammar conversion to the router

This change moves the JSON grammar -> regex grammar conversion to the
router by adding a dependency on the `outlines-core` Rust crate. In
contrast to the Python implementation, the conversions are not LRU-cached
since they seem to be fast enough:

simple schema           time:   [5.8293 µs 5.8307 µs 5.8320 µs]
                        change: [-13.166% -12.884% -12.641%] (p = 0.00 < 0.05)
                        Performance has improved.

complex schema          time:   [14.875 µs 14.881 µs 14.887 µs]
                        change: [-2.1637% -1.9914% -1.7852%] (p = 0.00 < 0.05)
                        Performance has improved.

Using the schemas from:
https://github.com/dottxt-ai/outlines-core/blob/main/benchmarks/bench_json_schema.py
2024-11-25 18:47:34 +01:00
drbh
c637d68d74
feat: concat the adapter id to the model id in chat response (#2779)
* feat: concat the adapter id to the model id in chat response

* fix: updated to include only the adapter id in chat response
2024-11-25 12:36:31 -05:00
OlivierDehaene
780531ec77
chore: prepare 2.4.1 release (#2773)
* chore: prepare 2.4.1 release

* fix tests

* fmt
2024-11-22 17:26:15 +00:00
Daniël de Kok
e87893d38e
chore: Update to marlin-kernels 0.3.6 (#2771)
This fixes a bug in 2:4 Marlin:
https://github.com/vllm-project/vllm/pull/10464
2024-11-22 14:44:47 +00:00
OlivierDehaene
ab7ccf5bc3
feat: add payload limit (#2726)
* feat: add payload limit

* update launcher
2024-11-21 18:20:15 +00:00
Hugo Larcher
d5bc6a20bd
feat: Add automatic nightly benchmarks (#2591)
* feat: Add automatic nightly benchmarks

* fix: Update runners group

* fix: add created_at field to results

* fix: Add variable results file location
2024-11-21 17:11:42 +00:00
Lucain
d012f229c6
Remove guideline from API (#2762) 2024-11-21 16:56:38 +00:00
Daniël de Kok
c5b5b3a11c
docs: Add a README section about using Nix (#2767) 2024-11-21 16:53:27 +00:00
drbh
faa10ad0bc
fix: tweak grammar test response (#2769) 2024-11-21 16:46:00 +00:00
OlivierDehaene
8e0c161d0a
fix: incomplete generations w/ single tokens generations and models that did not support chunking (#2770)
* Incomplete generation stream fix (#2754)

entries.len() could > batch.size in prefill, so need to filter as well.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* entries was wrongly extended for model that did not support chunking

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi <yi.a.wang@intel.com>
2024-11-21 16:37:55 +00:00
Daniël de Kok
3c54488638
nix: downgrade to outlines 0.1.3 (#2768) 2024-11-21 13:00:26 +01:00
drbh
6ee8d6dd3b
fix: set outlines version to 0.1.3 to avoid caching serialization issue (#2766)
fix: set outlines version to 0.1.3 to avoid bug
2024-11-20 18:09:39 -05:00
Daniël de Kok
07bed530f7
nix: build and cache impure devshells (#2765)
* nix: build and cache all devshells

* nix: add poetry to the impure shell

This shouldn't be used to manage dependencies in a Nix devshell, but can
be handy to update `poetry.lock`.

* Fix Nix build, disable pure shell (covered by Nix tests)
2024-11-20 20:56:11 +01:00
Daniël de Kok
46a5a7e73e
Add support for wNa16 int 2:4 compressed-tensors checkpoints (#2758)
This change adds support for wNa16 int checkpoints with 2:4 sparsity
using Marlin 2:4 kernels.
2024-11-20 18:25:23 +01:00
Daniël de Kok
2fda8845a7
nix: update for outlines 0.1.4 (#2764) 2024-11-20 18:24:29 +01:00
Daniël de Kok
45013b60a4 Install compressed-tensors in Docker CPU builds 2024-11-20 14:17:47 +00:00
drbh
bd6e8b3c13
fix: adjust llama MLP name from dense to mlp to correctly apply lora (#2760) 2024-11-19 15:10:22 -05:00
drbh
5489406c4a
PR 2634 CI - Fix the tool_choice format for named choice by adapting OpenAIs scheme (#2645)
* add OpenAI like tool_choice for named choice

* add tests

* fix: run linter and bump api docs

* fix: consolidate changes and remove old tool type

* feat: improve, simplify and rename tool choice struct add required support and refactor

* fix: simplify tool choice logic, improve tests, openapi and rust docs

* fix: refactor away prepare_chat_input and improve tool grammar apply control flow

* feat: update docs and add tool choice configuration section

* fix: simplify naming, tool choice default and improve test

* fix: adjust tool choice none logic, add test and small refactors

* fix: add missing snapshot file

* fix: adjust tool choice type in test

* fix: adjust default when json tool choice is

* fix: remove trailing space lint after rebase

* fix: remove mostly mocked unit test

---------

Co-authored-by: Linus Bierhoff <linus.bierhoff@icloud.com>
2024-11-19 13:31:59 -05:00
Daniël de Kok
2007a9473a
Update to moe-kernels 0.7.0 (#2720)
This version syncs with the vLLM kernels and brings some performance
improvements.
2024-11-19 14:55:29 +01:00
Daniël de Kok
b4ec427ad0
Simplify two ipex conditions (#2755) 2024-11-19 08:04:23 +01:00
drbh
38cff84a3e
feat: support flash attention 2 in qwen2 vl vision blocks (#2721)
* feat: support flash attention 2 in qwen2 vl vision blocks

* fix: calc max_seqlen once and small refactors
2024-11-18 12:46:40 -05:00
Daniël de Kok
3c9df21ff8
Add support for compressed-tensors w8a8 int checkpoints (#2745)
* Add support for compressed-tensors w8a8 int checkpoints

This change adds a loader for w8a8 int checkpoints. One large benefit of
int8 support is that the corresponding cutlass matmul kernels also work on
compute capability 7.5.

Evaluation on neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8:

|     Tasks     |Version|     Filter     |n-shot|        Metric         |   |Value |   |Stderr|
|---------------|------:|----------------|-----:|-----------------------|---|-----:|---|------|
|gsm8k_cot_llama|      3|flexible-extract|     8|exact_match            |↑  |0.8431|±  |0.0100|
|               |       |strict-match    |     8|exact_match            |↑  |0.8393|±  |0.0101|
|ifeval         |      4|none            |     0|inst_level_loose_acc   |↑  |0.8597|±  |   N/A|
|               |       |none            |     0|inst_level_strict_acc  |↑  |0.8201|±  |   N/A|
|               |       |none            |     0|prompt_level_loose_acc |↑  |0.7967|±  |0.0173|
|               |       |none            |     0|prompt_level_strict_acc|↑  |0.7468|±  |0.0187|

Which is the same ballpark as vLLM.

As usual, lots of thanks to Neural Magic/vLLM for the kernels.

* Always use dynamic input quantization for w8a8 int

It's far less flaky and gives better output.

* Use marlin-kernels 0.3.5

* Fix a typo

Co-authored-by: drbh <david.richard.holtz@gmail.com>

* Small fixes

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
2024-11-18 17:20:31 +01:00
Wang, Yi
a5ecd6e586
add ipex moe implementation to support Mixtral and PhiMoe (#2707)
* add ipex moe implementation to support Mixtral and PhiMoe

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update to ipex xpu 2.5

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* torch has xpu support in 2.5

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix oneapi basekit version

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@github.danieldk.eu>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Daniël de Kok <me@github.danieldk.eu>
2024-11-18 17:16:55 +01:00
drbh
fea62e928f
fix: improve find_segments via numpy diff (#2686) 2024-11-18 09:51:06 -05:00
Daniël de Kok
52e48739a5
Remove vLLM dependency for CUDA (#2751)
* Remove vLLM dependency for CUDA

This change adds `attention-kernels` as a dependency for paged
attention and cache reshaping. With that, we don't use vLLM
anywhere for CUDA.

Tested run (since we don't have paged attention in CI):

```
❯ ATTENTION=paged python -m pytest integration-tests -k "llama and awq" --release
[...]
5 snapshots passed.
```

* Fix clippy warning
2024-11-17 17:34:50 +01:00
drbh
6489f85269
feat: return streaming errors as an event formatted for openai's client (#2668)
* feat: return streaming errors as an event formatted for openai's client

* fix: propagate completions error events to stream

* fix: improve stream api error format and add status code

* fix: improve streamin error to include error_type

* Revert "fix: improve streamin error to include error_type"

This reverts commit 2b1a360b15.

* Reworked the implementation.

* Revert "Reworked the implementation."

This reverts commit 7c3f29777f17411ae4ade57e2f88e73cde704ee5.

* Small lifting.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-11-15 14:49:19 +01:00
Nicolas Patry
34a3bdedc3
Upgrading our deps. (#2750)
* Upgrading our deps.

* fixup.

* Fixup.
2024-11-15 14:03:27 +01:00
Alex Weston
4580ced091
Upgrade outlines to 0.1.1 (#2742)
* Upgrade outlines to 0.1.1

* Update for new API

* Check if allowed tokens is None

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-11-15 13:22:52 +01:00
jito
003eaec0fb
fix response type of document for Text Generation Inference (#2743)
Signed-off-by: jitokim <pigberger70@gmail.com>
2024-11-15 13:21:50 +01:00
Billel Mokeddem
4f4857a4ac
Fix: Change embeddings to embedding (#2738)
fix: change embeddings to embedding

Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-135.us-west-2.compute.internal>
2024-11-15 13:16:15 +01:00
Billel Mokeddem
f9ee46f740
Fix: Change model_type from ssm to mamba (#2740)
Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-135.us-west-2.compute.internal>
2024-11-15 13:15:36 +01:00
Daniël de Kok
8442f1ac85
benchmark: fix prefill throughput (#2741) 2024-11-15 13:14:55 +01:00
Daniël de Kok
ca4f46ddfc
nix: update nixpkgs (#2746)
Updates from Triton 2.1.0 to 3.1.0 (among other things).
2024-11-14 18:48:20 +01:00
Daniël de Kok
a785000842
Add initial support for compressed-tensors checkpoints (#2732)
compressed-tensors is a safetensors extension for sparse, quantized
tensors. The format is more powerful than earlier AWQ/GPTQ/FP8
quantization, because

- Different quantizer configurations can be used for different targets.
- The format can specify input/output quantizers in addition to weight
  quantizers.
- Configurable exclusions for quantization.

This change adds a dependency on the `compressed-tensors` package for
its configuration parsing and layer matching functionality.

The following types of quantization are supported in this PR:

- W8A16 and W4A16 INT using GPTQ-Marlin kernels.
- W8A8 and W8A16 FP using FP8-Marlin and cutlass kernels.

Support for other quantization types will be added in subsequent PRs.
2024-11-10 13:54:07 +01:00
Wang, Yi
97f7a22f0b
add trust_remote_code in tokenizer to fix baichuan issue (#2725)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-11-07 14:43:38 +01:00
Wang, Yi
b1f9044d6c
fix incorrect output of Qwen2-7B-Instruct-GPTQ-Int4 and Qwen2-7B-Inst… (#2717)
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
Close stale issues and PRs / stale (push) Has been cancelled
Nightly load test / load-tests (push) Has been cancelled
fix incorrect output of Qwen2-7B-Instruct-GPTQ-Int4 and Qwen2-7B-Instruct-AWQ
ipex kernel provide func like add_bias, so no need add it outside

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-11-04 16:07:51 +01:00
Daniël de Kok
5eedb2ec7a
nix: move to tgi-nix main (#2718) 2024-11-04 15:40:13 +01:00
Nicolas Patry
9fde566602
Fixing linting on main. (#2719) 2024-11-04 15:21:41 +01:00
Travis Addair
aadc9cb485
Fix prefix caching + speculative decoding (#2711) 2024-11-04 15:08:43 +01:00
Nicolas Patry
a5593ba83e
Hotfixing auto length (warmup max_s was wrong). (#2716)
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2024-11-04 09:55:54 +01:00
drbh
08c4184eb2
fix: add chat_tokenize endpoint to api docs (#2710) 2024-11-04 06:44:59 +01:00
drbh
6e3220529d
fix: create position ids for text only input (#2714)
* fix: create position ids for text only input

* fix: prefer repeat over expand to avoid clone
2024-11-02 08:40:05 +08:00
drbh
01dacf8e8f
fix cuda graphs for qwen2-vl (#2708)
* feat: support multidimensional position ids on batch to enable cuda graphs on qwen2-vl

* fix: only check model type if config exists

* fix: adjust sharding and lm head logic

* fix qwen2 failure in intel cpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix: return correct shape logits and add streaming test

* fix: remove unused import and refactor test

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-11-01 03:05:34 +01:00
drbh
befd9f6735
Support qwen2 vl (#2689)
* feat: add support for qwen2 vl model

* feat: fix token padding, enable warmup and process basic request

* fix: improve get_position_ids, add lift embed_tokens

* fix: remove get_cos_sin_hack dev function

* feat: add simple test chat with meesage and text

* fix: lint test

* fix: adjust positional embeddings for multi dimensional position ids

* fix: update docs and lint unused vars

* fix: include linted file

* fix: add norm after text output

* fix: format model file

* fix: adjust for ruff lints

* fix: remove unused rotate_half

* feat: refactors and calc num features

* fix: prefer position_ids passed from vlm causal lm and reset ids on batch

* fix: adjust get_position_ids if not available and add required args to signatures

* fix: adjust resize case for qwen2_vl warmup

* fix: avoid qwen2 vl specific paths with qwen2
2024-10-30 12:40:51 -04:00
Wang, Yi
46aeb0860d
add xpu triton in dockerfile, or will show "Could not import Flash At… (#2702)
add xpu triton in dockerfile, or will show "Could not import Flash Attention enabled models: No module named 'triton'"

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-10-30 14:18:50 +01:00
Nicolas Patry
98330df65e
Monkey patching as a desperate measure. (#2704)
* Monkey patching as a desperate measure.

* New snapshot ?
2024-10-28 11:25:13 +01:00
Nicolas Patry
513d19b955
More timeout on docker start ? (#2701)
* More timeout on docker start ?

* Latest upgrade.
2024-10-28 08:57:22 +01:00
Nicolas Patry
3a9cdc3241
Fixing auto bloom test. (#2699) 2024-10-28 06:14:11 +01:00
Nicolas Patry
78ce618c70
Update poetry lock. (#2698) 2024-10-28 06:11:33 +01:00
Nicolas Patry
90b226db29
We can have a tokenizer anywhere. (#2527)
* We can have a tokenizer anywhere.

* Handling potential lack of offsets (python tokenizer)

* Remove redundancy.

* Fixing the tests.

* Flake.lock update ?

* Fixing the  GIL locking.

* Fixing mamba by using the transformers version.

* Adding the legacy handle.

* Ellide lifetime.

* Lint.

* Deprecation message.

* Fixing bad rebase.
2024-10-28 05:00:24 +01:00
Nicolas Patry
0c9b6cdd76
Choosing input/total tokens automatically based on available VRAM? (#2673)
* Choosing input/total tokens automatically based on available VRAM?

* Update doc.

* Remove generated files.

* Trying to fix non chunking targets.

* Attempt #2

* fix.

* QuantLinear is rocm compatible.

* Much simpler logic after the overhead.

* Updating logic + non flash.

* Revert doc text.

* Simple updates.

* Fix integration mt0 (transformers update).
2024-10-28 04:59:49 +01:00
Nicolas Patry
2e4f4ba1bb
Green main (#2697) 2024-10-28 04:59:32 +01:00
Nicolas Patry
8a8794a672
Avoiding timeout for bloom tests. (#2693)
* Avoiding timeout for bloom tests.

* Skip the test let's see if it's always the first tests that fails.

* Fail early.

* Pulling ?

* No early exit.
2024-10-26 05:35:28 +02:00
OlivierDehaene
a6b02da971
chore: prepare 2.4.0 release (#2695) 2024-10-25 21:10:49 +00:00
OlivierDehaene
6f88bd9390
feat: add triton kernels to decrease latency of large batches (#2687)
* feat: add triton kernels to decrease latency of large batches

* cast to int32

* fix kernel

* fix kernel

* disable triton on rocm

* fix speculation

* add slots filtering kernel
2024-10-25 21:10:00 +00:00
Daniël de Kok
0f346a3296
Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels (#2688)
* Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels

Performance and accuracy of these kernels are on par (tested with Llama
70B and 405B). Removes a dependency and resolves some stability issues
we have been seeing.

* Update test snapshots
2024-10-25 16:40:47 +02:00
Funtowicz Morgan
ba5fc7d922
Add support for stop words in TRTLLM (#2678)
* feat(trtllm): rewrite health to not account for current state

* chore(looper): cleanup a bit more

* feat(post_processing): max_new_tokens is const evaluated now

* chore(ffi):formatting

* feat(trtllm): add stop words handling

# Conflicts:
#	backends/trtllm/lib/backend.cpp

* chore(trtllm): create specific parallelconfig factory and logging init methods

* chore(trtllm): define a macro for SizeType cast

* chore(trtllm): use GetParallelConfig

* chore(trtllm): minor refactoring

* chore(trtllm): validate there are enough GPus on the system for the desired model

* chore(trtllm): ensure max throughput scheduling policy is selected

* chore(trtllm): minor fix

* chore(router): minor refactorings

* feat(docker): build with-slurm ompi

* feat(docker): add python3.10 dev to runtime deps

* chore(docker): add mpi to ld_library_path

* chore(docker): install transformers

* feat(trtllm): detect stop_words from generation_config.json
2024-10-25 10:58:34 +02:00
Nicolas Patry
db68bd0524
Fixing mt0 test. (#2692) 2024-10-25 09:46:39 +02:00
Nicolas Patry
cece8635f8
Fixing rocm gptq by using triton code too (renamed cuda into triton). (#2691) 2024-10-25 09:17:57 +02:00
Funtowicz Morgan
43df056eee
[TENSORRT-LLM] - Implement new looper thread based backend (#2357)
* (backend) use parking_lot crate for RwLock fairness

# Conflicts:
#	backends/trtllm/src/backend.rs

* (launcher) default new server::run parameters to false for now

* (chore) fmt ... why?

* (ffi) use const for GetSamplingConfig

* (server) expose new SchedulingError

* (trt)

* (build) setup ccache if available

* (ffi) add max_new_tokens parameters

* (backend) cleanup a bit

* (backend) expose PullNewTokens

* (ffi) cleanup again

* (ffi) add missing headers imports

* (ffi) add template specialization to catch and convert to Rust Result<T, tensorrt_llm::common::TllmException>

* (looper) new looper initial implementation

* (ffi) remove narrowing type warning

* (ffi) encode the provided user prompt within each request thread

* (misc) change scope identifiers

* (backend) implement the post_processor background thread

* (misc) missing Result types for Rust

* use blocking_recv in looper to consume awaiting_requests at max before pulling in a single step

* (server) forward auth_token to server::run

* (build) fetchcontent use archives instead of git

* (ffi) fix usage of wrong vector constructor making a capacity fill call

* (ffi) missing namespace for tle::Response

* (ffi) do not use reference capture in lambda as we are not capturing anything

* (backend) refactor & cleanup

* (Dockerfile.trtllm) delete for now

* (misc) simplify [make_]move_iterator by using c++20 type inference

* (misc) no need to move for uint32_t items

* (scheduler) rework submit/pull logic

* (post) impl postprocessing

* (misc) delete backend.rs

* (misc) rerun-if-changed all the cmake modules

* (misc) move to latest trtllm

* (fix): HOPPER_SM_MAJOR is 9 not 8

* (misc: build for sm_{75,80,86,89,90} by default

* (misc): build with trtllm 0.13.0

* (misc): increase verbosity of spdlog

* (fix): do not recreate the stateful hashmap at every it

* (misc): update dependency in trtllm dockerfile

* (misc): update dependency in trtllm dockerfile

* (misc): disable logging in release mode

* (misc): improve trtllm download script robustness

* (fix): ore fixes for Dockerfile

* misc(cuda): require 12.6

* chore(cmake): use correct policy for download_timestamp

* feat(looper): check engine and executorWorker paths exist before creating the backend

* chore(cmake): download timestamp should be before URL

* feat(looper): minor optimizations to avoid growing too much the containers

* chore(trtllm): move dockerfile to right place

* chore(trtllm): disable tokenizer parallelism by default

* chore(trtllm): fmt

* chore(trtllm): post-rebase commit

* chore(trtllm): remove unused method

* feat(trtllm): cache maxNumTokens to avoid calling JSON everytime

* misc(router): remove SchedulingError

* feat(trtllm): do not tokenize twice

* Revert "chore(trtllm): remove unused method"

This reverts commit 31747163

* chore(rebase): fix invalid references

* chore(router): add python dependency

* Lint.

* Fix bad rebase

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-10-25 07:17:14 +02:00
Nicolas Patry
ed87b464b4
Fixing "deadlock" when python prompts for trust_remote_code by always (#2664)
specifiying a value.
2024-10-25 06:39:21 +02:00
Daniël de Kok
eab07f746c
Add support for FP8 KV cache scales (#2628)
* Add support for FP8 KV cache scales

Since FP8 only has limited dynamic range, we can scale keys/values
before storing them into the cache (and unscale them in attention). To
avoid rescaling the cache as the absmax values change, good scales are
usually determined per layer using calibration calibration data and stored
in the checkpoint.

This change adds support for for using key-value scales and loading them
from checkpoints in the two most common formats:

- Separate per-layer `k_scale` and `v_scale` scalars.
- Per-layer `kv_scale` scalar (older format).

Currently, scales are only used with an `float8_e4m3fn` cache.

Besides adding support for key/value scales, the `fp8_quantize` function
is also extended to support quantization with a kernel vendored from
vLLM. This is slightly faster than the PyTorch implementation, but also
scales in FP32, potentially improving accuracy.

* Update FP8 KV cache test to use checkpoint with scales

* `can_scale`: check that the attention is flashinfer
2024-10-24 16:36:18 +02:00
Daniël de Kok
14a0df3a38
Fix Phi 3.5 MoE tests (#2684)
PR #2682 also fixed in issue in Phi MoE, but it changes the test
outputs a bit. Fix this.
2024-10-24 15:21:50 +02:00
Daniël de Kok
1b914f37e7
flashinfer: reminder to remove contiguous call in the future (#2685) 2024-10-24 14:59:56 +02:00
OlivierDehaene
41c2623735
feat: allow any supported payload on /invocations (#2683)
* feat: allow any supported payload on /invocations

* update openAPI

* update doc
2024-10-23 11:26:01 +00:00
OlivierDehaene
27ff1871b5
hotfix: fix flashllama 2024-10-23 13:22:31 +02:00
OlivierDehaene
03c9388bf7
feat: natively support Granite models (#2682)
* feat: natively support Granite models

* Update doc
2024-10-23 10:04:05 +00:00
Daniël de Kok
f58eb70ebf
Make moe-kernels and marlin-kernels mandatory in CUDA installs (#2632) 2024-10-23 11:07:31 +02:00
Daniël de Kok
9c9ef37c56
Add impureWithCuda dev shell (#2677)
* Add `impureWithCuda` dev shell

This shell is handy when developing some kernels jointly with TGI - it
adds nvcc and a bunch of commonly-used CUDA libraries to the environment.

We don't add this to the normal impure shell to keep the development
environment as clean as possible (avoid accidental dependencies, etc.).

* Add cuDNN
2024-10-22 11:02:55 +02:00
Wang, Yi
058d3061f7
break when there's nothing to read (#2582)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-10-21 15:22:48 +02:00
Daniël de Kok
7f54b7336a
Test Marlin MoE with desc_act=true (#2622)
Update the Mixtral GPTQ test to use a model with `desc_act=true` and
`group_size!=-1` to ensure that we are checking activation
sorting/non-full K (with tensor parallelism). The `desc_act=false` case
is already checked by the Mixtral AWQ test.
2024-10-21 12:50:35 +02:00
Daniël de Kok
5e0fb46821
Make handling of FP8 scales more consisent (#2666)
Change `fp8_quantize` so that we can pass around reciprocals everywhere,
so scales are always passed around in the checkpoint format.

I also noticed that we ignore any input scales that we might have when
fbgemm is available. Skip this path if we already have a scale.
2024-10-19 09:05:01 +02:00
Nicolas Patry
153ff3740b
CI job. Gpt awq 4 (#2665)
* add gptq and awq int4 support in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix ci failure

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* set kv cache dtype

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine the code according to the review command

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Simplifying conditionals + reverting integration tests values.

* Unused import

* Fix redundant import.

* Revert change after rebase.

* Upgrading the tests (TP>1 fix changes to use different kernels.)

* Update server/text_generation_server/layers/gptq/__init__.py

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
2024-10-18 17:55:53 +02:00
Daniël de Kok
8ec57558cd
Break cycle between the attention implementations and KV cache (#2627) 2024-10-17 14:54:22 +02:00
drbh
5f32dea1e2
fix: prefer inplace softmax to avoid copy (#2661)
* fix: prefer inplace softmax to avoid copy

* Update server/text_generation_server/models/flash_causal_lm.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-10-17 08:49:02 -04:00
oOraph
1b97e084bf
fix tgi-entrypoint wrapper in docker file: exec instead of spawning a child process (#2663)
tgi-entrypoint: exec instead of spawning a child process

reason: otherwise parent will receive the signals when we'd like tgi to receive them
keeping the parent/child is not necessary and would require the parent to handle signals to forward them properly to the child

Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com>
Co-authored-by: Raphael Glon <oOraph@users.noreply.github.com>
2024-10-17 11:15:26 +02:00
Daniël de Kok
59ea38cbca
Simplify the attention function (#2609)
* Simplify the `attention` function

- Use one definition rather than multiple.
- Add `key`/`value` arguments, so that we don't need the
  `PREFILL_IN_KVCACHE` constant.
- Make it kwargs-only (to avoid mixing up the various `Tensor` args).

* Fixup flashinfer support
2024-10-17 10:42:52 +02:00
Daniël de Kok
5bbe1ce028
Support e4m3fn KV cache (#2655)
* Support `e4m3fn` KV cache

* Make check more obvious
2024-10-17 10:42:16 +02:00
OlivierDehaene
a6a0c97ed9
feat: prefill chunking (#2600)
* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-10-16 12:49:33 +02:00
Mohit Sharma
704a58c807
Fp8 e4m3_fnuz support for rocm (#2588)
* (feat) fp8 fnuz support for rocm

* (review comments) Fix compression_config load, type hints

* (bug) update all has_tensor

* (review_comments) fix typo and added comments

* (nit) improved comment
2024-10-16 09:54:50 +02:00
Alvaro Bartolome
ffe05ccd05
Rollback to ChatRequest for Vertex AI Chat instead of VertexChat (#2651)
As spotted by @philschmid, the payload was compliant with Vertex AI, but
just partially, since ideally the most compliant version would be with
the generation kwargs flattened to be on the same level as the
`messages`; meaning that Vertex AI would still expect a list of
instances, but each instance would be an OpenAI-compatible instance,
which is more clear; and more aligned with the SageMaker integration
too, so kudos to him for spotting that; and sorry from my end for any
inconvenience @Narsil.
2024-10-15 18:11:59 +02:00
Daniël de Kok
ce7e356561 Use flashinfer for Gemma 2. 2024-10-15 13:49:32 +00:00
Nicolas Patry
cf04a43fb1
Fixing linters. (#2650) 2024-10-15 12:43:49 +02:00
Dmitry Rogozhkin
58848cb471
feat: enable pytorch xpu support for non-attention models (#2561)
XPU backend is available natively (without IPEX) in pytorch starting
from pytorch 2.4. This commit extends TGI to cover the case when user
has XPU support thru pytorch 2.4, but does not have IPEX installed.
Models which don't require attention can work. For attention required
models more work is needed to provide attention implementation.

Tested with the following models:
* teknium/OpenHermes-2.5-Mistral-7B
* bigscience/bloom-560m
* google/gemma-7b
* google/flan-t5-xxl

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
2024-10-14 18:28:49 +02:00
Wang, Yi
7a82ddcbd0
update ipex to fix incorrect output of mllama in cpu (#2640)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-10-14 16:32:33 +02:00
Omar Sanseviero
51f5401893
Clarify gated description and quicktour (#2631)
Update quicktour.md
2024-10-14 16:31:37 +02:00
Nicolas Patry
3ea82d008c
Cpu perf (#2596)
* break when there's nothing to read

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Different approach, only listen on stdin when `LOG_LEVEL=debug` (which
is where dropping to a debugger is important).

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
2024-10-14 15:34:08 +02:00
Omar Sanseviero
ce28ee88d5
Small fixes for supported models (#2471)
* Small improvements for docs

* Update _toctree.yml

* Updating the doc (we keep the list actually).

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-10-14 15:26:39 +02:00
Nicolas Patry
0c478846c5
Fixing intel Supports windowing. (#2637) 2024-10-11 21:47:03 +02:00
Nicolas Patry
3dbdf63ec5
Intel ci (#2630)
* Intel CI ?

* Let's try non sharded gemma.

* Snapshot rename

* Apparently container can be gone already.
2024-10-10 16:51:57 +02:00
vb
d912f0bf55
Update documentation to most recent stable version of TGI. (#2625)
Update to most recent stable version of TGI.
2024-10-10 16:00:25 +02:00
drbh
e36dfaa8de
feat: allow tool calling to respond without a tool (#2614)
* feat: process token stream before returning to client

* fix: expect content in test

* fix: improve comparison via ruff lint

* fix: return event in all cases

* fix: always send event on error, avoid unwraps, refactor and improve tests

* fix: prefer no_tool over notify_error to improve reponse

* fix: adjust chat input test for no_tool

* fix: adjust test expected content

---------

Co-authored-by: System administrator <root@ip-10-90-0-186.ec2.internal>
2024-10-10 09:28:25 -04:00
Nicolas Patry
43f39f6894
AMD CI (#2589)
* Only run 1 valid test.

* TRying the tailscale action quickly.

* ?

* bash spaces.

* Remove tailscale.

* More quotes.

* mnt2 ?

* Othername to avoid recursive directories.

* Good old tmate.

* Remove tmate.

* Trying a few things.

* Remove some stuff.

* Sleep ?

* Tmp

* busybox

* Launcher tgi

* Starting hello

* Busybox in python

* No device.

* Removing all variables ?

* A un moment donné.

* Tmp

* Tmp2

* DEvice request, no container name

* No device requests

* Without pytest.

* No pytest.

* from env

* Start with devices

* Attemp #1

* Remove stdin messing

* Only 1 test, no container name

* Raw tgi

* Sending args.

* Show pip freeze.

* Start downloading with token

* Giving HIP devices.

* Mount volume + port forward

* Without pytest.

* No token

* Repeated arguments

* Wrong kwarg.

* On 2 GPUs

* Fallback to single shard CI test.

* Testing

* yaml

* Common cache ?

* Trailing slash ?

* Docker volume split.

* Fix docker volume

* Fixing ?

* ?

* Try no devices ?

* Flash llama on intel CPU ?

* Fix nvidia ?

* Temp deactivate intel, activate nvidia ?
2024-10-09 17:50:49 +02:00
Daniël de Kok
9ed0c85fe1
nix: add black and isort to the closure (#2619)
To make sure that everything is formatted with the same black version
as CI.

I sometimes use isort for new files to get nicely ordered imports,
so add it as well. Also set the isort configuration to format in a
way that is compatible with black.
2024-10-09 11:08:02 +02:00
drbh
8ad20daf33
CI (2599): Update ToolType input schema (#2601)
* Update ToolType input schema

* lint

* fix: run formatter

* fix: allow tool choide to be null

---------

Co-authored-by: Wauplin <lucainp@gmail.com>
2024-10-08 12:35:48 -04:00
Daniël de Kok
6db3bcb700
nix: move back to the tgi-nix main branch (#2620) 2024-10-08 12:55:05 +02:00
Daniël de Kok
64142489b6
Add support for fused MoE Marlin for AWQ (#2616)
* Add support for fused MoE Marlin for AWQ

This uses the updated MoE Marlin kernels from vLLM.

* Add integration test for AWQ MoE
2024-10-08 11:56:41 +02:00
Nicolas Patry
8b295aa498
Upgrade minor rust version (Fixes rust build compilation cache) (#2617)
* Upgrade minor rust version (Fixes rust build compilation cache)

* Black
2024-10-08 09:42:50 +02:00
Wang, Yi
57f9685dc3
enable mllama in intel platform (#2610)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-10-07 21:15:09 +02:00
Florian Zimmermeister
0da4df4b96
Fix FP8 KV-cache condition (#2611)
Update kv_cache.py
2024-10-07 09:34:19 +02:00
Daniël de Kok
2358c2bb54
Add basic FP8 KV cache support (#2603)
* Add basic FP8 KV cache support

This change adds rudimentary FP8 KV cache support. The support is
enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so
uses this type for the KV cache. However support is still limited:

* Only the `fp8_e5m2` type is supported.
* The KV cache layout is the same as `float16`/`bfloat16` (HND).
* The FP8 KV cache is only supported for FlashInfer.
* Loading of scales is not yet supported.

* Fix Cargo.toml
2024-10-04 17:51:48 +02:00
Daniël de Kok
68103079f4
nix: example of local package overrides during development (#2607) 2024-10-04 16:52:42 +02:00
drbh
3011639ff7
Revert "Unroll notify error into generate response" (#2605)
Revert "Unroll notify error into generate response (#2597)"

This reverts commit d22b0c1fbe.
2024-10-03 17:56:40 -04:00
Nicolas Patry
f6e2f05b16
New release 2.3.1 (#2604)
* New release 2.3.1

* Update doc number
2024-10-03 14:43:49 +02:00
drbh
d22b0c1fbe
Unroll notify error into generate response (#2597)
* feat: unroll notify_error if no tool is choosen

* fix: expect simple message when no tool is selected

* fix: improve test to avoid notify_error

* fix: improve docs and indicate change in expected response

* fix: adjust linting in test file
2024-10-02 11:34:57 -04:00
drbh
2335459556
CI (2592): Allow LoRA adapter revision in server launcher (#2602)
allow revision for lora adapters from launcher

Co-authored-by: Sida <sida@kulamind.com>
Co-authored-by: teamclouday <teamclouday@gmail.com>
2024-10-02 10:51:04 -04:00
Nicolas Patry
0204946d26
Max token capacity metric (#2595)
* adding max_token_capacity_metric

* added tgi to name of metric

* Adding max capacity metric.

* Add description for the metrics

---------

Co-authored-by: Edwinhr716 <Edandres249@gmail.com>
2024-10-02 16:32:36 +02:00
Nicolas Patry
d18ed5cfc5
Mllama flash version (#2585)
* Working loading state.

* Preprocessing.

* Working state ? (Broke idefics1 temporarily).

* Cleaner condition.

* Fix idefics.

* Updating config, removing TODO

* Mllama

* Ugrade transformers 4.45

* Flashing mllama.

* Starting to get there.

* Working state.

* Integrations tests for mllama (cutting to 10 tokens because there seems'
to be instability after (meaning size of the batch matters.

* Updating model link.

* Earlier assert.

* Fix vlm ?

* remove log.

* Force ignore all images but last.

* Default dtype bfloat16.

* Update integration test after switch to bf16.

* Remove dead code.

* Removed dead code.

* Upgrade the flake to latest transformers/tokenizers

* Move to hf tgi-nix

* Upgrade to 0.5.0
2024-10-02 11:22:13 +02:00
Daniël de Kok
584b4d7a68
nix: experimental support for building a Docker container (#2470)
* nix: experimental support for building a Docker image

Run using something like:

```
docker run \
  --device nvidia.com/gpu=all \
  -it --rm -p 8080:80 \
  -v $PWD/data:/data \
  -v $PWD/tmp:/tmp \
  tgi-docker:latest \
  --model-id <model_id>
```

* Example of building the Docker image using Nix inside Docker

* Stream to make the builder image smaller

This avoids storing a Docker image tarball in the image. Instead,
stream the layers while doing `docker run`.

* Don't spam journalctl on Linux

* Other dockerfile.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-10-01 18:02:06 +02:00
Daniël de Kok
1c84a30fe6
MoE Marlin: support desc_act for groupsize != -1 (#2590)
This change uses the updated Marlin MoE kernel from vLLM to support
MoE with activation sorting and groups.
2024-09-30 19:40:25 +02:00
Daniël de Kok
d1f257ac56
Move flake back to tgi-nix main (#2586) 2024-09-30 11:39:41 +02:00
drbh
93a7042d7e
feat: support phi3.5 moe (#2479)
* feat: support phi3.5 moe model loading

* fix: prefer llama base model and improve rotary logic

* feat: return reasonable generation and add integration test

* fix: run lint and update docs

* fix: rerun lint for openapi docs

* fix: prefer do_sample false unless temp is set by user, and update chat tests

* fix: small typo adjustments

* fix: consolidate long rope paths

* fix: revert greedy by default and test changes

* Vendor configuration so that we don't have to `trust_remote_code`

* Use SparseMoELayer

* Add support for dense MoE

* Some type annotations

* Add the usual model tests

* Ruff.

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-09-30 11:15:09 +02:00
Daniël de Kok
90a1d04a2f
Add support for GPTQ-quantized MoE models using MoE Marlin (#2557)
This change add support for MoE models that use GPTQ quantization.
Currently only models with the following properties are supported:

- No `desc_act` with tensor parallelism, unless `group_size=-1`.
- No asymmetric quantization.
- No AWQ.
2024-09-30 11:14:32 +02:00
Mohit Sharma
f9e561eced
Update ROCM libs and improvements (#2579)
* style

* update torch

* ix issues

* fix clone

* revert mkl

* added custom PA

* style

* fix style

* style

* hide env vart

* fix mixtral model

* add skinny kernel and merge fixes

* fixed style

* fix issue for sliding window models

* addressed review comments

* fix import

* improved error messag

* updated default value

* remove import

* fix imports after rebase

* float16 dep

* improve dockerfile

* cleaned dockerfile
2024-09-30 10:54:32 +02:00
Ikram Ul Haq
e790cfc0e4
Update architecture.md (#2577) 2024-09-30 08:56:20 +02:00
Daniël de Kok
afc7ded84f
Remove compute capability lazy cell (#2580)
Remove compute capability lock

We are only calling the `get_cuda_capability` function once, so avoiding
the cost of multiple calls is not really necessary yet.
2024-09-30 08:48:47 +02:00
Daniël de Kok
1028996fb3
flashinfer: pass window size and dtype (#2574) 2024-09-28 18:41:41 +02:00
Daniël de Kok
5b6b74e21d
Improve support for GPUs with capability < 8 (#2575)
* Improve support for GPUs with capability < 8

- For models that cannot use flashinfer, use flash-attn v1 + paged
  attention for models with a compute capability older than 8.
- Disable prefix caching when using paged attention.
- When using flash-attn v1, pass the key/value, rather than the
  cache, since v1 cannot use block tables.

* nix: add flash-attn-v1 to the server environment

* Move disabling prefix caching into the block of exceptions

* Capability as `usize`s
2024-09-27 16:19:42 +02:00
Alvaro Bartolome
0aa66d693a
Fix build with --features google (#2566)
* Fix `cargo build --features google`

* Add `cargo test --features google`
2024-09-26 11:41:38 +02:00
Alvaro Bartolome
0b7df77178
Add LoRA adapters support for Gemma2 (#2567)
* Add LoRA adapters support for Gemma2

* Make `black` formatting happy
2024-09-26 10:54:08 +02:00
Nicholas Broad
7efcb5e0ed
remove LORA_ADAPTERS_PATH (#2563)
specify how to call local adapters
2024-09-25 01:20:15 +02:00
Nicolas Patry
dd8691b7c5
More tensor cores. (#2558)
* More tensor cores.

* Fixing the logic.

* Gemma is modified by this.
2024-09-24 23:57:26 +02:00
Nicolas Patry
c032280b17
Cleanup Vertex + Chat (#2553)
* Cleanup Vertex + Chat

* logprobs defaults to false.

* Parameters are optional

* Fix  docs.

* Changing back this logprobs default.

* Fixup doc.

* Let's debug that.

* Not unstable.

* Updating Cargo ?

* Wat?

* Dummy change.

* Trying some other install.

* Trying smething.

* Revert everything.

* Update Cargo lock.

* Fixing the pre-commit after rebase.
2024-09-24 23:37:17 +02:00
Nicolas Patry
75c8c54ac9
Hotfixing main. (#2562) 2024-09-24 23:00:43 +02:00
Aritra Roy Gosthipaty
e6d29656b5
Adding note for private models in quick-tour document (#2548)
* chore: adding note for private models in quicktour doc

* Update docs/source/quicktour.md

Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>

* Update docs/source/quicktour.md

Co-authored-by: vb <vaibhavs10@gmail.com>

* Update docs/source/quicktour.md

Co-authored-by: vb <vaibhavs10@gmail.com>

---------

Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
Co-authored-by: vb <vaibhavs10@gmail.com>
2024-09-24 15:06:53 +02:00
Orhun Parmaksız
8024ded58f
Simplify crossterm imports (#2545) 2024-09-24 14:57:20 +02:00
Orhun Parmaksız
03263f5e88
Update the link to the Ratatui organization (#2546) 2024-09-24 14:51:48 +02:00
Daniël de Kok
3f14cd1420
Add DenseMoELayer and wire it up in Mixtral/Deepseek V2 (#2537)
This replaces the custom layers in both models.
2024-09-24 14:27:06 +02:00
Daniël de Kok
c29dc89c18
Add support for scalar FP8 weight scales (#2550)
* Add support for scalar FP8 weight scales

* Support LLM compressor FP8 checkpoints on H100

On H100, we use fbgemm-gpu, which requires bfloat16 as the input dtype.
However, we wouldn't pick up fp8 quantization for models quantized with
LLM compressor. This change adds enough parsing to detect if models have
FP8-quantized weights.

* Remove stray debug print
2024-09-24 13:57:40 +02:00
Nicolas Patry
0ff6ff60ad
Hotfixing main (#2556) 2024-09-24 11:51:14 +02:00
Nicolas Patry
74d3ce106e
Micro cleanup. (#2555) 2024-09-24 11:19:24 +02:00
Alvaro Bartolome
d31a6f75cc
Remove duplicated RUN in Dockerfile (#2547) 2024-09-24 10:19:13 +02:00
OlivierDehaene
10e6f29295
chore: Add old V2 backend (#2551)
* wip

* added v2
2024-09-24 08:38:17 +02:00
Daniël de Kok
9263817c71
nix: remove unused _server.nix file (#2538) 2024-09-23 09:43:23 +02:00
Nicolas Patry
169178b937
Preparing for release. (#2540)
* Preparing for release.

* Upgrade version in docs.
2024-09-20 17:42:04 +02:00
OlivierDehaene
7e2d18877e
fix: wrap python basic logs in debug assertion in launcher (#2539)
* fix: wrap python basic logs in debug assertion in launcher

* use level filters instead
2024-09-20 14:59:31 +00:00
Wang, Yi
f478aa77ad
hotfix: ipex fails since cuda moe kernel is not supported (#2532)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-09-20 10:02:55 +02:00
Daniël de Kok
abd24dd385
doc: clarify that --quantize is not needed for pre-quantized models (#2536) 2024-09-19 22:17:15 +02:00
Daniël de Kok
c103760172
Update to moe-kenels 0.3.1 (#2535)
* Update to moe-kenels 0.3.1

* Attempt to fix apt failure
2024-09-19 22:16:32 +02:00
Nicolas Patry
f512021e77
Stream options. (#2533)
* Stream options.

* Fetch stuff from nix integration test for easier testing.

* Adding the assert.

* Only send the usage when asked for.

* Update the docs.

* Impure test because we need network.

* develop.

* Optional usage.

* Fixes.

* Workflow
2024-09-19 20:50:37 +02:00
Daniël de Kok
ce85efa968
Move to moe-kernels package and switch to common MoE layer (#2511)
* Move to moe-kernels package and switch to common MoE layer

This change introduces the new `moe-kernels` package:

- Add `moe-kernels` as a dependency.
- Introduce a `SparseMoELayer` module that can be used by MoE
  models.
- Port over Mixtral and Deepseek.

* Make `cargo check` pass

* Update runner
2024-09-17 18:08:58 +02:00
OlivierDehaene
86984e3236
fix: metrics unbounded memory (#2528) 2024-09-17 16:01:28 +00:00
Daniël de Kok
71e4268600
nix: pure Rust check/fmt/clippy/test (#2525)
Runs the tests in a Nix build sandbox.
2024-09-17 12:14:30 +02:00
Nicolas Patry
38fcafcf96
Adding a test for FD. (#2516)
* Adding a test for FD.

* Fixing flashdecoding (empty batch doesn't work).

* Fixing the invalid popping.

* Fixing radix with block_size > 1

* Last reference.

* Use an actual hash.

* Update hash for slice.len() == 1

* Update the locks.

* Increasing docker timeout.
2024-09-16 17:00:54 +02:00
Daniël de Kok
7774655297
Add tests for Mixtral (#2520)
Disable by default because CI runners do not have enough GPUs.
2024-09-16 12:39:18 +02:00
Alex Strick van Linschoten
9cca3e0b03
Use ratatui not (deprecated) tui (#2521)
* use ratatui not archived tui

* bump ratatui all the way with options
2024-09-13 18:45:28 +02:00
Wang, Yi
3ac7df2b6d
hotfix : enable intel ipex cpu and xpu in python3.11 (#2517)
enable intel ipex cpu and xpu in python3.11

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-09-12 17:23:49 +02:00
drbh
628334d336
fix: pass missing revision arg for lora adapter when loading multiple… (#2510)
fix: pass missing revision arg for lora adapter when loading multiple adapters
2024-09-12 17:04:52 +02:00
Nicolas Patry
d95c670ada
Add nix test. (#2513)
* Add nix test.

* Modifying yourself means you need to rerun.

* Fixing the test + adding click (needed for pre-commit hooks).

* Try thuis.

* Our runner + pure test (not written)

* Reemove server.

* Root user.

* Different user ?

* Add the actual test target.

* Forgot this modification.

* Add a formatter.

* Add the secrets.

* Fixed the auth token ?

* Adding the other tests.

* Missing pre-commit.

* Test requires cargo for cargo fmt.

* Update it a bit.

* Up.

* Attempting to use a cache location for the models.

* Ignore the cache for now.
2024-09-12 14:54:56 +02:00
Daniël de Kok
94304649f1
nix: support Python tokenizer conversion in the router (#2515)
Ideally we wouldn't have the router wrapper that this change adds,
but when I give PyO3 a Python interpreter with packages, it ends
up linking libpython from the Python interpreter rather than the
constructed environment and cannot pick up the Python modules as
a result.
2024-09-12 10:44:01 +02:00
Nicolas Patry
69e3be20fb
Fix truffle (#2514)
* Attempting to discard the trufflehog warning.

* Attempt to fix trufflehog.
2024-09-11 22:45:19 +02:00
Nicolas Patry
dae3bf1d87
Fix tokenization yi (#2507)
* Fixing odd tokenization self modifications on the Rust side (load and
resave in Python).

* Fixing the builds ?

* Fix the gh action?

* Fixing the location ?

* Validation is odd.

* Try a faster runner

* Upgrade python version.

* Remove sccache

* No sccache.

* Getting libpython maybe ?

* List stuff.

* Monkey it up.

* have no idea at this point

* Tmp.

* Shot in the dark.

* Tmate the hell out of this.

* Desperation.

* WTF.

* -y.

* Apparently 3.10 is not available anymore.

* Updating the dockerfile to make libpython discoverable at runtime too.

* Put back rust tests.

* Why do we want mkl on AMD ?

* Forcing 3.11 ?
2024-09-11 22:41:56 +02:00
Nicolas Patry
a4e3e8c608
Prefix test - Different kind of load test to trigger prefix test bugs. (#2490)
* Adding prefix test.

* [WIP] tmp dump of integration load tests.

* Remove other tensor creation.

* Fixed the radix tree.

Used a slice everywhere in radix.rs to keep the cheap Arc cloning
instead of recomputing the input_ids.

* Fix parsing

* Is it really flashinfer version ?

* Remove some comments.

* Revert the max prefix hit.

* Adding numpy to diff.

* Upgraded flashinfer.

* Upgrading some stuff.

* Are we done yet ?

* Minor fixup

* Remove 1 log and put back the other.

* Add comment for why slot 0 is OK.

* Mounting on the job.

* Get me a debug branch

* Debugging CIs is fun.

* Attempt #28

* wip

* Tmate.

* Praying.

* Updating VLM causal model with updated context.

* Important line got squashed.

* Tmate again.

* Fingers crossed.

* We want only 1 run of integration tests.....

---------

Co-authored-by: Guillaume LEGENDRE <glegendre01@gmail.com>
2024-09-11 18:10:40 +02:00
Vallepu Vamsi Krishna
eabbbbda23
Add Directory Check to Prevent Redundant Cloning in Build Process (#2486)
Update Makefile-fbgemm

Added Directory check for FBGEMM repository cloning.
2024-09-07 13:19:43 +02:00
Nicolas Patry
c1fe28d694
Fixing more correctly the invalid drop of the batch. (#2498) 2024-09-06 17:35:49 +02:00
Martin Iglesias Goyanes
aaea212d0f
Add links to Adyen blogpost (#2500)
* Add links to Adyen blogpost

* Adding to toctree.

* Update external.md

* Update _toctree.yml

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-09-06 17:00:54 +02:00
Daniël de Kok
a3c9c62dc0
hotfix: add syrupy to the right subproject (#2499) 2024-09-06 12:47:06 +02:00
Daniël de Kok
379472c4c2
radix trie: add assertions (#2491)
These should all be cheap assertions.

Also:

* Fixup some comments.
* Delete a `remove` that was done unnecessarily twice.
2024-09-06 11:55:23 +02:00
Daniël de Kok
2eb57a15ec
Fix incompatibility with latest syrupy and update in Poetry (#2497) 2024-09-06 11:00:52 +02:00
Daniël de Kok
0424e27f65
nix: add pyright/ruff for proper LSP in the impure devshell (#2496)
We need this to ensure that pyright/ruff are part of the same
interpreter/venv.
2024-09-06 10:19:04 +02:00
Wang, Yi
5cd8025f18
hotfix: fix regression of attention api change in intel platform (#2439)
fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache
format kv input now.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-09-05 17:41:39 +02:00
Daniël de Kok
e279b38aca
Add two handy gitignores for Nix environments (#2484) 2024-09-05 17:06:54 +02:00
Nicolas Patry
8b96a18265
Adding links to Adyen blogpost. (#2492) 2024-09-05 16:11:52 +02:00
Daniël de Kok
deec30f893
hotfix: avoid non-prefilled block use when using prefix caching (#2489)
The minimum batch size logic could cause prefix blocks to be
deallocated without prefill. The next allocation of the same
prefix would then use garbage blocks.
2024-09-05 15:09:29 +02:00
drbh
6cb42f49ae
feat: support lora revisions and qkv_proj weights (#2482)
* feat: support lora revisions and qkv_proj weights

* fix: add qkv_proj weights to weight test
2024-09-02 13:09:06 -04:00
drbh
47d7e34458
fix: enable chat requests in vertex endpoint (#2481)
* fix: enable chat requests in vertex endpoint

* feat: avoid unwrap and pre allocate future vec
2024-09-02 10:00:52 -04:00
Daniël de Kok
de2cdeca53
nix: add punica-kernels (#2477)
Enables LoRA support.
2024-09-02 11:31:36 +02:00
Daniël de Kok
e4ab855480
nix: improve impure devshell (#2478)
- Add some test dependencies.
- Install server in venv.
- Install Python client in venv.
2024-09-02 09:27:10 +02:00
Nicolas Patry
d9fbbaafb0
Tied embeddings in MLP speculator. (#2473)
* Tied embeddings in MLP speculator.

* Fixing the scale_weight when users decide to not use the speculation as
much as defined in the config.

* Adding scaling support + optimize some ops.
2024-08-29 17:44:54 +02:00
Wang, Yi
9883f3b40e
update doc with intel cpu part (#2420)
* update doc with intel cpu part

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Apply suggestions from code review

we do not use latest ever in documentation, it causes too many issues for users. Release number get update on every release.

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-08-29 17:42:02 +02:00
drbh
d5202c46f7
feat: add /v1/models endpoint (#2433)
* feat: add /v1/models endpoint

* feat: add /v1/models endpoint

* fix: remove unused type import

* fix: revert route typo

* fix: update docs with new endpoint

* fix: add to redocly ignore and lint
2024-08-29 16:32:38 +02:00
Nicolas Patry
e415b690a6
Lots of improvements (Still 2 allocators) (#2449)
* Making prefix/flashinfer the default and testing the full release tests.

* Include flashinfer in the docker.

* Using prebuilt.

* Allowing window_left_size (dummy version).

* Disabling flashinfer/prefix caching on odd head_dim

* Disable prefix caching for lora.

* More specific codes.

* Update lock

* Updating integration tests with new values with FI/FD.

Remove paged as a default too, and using FD everywhere.

* Update cargo lock ?

* Upgrade to 1.80 because of bitstream...

* Everywhere 1.80

* Forgot last default place.

* Apply suggestions from code review

Co-authored-by: drbh <david.richard.holtz@gmail.com>

* Updated flake lock

* Tmp

* Upgrade resolution system for less errors in resolution.

* Remove lambda for cleaner function.

* Handling debugger.

* OVerride the env in server tests.

* Is this enough to make it work ?

* This seems to be working.

* Downgrade some logs.

* Fixing the default for vlm.

* Don't enable prefix caching on VLM just yet.

* Change `add_special_tokens` in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)

* Fixing prefix caching for flashdecoding.

* Update all models.

* Fixed flashinfer version.

* add_special_tokens is internal only

* Fixing seqlen with the new vlms.

* Fixing the issue with `add_special_tokens` not being passed around.

* Fixing the test.

* Removing encoder_decoder (seq2seq).

* Update the chat test.

* Fixing the batching tokenization in flash causal lm.

* Truncating left for radix purposes.

* Oops this doesn't belong here.

* Put back default pure shell.

* Update server tests

- Default to throughput test in k6
- Use TGI_WIGGLE_ROOM to adjust wiggle room

* Only n_heads / process_group.size() are necessary.

* Revert the integrationt tests change (seem linked to head_size
modification).

* Adding error message when assert is violated.

* Fixing the free algorithm to handle times where the common prefix is
smaller.

* Apply suggestions from code review

Co-authored-by: OlivierDehaene <olivier@huggingface.co>

* Update server/text_generation_server/layers/attention/common.py

Co-authored-by: OlivierDehaene <olivier@huggingface.co>

* Fix disabling prefix caching - Fix windowing checks.

* Revert the Cohere tokenizer change (for now using a revision instead).

* Fmt.

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2024-08-29 16:29:01 +02:00
Daniël de Kok
4e821c003a
nix: build Torch against MKL and various other improvements (#2469)
Updates tgi-nix input:

- Move Torch closer to upstream by building against MKL.
- Remove compute capability 8.7 from Torch (Jetson).
- Sync nixpkgs cumpute capabilities with Torch (avoids
  compiling too mana capabilities for MAGMA).
- Use nixpkgs configuration passed through by `tgi-nix`.
2024-08-29 16:25:25 +02:00
drbh
8f99f165ce
fix: improve regex expression (#2468) 2024-08-28 13:44:44 -04:00
drbh
21187c27c9
fix: bump minijinja version and add test for llama 3.1 tools (#2463)
* fix: support tojson and avoid message indexing issue in template

* fix: prefer minijinja native methods and prefer workspace level dependency

* fix: adjust comment typo
2024-08-27 13:31:08 -04:00
Nicolas Patry
2788d41a76
Fixing CI. (#2462) 2024-08-27 15:33:02 +02:00
drbh
cfa73b5c99
Pr 2451 ci branch (#2454)
* fix[router]: Fix tools not passed in chat template

Signed-off-by: GitHub <noreply@github.com>

* feat: improve default tool serialization and lints

* feat: refactor tool logic to include notify_error in prompt and adjust typing

* fix: adjust non tool template apply

* fix: simplify tool grammar logic and improve schema

* feat: avoid skip tool test and avoid empty tool prompts

* fix: increase test client timeout for grammar compilation tests

---------

Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: Simone Rossi <simone.rossi.93@gmail.com>
2024-08-26 20:19:38 -04:00
drbh
30be188400
Fix: don't apply post layernorm in SiglipVisionTransformer (#2459)
* Fix: don't apply post layernorm in SiglipVisionTransformer

This fixes a bug with LLaVA Next when using Siglip as the vision model. LLaVA Next expects the output of the vision model to be the encoder outputs before layernorm (see original transformers implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L813).

This also makes Siglip consistent with the existing Clip implementation:

https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/custom_modeling/clip.py#L613

* fix: adjust pali gemma for post layer norm and small refactors

---------

Co-authored-by: Travis Addair <tgaddair@gmail.com>
2024-08-26 17:04:46 -04:00
Daniël de Kok
f3c5d7d92f
nix: add default package (#2453)
The default package wraps the launcher and puts the server/router in the
path.

As a result, TGI can be started using something like:

```
nix run .# -- \
  --model-id hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 \
  --port 8080
```
2024-08-23 22:06:22 +02:00
Daniël de Kok
358ceb67dd
nix: add awq-inference-engine as server dependency (#2442) 2024-08-21 22:20:03 +02:00
Nicolas Patry
310778e02a
Adding eetq to flake. (#2438) 2024-08-21 09:06:33 +02:00
Daniël de Kok
9474415095
nix: add text-generation-benchmark to pure devshell (#2431)
nix: add text-generation-benchmark to pure devshell
2024-08-21 07:48:13 +02:00
Daniël de Kok
f5f11b797e
nix: add pure server to flake, add both pure and impure devshells (#2430)
* nix: pure server and support both pure and impure devShells

* nix: remove unused poetry2nix input

It is not wired up and we now have a pure server.

* nix: add ipdb to impure devshell
2024-08-20 22:07:33 +02:00
Nicolas Patry
b70ae0969f
Prefix caching (#2402)
* Prefix caching WIP

* Fixing prefix attention.

* Fixing flashinfer import.

* Fixing black.

* Fixing medusa (still wrong outputs, but functional).

* Just medusa values now.

* Fixing medusa without prefix caching.

* Fixing prefix caching.

* Medusa requires reshaping.

* Removing the logs.

* Remove router.nix

* Fixup:

- Remove logs
- Disable VLMs (they do not work)
- Disable prefix caching when user wants prefill logprobs.

* Update flake.lock

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2024-08-20 11:15:30 +02:00
Daniël de Kok
38773453ae
nix: update to CUDA 12.4 (#2429)
* Update to CUDA 12.4

* poetry2nix: follow tgi-nix nixpkgs
2024-08-19 09:28:38 +02:00
Nicolas Patry
e4201f44cf
All integration tests back everywhere (too many failed CI). (#2428)
* All integration tests back everywhere (too many failed CI).

* Upgrade integration tests after 12.4

* Attempt to remove the specifed compute cap.

* Common arch list.

* Punica uses raw ASM which is not valid on 9.0 apparently.
2024-08-16 21:19:46 +02:00
Hugo Larcher
53729b74ac
doc: Add metrics documentation and add a 'Reference' section (#2230)
* doc: Add metrics documentation and add a 'Reference' section

* doc: Add API reference

* doc: Refactor API reference

* fix: Message API link

* Bad rebase

* Moving the docs.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-08-16 19:43:30 +02:00
Nicolas Patry
cb0a29484d
FIxing the CI. 2024-08-16 14:21:29 +02:00
Nicolas Patry
c7ab1810d4
Further fixes. (#2426)
* Further fixes.

* Update the conftest to allow NaN (first logprob).

* Fix the condition.
2024-08-16 13:21:44 +02:00
Vaibhav Srivastav
99b662f8c2
Improve the Consuming TGI + Streaming docs. (#2412)
* Improve the Consuming TGI docs.

* Fix erronous update to .

* add info about Open AI client.

* More updates.

* Apply suggestions from code review

Co-authored-by: Erik Kaunismäki <erik.kaum@gmail.com>

* Suggestions from Lucain.

* Update Gradio snippet.

* Up.

* Apply suggestions from code review

Co-authored-by: Lucain <lucainp@gmail.com>

* Update docs/source/basic_tutorials/consuming_tgi.md

Co-authored-by: Lucain <lucainp@gmail.com>

* Up.

* Apply suggestions from code review

Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>

* Up.

* Up.

* Doc review from Nico.

* Doc review from Nico. x2

* Last nit

---------

Co-authored-by: Erik Kaunismäki <erik.kaum@gmail.com>
Co-authored-by: Lucain <lucainp@gmail.com>
Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
2024-08-16 12:43:08 +02:00
Daniël de Kok
1411bfb989
nix: try to reduce the number of Rust rebuilds (#2424)
Try to reduce the number of router/launcher rebuilds by filtering
sources. In this way, recompiles should only be triggered by changes
in Cargo or Rust files.
2024-08-16 10:01:01 +02:00
Nicolas Patry
1b0aa06204
Upgrading the tests to match the current workings. (#2423) 2024-08-15 13:28:42 +02:00
Nicolas Patry
57b3495823
Fixing exl2 and other quanize tests again. (#2419)
* Fixing exl2 and other quanize tests again.

* Mark exl2 as non release (so CI tests them, needs to be removed latet).

* Fixing exl2 (by disabling cuda graphs)

* Fix quantization defaults without cuda graphs on exl2 (linked to new
issues with it).

* Removing serde override.

* Go back to released exl2 and remove log.

* Adding warnings for deprecated bitsandbytes + upgrade info to warn.
2024-08-15 11:12:51 +02:00
Daniël de Kok
9aaa12e7ac
nix: build router incrementally (#2422) 2024-08-15 10:21:51 +02:00
Funtowicz Morgan
3f385991b0
More fixes trtllm (#2342)
* (backend) use parking_lot crate for RwLock fairness

* (docker) let's put rust in the TRTLLM folder when building

* (docker) build ompi with SLURM support

* (launcher) default new server::run parameters to false for now

* (chore) fmt ... why?
2024-08-14 12:02:05 +02:00
Nicolas Patry
f3b5c69441
Upgrading exl2. (#2415)
* Upgrading exl2.

* Fixing the other pathways.

* Fix idefics.
2024-08-14 11:58:08 +02:00
Daniël de Kok
c5fff92b48
nix: partial incremental build of the router (#2416)
This is less incremental than crate2nix, but does build all dependencies
separately, so avoids full rebuilds.
2024-08-14 11:06:28 +02:00
drbh
1cebccc72b
fix: adds causal to attention params (#2408)
fix: adds causal to attention params to check when using flash attn v1
2024-08-13 16:19:46 +02:00
Wang, Yi
59922f9bc1
add numa to improve cpu inference perf (#2330)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-08-13 15:33:55 +02:00
Nicolas Patry
cd9b15d17f
Adding more kernels to flake. (#2411) 2024-08-13 10:49:18 +02:00
Daniël de Kok
6f4bb4f26f
nix: incremental build of the launcher (#2410) 2024-08-13 10:44:15 +02:00
drbh
8a7749b8fb
fix: include create_exllama_buffers and set_device for exllama (#2407) 2024-08-12 17:59:37 -04:00
drbh
9a7830bd28
Pr 2395 ci run (#2406)
* fix(router): Fix appending to message content

* feat: add message and chat template test

---------

Co-authored-by: Simone Rossi <simone.rossi.93@gmail.com>
2024-08-12 14:38:59 -04:00
Nicolas Patry
19ea85f8dc
Updating the flake. (#2404) 2024-08-12 18:09:16 +02:00
drbh
30395b09f4
fix: improve completions to send a final chunk with usage details (#2336)
* fix: improve completions to send a final chunk with usage details

* fix: include finish reason string

* fix: remove dev debug trait and unneeded mut

* fix: update openapi schema
2024-08-12 17:26:11 +02:00
drbh
4c3f8a70a1
fix: allocate tmp based on sgmv kernel if available (#2345)
* fix: allocate tmp based on sgmv kernel if available

* fix: re add copy build artifacts step for punica kernels
2024-08-12 17:24:32 +02:00
drbh
155f9c98e2
feat: validate template variables before apply and improve sliding wi… (#2403)
* feat: validate template variables before apply and improve sliding window check

* fix: improve missing template var test
2024-08-12 10:58:40 -04:00
Nicolas Patry
136bcc8128
Keeping the benchmark somewhere (#2401)
Co-authored-by: Daniël de Kok <me@danieldk.eu>
2024-08-12 15:22:02 +02:00
Daniël de Kok
8deeaca4ff
Add support for prefix caching to the v3 router (#2392)
This change adds support for prefix caching to the v3 router. This
is broken up from the backend support to ease reviewing.

For now prefix caching is only enabled with `USE_PREFIX_CACHING=1`
in this case, the router will switch to `RadixAllocator`. This
allocator uses a radix trie to keep track of prefills that were
seen prior. If a new prefill is a prefix of a previously-seen
prefil, the router will send a request with `prefix_len>0`, which
can be used by the backend to decide to reuse KV blocks from the
cache, rather than recomputing them.

Even though backend support is not added in this PR, the backend
will still work with prefix caching enabled. The prefix lengths
are just ignored and not used.
2024-08-12 14:59:17 +02:00
Wang, Yi
b6bb1d5160
Cpu dockerimage (#2367)
add intel-cpu docker image

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-08-12 14:10:30 +02:00
Nicolas Patry
84bc3d7b7d
Fixing import exl2 (#2399) 2024-08-12 14:08:59 +02:00
Nicolas Patry
730fa00e20
Adding launcher to build. (#2397) 2024-08-12 14:08:46 +02:00
Nicolas Patry
9c739651cd
Upgrade fbgemm (#2398)
* Upgrade fbgemm

* Fix fbgemm version
2024-08-12 14:08:38 +02:00
Daniël de Kok
01a515dea2
nix: add router to the devshell (#2396) 2024-08-12 09:28:38 +02:00
Daniël de Kok
8dcc7d3f6b
Update flake for 9.0a capability in Torch (#2394) 2024-08-09 22:36:51 +02:00
drbh
0d06aed02d
feat: add guideline to chat request and template (#2391)
* feat: add guideline to chat request and template

* fix: add template test and update docs
2024-08-09 10:56:45 -04:00
Nicolas Patry
7a48a84784
Using an enum for flash backens (paged/flashdecoding/flashinfer) (#2385)
* Using an enum for flash backens (paged/flashdecoding/flashinfer)

* Early exit on server too.

* Clippy.

* Fix clippy and fmt.
2024-08-09 16:41:17 +02:00
Daniël de Kok
6e127dcc96
flake: use rust-overlay (#2390) 2024-08-09 15:24:21 +02:00
Vaibhav Srivastav
b2b9c42724
Update documentation for Supported models (#2386)
* Minor doc fixes

* up.

* Other minor updates.
2024-08-09 15:01:34 +02:00
Daniël de Kok
977534bcb8
flake: add fmt and clippy (#2389) 2024-08-09 14:56:20 +02:00
Nicolas Patry
952b450a3b
Using HF_HOME instead of CACHE to get token read in addition to models. (#2288) 2024-08-09 14:25:44 +02:00
Daniël de Kok
c6d5039cd7
Add experimental flake (#2384)
Add flake.nix
2024-08-09 12:32:37 +02:00
Daniël de Kok
7830de1566
Add FlashInfer support (#2354)
This change adds support for FlashInfer. FlashInfer can be enabled using
`FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`.
Since this functionality is currently only for testing, FlashInfer is
not installed anywhere yet.

The FlashInfer API is quite different from FlashAttention/vLLM in that
it requires more global bookkeeping:

* A wrapper class needs to be contstructed (which we just call *state*).
  Since this is fairly expensive (due to pinned host memory allocation),
  we only do this once in a FlashCausalLM instance or for each CUDA
  Graph size.
* Each model forward call needs to be wrapped in `begin_forward` and
  `end_forward`. This sets up data structures that can be reused for all
  calls to attention for that forward call.

When calling attention, we need access to the state object. To avoid
passing an argument down the call chain (which would require changes to
all models), we use a context variable.

Each model forward call is wrapped using a context manager that does all
the bookkeeping for such a call:

* Set the context variable to the forward call's state.
* Call `begin_forward` on the state.
* Yield.
* Call `end_forward` on the state.
* Reset the context variable.

We cannot use a single shared global variable for this, since e.g. CUDA
Graphs of different sizes each have their own state.
2024-08-09 11:42:00 +02:00
drbh
6d06473cf4
Pr 2352 ci branch (#2382)
* Fix unsigned integer underflow

Passing --max-batch-size to the launcher actually had no effect
because after a few requests the max_size passed to State::next_batch
would underflow becoming a largo positive number.

In the scheduler, as soon as the cached batch size reached the
max_batch_size the max_size passed to next_batch becomes 0.
Since the only check in that funcion is
```
if Some(batch_requests.len()) == max_size {
    break;
}
```
and it's called after the `batch_requests.len()` has
become 1, it doesn't do anything to prevent more than 0
requests from being batched.

Now we have cached batch in the server that is large than
max_batch_size and `max_size - batch_size as usize`
underflows.

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>

* fix: update v3 scheduler and ensure max_batch_size > 0

---------

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
2024-08-09 10:54:32 +02:00
Vaibhav Srivastav
cb3ae30284
Update Quantization docs and minor doc fix. (#2368)
* Update Quantization docs and minor doc fix.

* update readme with latest quants info

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2024-08-08 16:06:57 -04:00
drbh
f852190060
fix: prefer hidden_activation over hidden_act in gemma2 (#2381) 2024-08-08 14:08:56 -04:00
drbh
2ca5980634
Pr 2337 ci branch (#2379)
* hotfix: fix xpu crash brought by code refine. torch.xpu rely on import ipex

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* reable gemma2 in xpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix in regression in ipex flashattention

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
2024-08-08 12:30:29 -04:00
Wang, Yi
689b1abbf6
fix EleutherAI/gpt-neox-20b does not work in tgi (#2346)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-08-08 12:08:52 -04:00
drbh
82d19d7723
Pr 2374 ci branch (#2378)
* Update __init__.py

Fix issue with NoneType comparison for max_input_tokens and sliding_window

- Add default values for max_input_tokens and sliding_window to handle None cases.
- Ensure the comparison between max_input_tokens and sliding_window is handled correctly to prevent TypeError.
- This change addresses the error: TypeError: '<=' not supported between instances of 'int' and 'NoneType'.

* Update __init__.py

Handle NoneType in sliding_window comparison to fix TypeError in __init__.py by ensuring the comparison logic accounts for NoneType values, preventing errors and improving code robustness.

* fix: syntax/style tweak

---------

Co-authored-by: Praz <prazanth2006@gmail.com>
2024-08-08 11:14:06 -04:00
drbh
a379d5536b
Fix the prefix for OPT model in opt_modelling.py #2370 (CI RUN) (#2371)
* Fix the bug

* fix: run lints

* fix: small syntax tweak

---------

Co-authored-by: Sadra Barikbin <sadraqazvin1@yahoo.com>
2024-08-07 23:14:02 -04:00
drbh
21267f3ca3
add gptj modeling in TGI #2366 (CI RUN) (#2372)
* add gptj modeling

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix: update docs for model addition

* fix: adjust syntax typo

* fix: adjust syntax typo again

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
2024-08-07 21:32:37 -04:00
almersawi
8094ecfc9e
fix: fix num_ln_in_parallel_attn attribute name typo in RWConfig (#2350)
Co-authored-by: Islam Almersawi <islam.almersawi@openinnovation.ai>
2024-08-07 19:45:23 -04:00
drbh
133015f408
fix: prefer original layernorm names for 180B (#2365) 2024-08-06 15:25:30 -04:00
drbh
a64d407d64
fix: default num_ln_in_parallel_attn to one if not supplied (#2364) 2024-08-06 13:33:22 -04:00
drbh
1768c00b9f
feat: return the generated text when parsing fails (#2353) 2024-08-06 13:10:19 -04:00
drbh
f8a5b381fe
feat: prefer stop over eos_token to align with openai finish_reason (#2344) 2024-08-06 13:09:50 -04:00
drbh
e11f5f1c38
feat: implement a templated endpoint for visibility into chat requests (#2333)
* feat: implement a templated endpoint for visibility into chat requests

* feat: improve to tokenize too

* fix: adjust return type

* feat: simplify prepare_chat_input logic and adjust start stop chars
2024-08-06 13:51:32 +02:00
drbh
29b8d19cdf
fix: return the out tensor rather then the functions return value (#2361) 2024-08-06 13:49:53 +02:00
drbh
dd47a3dac4
feat: include local lora adapter loading docs (#2359) 2024-08-05 12:36:44 -04:00
drbh
215ed3ad52
fix: attempt forward on flash attn2 to check hardware support (#2335)
* fix: attempt forward on flash attn2 to check hardware support

* fix: warn window_size_left when using flash attn 1

* fix: prefer version check over test op and avoid window_size_left if not flash attn2

* fix: improve condtional and error message

* fix: update sliding window conditional

* fix: simplify changes and revert model changes

* fix: avoid changing conditional

* fix: typo tweak
2024-08-05 09:11:40 -04:00
Daniël de Kok
47447ef017
Unify attention output handling (#2343)
- Always return the hidden states.
- Create the output tensor inside the `attention` and `paged_attention`
  functions.

This removes the difference between how the output is handled between
attention (output parameter) and paged attention (return value). This
also removes the assumption that the attention implementation can
write to an output tensor (in preparation of FlashInfer).
2024-08-01 17:03:28 +02:00
Daniël de Kok
22fb1be588
Fix cache block size for flash decoding (#2351)
* Fix cache block size for flash decoding

This seems to have been accidentally dropped during the TRT-LLM
PR rebase.

* Also run CI on changes to `backends`
2024-08-01 15:38:57 +02:00
Wang, Yi
9ab9937414
enable HuggingFaceM4/idefics-9b in intel gpu (#2338)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-08-01 11:08:36 +02:00
Erik Kaunismäki
7451041ecd
refactor usage stats (#2339)
* refactor usage stats

* Update docs/source/usage_statistics.md

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* Update router/src/server.rs

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* changes based on feedback

* run python3 udpate_doc.py

* fix pre-commit

* Update router/src/server.rs

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* delete option around usage stats arg

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-07-31 16:29:07 +02:00
drbh
f7f61876cf
Pr 2290 ci run (#2329)
* MODEL_ID propagation fix

* fix: remove global model id

---------

Co-authored-by: root <root@tw031.pit.tensorwave.lan>
2024-07-31 10:27:15 -04:00
Daniël de Kok
34f7dcfd80
Handle GPTQ-Marlin loading in GPTQMarlinWeightLoader (#2300)
The `GPTWeightLoader` was structured like this in pseudocode:

if marlin:
  Set up tensors in a way that GPTQ-Marlin expects
else:
  Set up tensors in a way that ExLlama/GPTQ/AWQ expect

However, the GPT-Marlin implementation details should really be in the
`marlin` module. So move the former part out to a separate
`GPTQMarlinWeightsLoader`.
2024-07-31 13:08:41 +02:00
Nicolas Patry
2b19d671b4
Rebase TRT-llm (#2331)
* wip

wip

refacto

refacto

Initial setup for CXX binding to TRTLLM

Working FFI call for TGI and TRTLLM backend

Remove unused parameters annd force tokenizer name to be set

Overall build TRTLLM and deps through CMake build system

Enable end to end CMake build

First version loading engines and making it ready for inference

Remembering to check how we can detect support for chunked context

Move to latest TensorRT-LLM version

Specify which default log level to use depending on CMake build type

make leader executor mode working

unconditionally call InitializeBackend on the FFI layer

bind to CUDA::nvml to retrieve compute capabilities at runtime

updated logic and comment to detect cuda compute capabilities

implement the Stream method to send new tokens through a callback

use spdlog release 1.14.1 moving forward

update trtllm to latest version a96cccafcf6365c128f004f779160951f8c0801c

correctly tell cmake to build dependent tensorrt-llm required libraries

create cmake install target to put everything relevant in installation folder

add auth_token CLI argument to provide hf hub authentification token

allow converting huggingface::tokenizers error to TensorRtLlmBackendError

use correct include for spdlog

include guard to build example in cmakelists

working setup of the ffi layer

remove fmt import

use external fmt lib

end to end ffi flow working

make sure to track include/ffi.h to trigger rebuild from cargo

impl the rust backend which currently cannot move the actual computation in background thread

expose shutdown function at ffi layer

impl RwLock scenario for TensorRtLllmBackend

oops missing c++ backend definitions

compute the number of maximum new tokens for each request independently

make sure the context is not dropped in the middle of the async decoding.

remove unnecessary log

add all the necessary plumbery to return the generated content

update invalid doc in cpp file

correctly forward back the log probabilities

remove unneeded scope variable for now

refactor Stream impl for Generation to factorise code

expose the internal missing start/queue timestamp

forward tgi parameters rep/freq penalty

add some more validation about grammar not supported

define a shared struct to hold the result of a decoding step

expose information about potential error happening while decoding

remove logging

add logging in case of decoding error

make sure executor_worker is provided

add initial Dockerfile for TRTLLM backend

add some more information in CMakeLists.txt to correctly install executorWorker

add some more information in CMakeLists.txt to correctly find and install nvrtc wrapper

simplify prebuilt trtllm libraries name definition

do the same name definition stuff for tensorrt_llm_executor_static

leverage pkg-config to probe libraries paths and reuse new install structure from cmake

fix bad copy/past missing nvinfer linkage direction

align all the linker search dependency

add missing pkgconfig folder for MPI in Dockerfile

correctly setup linking search path for runtime layer

fix missing / before tgi lib path

adding missing ld_library_path for cuda stubs in Dockerfile

update tgi entrypoint

commenting out Python part for TensorRT installation

refactored docker image

move to TensorRT-LLM v0.11.0

make docker linter happy with same capitalization rule

fix typo

refactor the compute capabilities detection along with num gpus

update TensorRT-LLM to latest version

update TensorRT install script to latest

update build.rs to link to cuda 12.5

add missing dependant libraries for linking

clean up a bit

install to decoder_attention target

add some custom stuff for nccl linkage

fix envvar CARGO_CFG_TARGET_ARCH set at runtime vs compile time

use std::env::const::ARCH

make sure variable live long enough...

look for cuda 12.5

add some more basic info in README.md

* Rebase.

* Fix autodocs.

* Let's try to enable trtllm backend.

* Ignore backends/v3 by default.

* Fixing client.

* Fix makefile + autodocs.

* Updating the schema thing + redocly.

* Fix trtllm lint.

* Adding pb files ?

* Remove cargo fmt temporarily.

* ?

* Tmp.

* Remove both check + clippy  ?

* Backporting telemetry.

* Backporting 457fb0a1

* Remove PB from git.

* Fixing PB with default member backends/client

* update TensorRT-LLM to latest version

* provided None for api_key

* link against libtensorrt_llm and not libtensorrt-llm

---------

Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
Co-authored-by: Morgan Funtowicz <morgan@huggingface.co>
2024-07-31 10:33:10 +02:00
Daniël de Kok
53aec27328
server quantize: store quantizer config in standard format (#2299)
- Create `quantization_config` option in the model config.
- Don't store the quantizer config in tensors anymore.
2024-07-30 15:16:20 +02:00
drbh
0b95693fb8
fix: adjust test snapshots and small refactors (#2323)
* fix: adjust test snapshots and small refactors

* fix: revert non snapshot changes
2024-07-29 11:38:38 -04:00
Erik Kaunismäki
3d7f4f41bb
patch-error-on-invalid-grammar (#2282)
* quick fix

* allow silent failure

* explicit todo that this is only short term
2024-07-29 10:09:25 -04:00
drbh
f15e808d4c
fix: reject grammars without properties (#2309) 2024-07-29 10:07:25 -04:00
Daniël de Kok
922732b255
Install Marlin from standalone package (#2320) 2024-07-29 15:37:10 +02:00
Erik Kaunismäki
583d37a2f8
Run ci api key (#2315)
* Add API_Key for Auth and conditionally add authorisation for non info/health endpoints.

* change name to info routes

* Fix comment

* convert strings to lowercase for case insensitive comparison

* convert header to string

* fixes and update docs

* update docs again

* revert wrong update

---------

Co-authored-by: Kevin Duffy <kevin.duffy94@gmail.com>
2024-07-29 11:14:17 +02:00
Adrien
fd2e06316d
fix: fix buildkit config in ci
Signed-off-by: Adrien <adrien@huggingface.co>
2024-07-29 09:25:56 +02:00
drbh
bab02ff2bc
feat: add ruff and resolve issue (#2262)
* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
2024-07-26 10:29:09 -04:00
Daniël de Kok
4b49c50f4c
Support tied embeddings in 0.5B and 1.5B Qwen2 models (#2313) 2024-07-26 14:57:24 +02:00
Adrien
3905f854ed
Fix registry name (#2307) 2024-07-25 16:06:00 +02:00
Nicolas Patry
17ed42be3a
Fixing idefics on g6 tests. (#2306) 2024-07-25 14:44:21 +02:00
Daniël de Kok
9256d7c38c
Some small fixes for the Torch 2.4.0 update (#2304)
* Fix GPTQ autotune data type to be compatible with Torch 2.4.0

* Update poetry lock file

* Fix small PaliGemma logprob differences after the torch update
2024-07-25 13:34:44 +02:00
Nicolas Patry
26614057a7
Using g6 instead of g5. (#2281)
* Using g6 instead of g5.

* Update the idefics2 snapshot.
2024-07-25 11:21:17 +02:00
drbh
5d85a958c9
fix: refactor adapter weight loading and mapping (#2193)
* fix: refactor adapter weight loading and mapping

* feat: enable lora load from directory

* fix: adjust launcher for local lora adapters

* feat: improve weight loading and add tests

* fix: improve logging and rebase syntax issue

* fix: impove adapter merge comments and remove unused conditional

* fix: improve get_model_with_lora_adapters naming

* fix: comment typo
2024-07-24 15:32:14 -04:00
Daniël de Kok
93d2b9fe9c
Split up layers.marlin into several files (#2292)
The marlin.py file was getting large, split it up.
2024-07-24 16:33:26 +02:00
Wang, Yi
8642250602
fix of use of unquantized weights in cohere GQA loading, also enable … (#2291)
fix of use of unquantized weights in cohere GQA loading, also enable the model in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-07-24 10:44:02 +02:00
Wang, Yi
5ad39dd3c3
fix crash in multi-modal (#2245)
* fix crash in multi-modal

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update according to review comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix llava_next regression in latest main

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-07-24 10:39:08 +02:00
OlivierDehaene
a895029424
hotfix: update nccl 2024-07-23 23:31:28 +02:00
OlivierDehaene
e7e3aa6cac
chore: update to torch 2.4 (#2259)
* chore: update to torch 2.4

* remove un-necessary patch

* fix
2024-07-23 20:39:43 +00:00
761 changed files with 118987 additions and 98006 deletions

View File

@ -2,3 +2,6 @@ aml
target target
server/transformers server/transformers
server/flash-attention server/flash-attention
cmake-build-debug/
cmake-build-release/
Dockerfile*

View File

@ -28,7 +28,7 @@ jobs:
- name: Install router - name: Install router
id: install-router id: install-router
run: cargo install --path router/ run: cargo install --path backends/v3/
- uses: actions/setup-node@v4 - uses: actions/setup-node@v4
with: with:
@ -41,5 +41,5 @@ jobs:
- name: Check that documentation is up-to-date - name: Check that documentation is up-to-date
run: | run: |
npm install -g swagger-cli npm install -g @redocly/cli
python update_doc.py --check python update_doc.py --check

View File

@ -6,10 +6,11 @@ on:
hardware: hardware:
type: string type: string
description: Hardware description: Hardware
# options: # options:
# - cuda # - cuda
# - rocm # - cuda-trtllm
# - intel # - rocm
# - intel
required: true required: true
release-tests: release-tests:
description: "Run release integration tests" description: "Run release integration tests"
@ -21,68 +22,141 @@ jobs:
build-and-push: build-and-push:
outputs: outputs:
docker_image: ${{ steps.final.outputs.docker_image }} docker_image: ${{ steps.final.outputs.docker_image }}
docker_volume: ${{ steps.final.outputs.docker_volume }}
docker_devices: ${{ steps.final.outputs.docker_devices }} docker_devices: ${{ steps.final.outputs.docker_devices }}
runs_on: ${{ steps.final.outputs.runs_on }} runs_on: ${{ steps.final.outputs.runs_on }}
label: ${{ steps.final.outputs.label }} label_extension: ${{ steps.final.outputs.label_extension }}
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
concurrency: concurrency:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
runs-on: runs-on:
group: aws-highmemory-32-plus-priv group: aws-highmemory-64-plus-priv
permissions: permissions:
contents: write contents: write
packages: write packages: write
# This is used to complete the identity challenge
# with sigstore/fulcio when running outside of PRs.
id-token: write id-token: write
security-events: write
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Inject slug/short variables - name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1 uses: rlespinasse/github-slug-action@v4.4.1
- name: Construct harware variables - name: Inject required variables for sccache to interact with Github Actions Cache
uses: actions/github-script@v7
with:
script: |
core.exportVariable('ACTIONS_RESULTS_URL', process.env.ACTIONS_RESULTS_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
- name: Extract TensorRT-LLM version
run: |
echo "TENSORRT_LLM_VERSION=$(grep -oP '([a-z,0-9]{40})' $GITHUB_WORKSPACE/backends/trtllm/cmake/trtllm.cmake)" >> $GITHUB_ENV
echo "TensorRT-LLM version: ${{ env.TENSORRT_LLM_VERSION }}"
- name: Construct hardware variables
shell: bash shell: bash
run: | run: |
case ${{ inputs.hardware }} in case ${{ inputs.hardware }} in
cuda) cuda)
export dockerfile="Dockerfile" export dockerfile="Dockerfile"
export label_extension="" export label_extension=""
export docker_volume="/mnt/cache"
export docker_devices="" export docker_devices=""
export runs_on="aws-g5-12xlarge-plus" export runs_on="aws-g6-12xl-plus-priv-cache"
export platform=""
export extra_pytest=""
export target=""
;;
cuda-trtllm)
export dockerfile="Dockerfile_trtllm"
export label_extension="-trtllm"
export docker_volume="/mnt/cache"
export docker_devices=""
export runs_on="ubuntu-latest"
export platform=""
export extra_pytest=""
if [[ "${GITHUB_REF}" == refs/tags/* ]]; then
export build_type="release";
export target="";
else
export build_type="dev";
export target="ci-runtime";
fi
;; ;;
rocm) rocm)
export dockerfile="Dockerfile_amd" export dockerfile="Dockerfile_amd"
export label_extension="-rocm" export label_extension="-rocm"
export docker_devices="/dev/kfd,/dev/dri" export docker_devices="/dev/kfd,/dev/dri"
# TODO Re-enable when they pass. export docker_volume="/mnt"
# export runs_on="amd-gpu-tgi" # This runner was deactivated.
export runs_on="ubuntu-latest" export runs_on="ubuntu-latest"
export platform=""
export extra_pytest="-k test_flash_gemma_gptq_load"
export target=""
;; ;;
intel) intel-xpu)
export dockerfile="Dockerfile_intel" export dockerfile="Dockerfile_intel"
export label_extension="-intel" export label_extension="-intel-xpu"
export docker_devices=""
export docker_volume="/mnt/cache"
export runs_on="ubuntu-latest"
export platform="xpu"
export extra_pytest=""
export target=""
;;
intel-cpu)
export dockerfile="Dockerfile_intel"
export label_extension="-intel-cpu"
export docker_devices="none"
export docker_volume="/mnt/cache"
# export runs_on="ubuntu-latest"
export runs_on="aws-highmemory-32-plus-priv"
export platform="cpu"
export extra_pytest="-k test_flash_gemma_simple"
export target=""
;;
neuron)
export dockerfile="Dockerfile.neuron"
export label_extension="-neuron"
export docker_devices="/dev/neuron0"
export docker_volume="/mnt/cache"
export runs_on="aws-inf2-8xlarge"
export platform="cpu"
export extra_pytest="--neuron"
export target=""
;;
gaudi)
export dockerfile="Dockerfile_gaudi"
export label_extension="-gaudi"
export docker_volume="/mnt/cache"
export docker_devices="" export docker_devices=""
export runs_on="ubuntu-latest" export runs_on="ubuntu-latest"
;; export platform=""
export extra_pytest=""
export target=""
esac esac
echo $dockerfile echo $dockerfile
echo "Dockerfile=${dockerfile}" echo "Dockerfile=${dockerfile}"
echo $label_extension echo $label_extension
echo $docker_devices echo $docker_devices
echo $runs_on echo $runs_on
echo $platform
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
echo "LABEL=${label_extension}" >> $GITHUB_ENV echo "LABEL_EXTENSION=${label_extension}" >> $GITHUB_ENV
echo "PLATFORM=${platform}" >> $GITHUB_ENV
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
echo "TARGET=${target}" >> $GITHUB_ENV
echo "BUILD_TYPE=${build_type}" >> $GITHUB_ENV
- name: Initialize Docker Buildx - name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with: with:
install: true install: true
buildkitd-config-inline: | buildkitd-config: /tmp/buildkitd.toml
[registry."docker.io"]
mirrors = ["registry-us-east-1-mirror.prod.aws.ci.huggingface.tech"]
- name: Login to internal Container Registry - name: Login to internal Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
username: ${{ secrets.REGISTRY_USERNAME }} username: ${{ secrets.REGISTRY_USERNAME }}
@ -95,6 +169,12 @@ jobs:
registry: ghcr.io registry: ghcr.io
username: ${{ github.actor }} username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Login to Docker Hub Container Registry
uses: docker/login-action@v3
with:
registry: docker.io
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Login to Azure Container Registry - name: Login to Azure Container Registry
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
uses: docker/login-action@v3 uses: docker/login-action@v3
@ -109,10 +189,9 @@ jobs:
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: | images: |
registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference docker.io/huggingface/text-generation-inference-ci
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
tags: | tags: |
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
# If main, release or tag # If main, release or tag
- name: Extract metadata (tags, labels) for Docker - name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name != 'pull_request' }} if: ${{ github.event_name != 'pull_request' }}
@ -120,17 +199,16 @@ jobs:
uses: docker/metadata-action@v4.3.0 uses: docker/metadata-action@v4.3.0
with: with:
flavor: | flavor: |
latest=auto latest=false
images: | images: |
registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference registry.internal.huggingface.tech/api-inference/community/text-generation-inference
registry.internal.huggingface.tech/api-inference/community/text-generation-inferenceca
ghcr.io/huggingface/text-generation-inference ghcr.io/huggingface/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags: | tags: |
type=semver,pattern={{version}}${{ env.LABEL }} type=semver,pattern={{version}}${{ env.LABEL_EXTENSION }}
type=semver,pattern={{major}}.{{minor}}${{ env.LABEL }} type=semver,pattern={{major}}.{{minor}}${{ env.LABEL_EXTENSION }}
type=raw,value=latest${{ env.LABEL }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} type=raw,value=latest${{ env.LABEL_EXTENSION }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
- name: Build and push Docker image - name: Build and push Docker image
id: build-and-push id: build-and-push
uses: docker/build-push-action@v4 uses: docker/build-push-action@v4
@ -141,28 +219,41 @@ jobs:
platforms: 'linux/amd64' platforms: 'linux/amd64'
build-args: | build-args: |
GIT_SHA=${{ env.GITHUB_SHA }} GIT_SHA=${{ env.GITHUB_SHA }}
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
PLATFORM=${{ env.PLATFORM }}
build_type=${{ env.BUILD_TYPE }}
sccache_gha_enabled=on
actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
target: ${{ env.TARGET }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=max,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
- name: Final - name: Final
id: final id: final
run: | run: |
echo "docker_image=registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
if [ "${{ github.event_name }}" = "pull_request" ]; then
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
else
echo "docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
fi
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT" echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT" echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "label_extension=${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
integration_tests: echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
precompile_neuron_models:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
needs: build-and-push needs: build-and-push
if: needs.build-and-push.outputs.label_extension == '-neuron'
runs-on: runs-on:
group: ${{ needs.build-and-push.outputs.runs_on }} group: ${{ needs.build-and-push.outputs.runs_on }}
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env: env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }} PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
@ -171,15 +262,66 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.11"
- name: Install
run: |
make install-integration-tests
- name: Export neuron models
run: |
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
echo $DOCKER_IMAGE
docker pull $DOCKER_IMAGE
export HF_TOKEN=${{ secrets.HF_TOKEN_NEURON }}
python integration-tests/fixtures/neuron/export_models.py
integration_tests:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs: [precompile_neuron_models, build-and-push]
if: ${{ always() && !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && needs.build-and-push.outputs.runs_on != 'ubuntu-latest' }}
runs-on:
group: ${{ needs.build-and-push.outputs.runs_on }}
env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install - name: Install
run: | run: |
make install-integration-tests make install-integration-tests
- name: Run tests - name: Run tests
run: | run: |
export DOCKER_VOLUME=/mnt/cache export DOCKER_VOLUME=${{ needs.build-and-push.outputs.docker_volume }}
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }} export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export EXTRA_PYTEST="${{ needs.build-and-push.outputs.extra_pytest }}"
export HF_TOKEN=${{ secrets.HF_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE echo $DOCKER_IMAGE
pytest -s -vv integration-tests ${PYTEST_FLAGS} docker pull $DOCKER_IMAGE
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
backend_trtllm_cxx_tests:
needs: build-and-push
if: needs.build-and-push.outputs.label_extension == '-trtllm'
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-trtllm-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
group: aws-g6-12xl-plus-priv-cache
container:
image: ${{ needs.build-and-push.outputs.docker_image }}
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
options: --gpus all --shm-size=8g
steps:
- name: Run C++/CUDA tests
if: ${{ env.LABEL_EXTENSION == 'ci-runtime' }}
run: /usr/local/tgi/bin/tgi_trtllm_backend_tests

View File

@ -11,7 +11,7 @@ concurrency:
jobs: jobs:
build: build:
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yaml@main uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with: with:
commit_sha: ${{ github.event.pull_request.head.sha }} commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }} pr_number: ${{ github.event.number }}

View File

@ -10,6 +10,7 @@ on:
paths: paths:
- ".github/workflows/build.yaml" - ".github/workflows/build.yaml"
- "integration-tests/**" - "integration-tests/**"
- "backends/**"
- "server/**" - "server/**"
- "proto/**" - "proto/**"
- "router/**" - "router/**"
@ -19,6 +20,8 @@ on:
- "Dockerfile" - "Dockerfile"
- "Dockerfile_amd" - "Dockerfile_amd"
- "Dockerfile_intel" - "Dockerfile_intel"
- "Dockerfile.neuron"
- "Dockerfile_gaudi"
branches: branches:
- "main" - "main"
workflow_dispatch: workflow_dispatch:
@ -36,8 +39,12 @@ jobs:
# fail-fast is true by default # fail-fast is true by default
fail-fast: false fail-fast: false
matrix: matrix:
hardware: ["cuda", "rocm", "intel"] hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron", "gaudi"]
uses: ./.github/workflows/build.yaml # calls the one above ^ uses: ./.github/workflows/build.yaml # calls the one above ^
permissions:
contents: write
packages: write
id-token: write
with: with:
hardware: ${{ matrix.hardware }} hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206 # https://github.com/actions/runner/issues/2206

View File

@ -3,12 +3,17 @@ name: Nightly load test
on: on:
schedule: schedule:
- cron: '0 0 * * 1-5' - cron: '0 0 * * 1-5'
workflow_call:
workflow_dispatch:
pull_request: pull_request:
paths: paths:
- ".github/workflows/load_test.yaml" - ".github/workflows/load_test.yaml"
branches:
- 'main' env:
AWS_DEFAULT_REGION: us-east-1
AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
jobs: jobs:
load-tests: load-tests:
@ -16,28 +21,30 @@ jobs:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
runs-on: runs-on:
group: aws-g5-12xlarge group: aws-g6-12xl-plus-priv-cache
env: env:
DOCKER_VOLUME: /cache DOCKER_VOLUME: /cache
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Install k6 - name: Install Python 3.11
run: | uses: actions/setup-python@v2
curl https://github.com/grafana/k6/releases/download/v0.44.0/k6-v0.44.0-linux-amd64.tar.gz -L | tar xvz --strip-components 1 with:
python-version: 3.11
- name: Start starcoder - name: Install poetry
run: | run: |
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768 curl -sSL https://install.python-poetry.org | python3 -
sleep 10 export PATH="$HOME/.local/bin:$PATH"
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health poetry --version
- name: Run k6 - name: Run bench test
run: | run: |
./k6 run load_tests/starcoder_load.js export PATH="$HOME/.local/bin:$PATH"
cd load_tests
- name: Stop starcoder poetry install
if: ${{ always() }} poetry run python benchmarks.py --sha ${{ github.sha }} --results-file "s3://text-generation-inference-ci/benchmarks/ci/${{ github.sha }}.parquet"
run: | shell: bash
docker stop tgi-starcoder || true env:
HF_TOKEN: ${{ secrets.HF_TOKEN_BENCHMARK }}

53
.github/workflows/nix_build.yaml vendored Normal file
View File

@ -0,0 +1,53 @@
name: "Nix Build Docker image"
on:
pull_request:
push:
branches:
- 'main'
tags:
- 'v*'
concurrency:
group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build_nix_image:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
- name: Build
run: nix build .#dockerImage
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3
with:
install: true
buildkitd-config: /tmp/buildkitd.toml
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Login to internal Container Registry
# if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Push to docker
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
else
export TAG=${{ github.ref_name }}-nix
fi
export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"

34
.github/workflows/nix_cache.yaml vendored Normal file
View File

@ -0,0 +1,34 @@
name: "Cache devshells"
on:
pull_request:
paths:
- "flake.nix"
- "flake.lock"
- "nix/**"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
tests:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
# If you chose signing key for write access
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
env:
USER: github_runner
- name: Build impure devshell
run: nix build .\#devShells.x86_64-linux.impure
- name: Build impure devshell (CUDA dev)
run: nix build .\#devShells.x86_64-linux.impureWithCuda
# Pure shell dependencies are covered by Nix tests.
# - name: Build pure devshell
# run: nix build .\#devShells.x86_64-linux.pure

42
.github/workflows/nix_tests.yaml vendored Normal file
View File

@ -0,0 +1,42 @@
name: "Nix Tests"
on:
pull_request:
paths:
- ".github/workflows/nix_tests.yaml"
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
- "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
tests:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
- name: Build
run: nix develop .#test --command echo "Ok"
- name: Pre-commit tests.
run: nix develop .#test --command pre-commit run --all-files
- name: Python tests.
run: nix develop .#test --command python -m pytest server/tests/
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Rust tests.
run: nix develop .#test --command cargo test

View File

@ -8,6 +8,7 @@ on:
- "proto/**" - "proto/**"
- "router/**" - "router/**"
- "launcher/**" - "launcher/**"
- "backends/**"
- "Cargo.lock" - "Cargo.lock"
- "rust-toolchain.toml" - "rust-toolchain.toml"
@ -17,26 +18,17 @@ concurrency:
jobs: jobs:
run_tests: run_tests:
runs-on: ubuntu-latest runs-on:
group: aws-highmemory-32-plus-priv
env:
SCCACHE_GHA_ENABLED: "on"
RUSTC_WRAPPER: /usr/local/bin/sccache
SCCACHE: 0.3.3
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v1 uses: actions/setup-python@v4
id: python
with: with:
python-version: 3.9 python-version: 3.11
- name: Install Rust - uses: dtolnay/rust-toolchain@1.85.0
uses: actions-rs/toolchain@v1
with: with:
# Released on: 02 May, 2024
# https://releases.rs/docs/1.78.0/
toolchain: 1.79.0
override: true
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Protoc - name: Install Protoc
uses: arduino/setup-protoc@v1 uses: arduino/setup-protoc@v1
@ -44,34 +36,22 @@ jobs:
run: | run: |
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install sccache
run: |
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
chmod +x /usr/local/bin/sccache
- name: configure sccache
uses: actions/github-script@v6
with:
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}');
core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-');
- name: cargo registry cache
uses: actions/cache@v3
with:
key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }}
restore-keys: |
cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-
cargo-${{ runner.os }}-
path: |
~/.cargo/registry
~/.cargo/git
- name: Install - name: Install
run: | run: |
sudo apt update
sudo apt install python3.11-dev -y
pip install -U pip uv
uv venv
source ./.venv/bin/activate
make install-cpu make install-cpu
- name: Download locked kernels
run: |
source ./.venv/bin/activate
kernels download server
- name: Run server tests - name: Run server tests
run: | run: |
pip install pytest source ./.venv/bin/activate
uv pip install pytest
export HF_TOKEN=${{ secrets.HF_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests pytest -s -vv server/tests
- name: Pre-commit checks - name: Pre-commit checks
@ -82,6 +62,6 @@ jobs:
- name: Run Rust tests - name: Run Rust tests
run: | run: |
cargo test cargo test
- name: sccache stats - name: Run Rust tests with google feature
run: | run: |
/usr/local/bin/sccache --show-stats cargo test --features google

View File

@ -10,9 +10,12 @@ jobs:
trufflehog: trufflehog:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Secret Scanning - name: Secret Scanning
uses: trufflesecurity/trufflehog@main uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
with:
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
extra_args: --results=verified,unknown --exclude-detectors=postgres

17
.gitignore vendored
View File

@ -3,9 +3,14 @@ target
router/tokenizer.json router/tokenizer.json
*__pycache__* *__pycache__*
backends/v2/src/client/pb
backends/v3/src/client/pb
backends/client/src/v2/pb
backends/client/src/v3/pb
# ROCm auto-generated files # ROCm auto-generated files
*.hip *.hip
server/exllamav2_kernels/exllamav2_kernels/hip/ server/exllamav2
server/exllama_kernels/exllama_kernels/hip/ server/exllama_kernels/exllama_kernels/hip/
server/exllama_kernels/exllama_kernels/hip_func/ server/exllama_kernels/exllama_kernels/hip_func/
*_hip.cuh *_hip.cuh
@ -14,3 +19,13 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
data/ data/
load_tests/*.json load_tests/*.json
server/fbgemmm
.direnv/
.venv/
# Gaudi auto-generated files
hl-smi_log*.txt
.graph_dumps
out
hqt_output

View File

@ -4,8 +4,9 @@ repos:
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
exclude: crate-hashes.json
- id: trailing-whitespace - id: trailing-whitespace
exclude: docs/source/basic_tutorials/launcher.md exclude: docs/source/reference/launcher.md
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 24.2.0 rev: 24.2.0
hooks: hooks:
@ -13,6 +14,11 @@ repos:
- repo: https://github.com/doublify/pre-commit-rust - repo: https://github.com/doublify/pre-commit-rust
rev: v1.0 rev: v1.0
hooks: hooks:
- id: fmt
- id: cargo-check - id: cargo-check
- id: fmt
- id: clippy - id: clippy
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

82
.redocly.lint-ignore.yaml Normal file
View File

@ -0,0 +1,82 @@
# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API.
# See https://redoc.ly/docs/cli/ for more information.
docs/openapi.json:
no-empty-servers:
- '#/openapi'
spec:
- >-
#/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum
- >-
#/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/grammar/nullable'
- >-
#/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum
- '#/components/schemas/GenerateResponse/properties/details/nullable'
- '#/components/schemas/StreamResponse/properties/details/nullable'
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
- '#/components/schemas/ChatRequest/properties/stream_options/nullable'
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
- '#/components/schemas/ToolChoice/nullable'
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
- '#/components/schemas/ChatCompletionChunk/properties/usage/nullable'
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
no-invalid-media-type-examples:
- '#/paths/~1/post/responses/422/content/application~1json/example'
- '#/paths/~1/post/responses/424/content/application~1json/example'
- '#/paths/~1/post/responses/429/content/application~1json/example'
- '#/paths/~1/post/responses/500/content/application~1json/example'
- '#/paths/~1generate/post/responses/422/content/application~1json/example'
- '#/paths/~1generate/post/responses/424/content/application~1json/example'
- '#/paths/~1generate/post/responses/429/content/application~1json/example'
- '#/paths/~1generate/post/responses/500/content/application~1json/example'
- >-
#/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example
- '#/paths/~1tokenize/post/responses/404/content/application~1json/example'
- >-
#/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/422/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/424/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/429/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/500/content/application~1json/example
operation-4xx-response:
- '#/paths/~1health/get/responses'
- '#/paths/~1info/get/responses'
- '#/paths/~1metrics/get/responses'
no-unused-components:
- '#/components/schemas/Completion'
security-defined:
- '#/paths/~1/post'
- '#/paths/~1generate/post'
- '#/paths/~1generate_stream/post'
- '#/paths/~1health/get'
- '#/paths/~1info/get'
- '#/paths/~1metrics/get'
- '#/paths/~1tokenize/post'
- '#/paths/~1v1~1chat~1completions/post'
- '#/paths/~1v1~1completions/post'
- '#/paths/~1v1~1models/get'

2982
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,23 +1,40 @@
[workspace] [workspace]
members = [ members = [
"benchmark", "benchmark",
"router", "backends/v2",
"router/client", "backends/v3",
"router/grpc-metadata", "backends/grpc-metadata",
"launcher" "backends/trtllm",
"backends/llamacpp",
"launcher",
"router"
]
default-members = [
"benchmark",
"backends/v2",
"backends/v3",
"backends/grpc-metadata",
# "backends/trtllm",
"launcher",
"router"
] ]
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "2.2.0" version = "3.2.3-dev0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies] [workspace.dependencies]
base64 = "0.22.0" base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] } tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] } hf-hub = { version = "0.4.2", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[profile.release] [profile.release]
incremental = true incremental = true

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -11,11 +11,15 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -28,32 +32,29 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA ARG GIT_SHA
ARG DOCKER_LABEL ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt --frozen
# Python builder # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
WORKDIR /usr/src/
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.3.0 ARG PYTORCH_VERSION=2.6
ARG PYTHON_VERSION=3.11
ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1
ARG MAMBA_VERSION=24.3.0-0
ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx # Automatically set by buildx
ARG TARGETPLATFORM ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \ build-essential \
ca-certificates \ ca-certificates \
@ -61,31 +62,18 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \ curl \
git && \ git && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
# Install conda ENV PATH="$PATH:/root/.local/bin"
# translating Docker's TARGETPLATFORM into mamba arches RUN uv python install ${PYTHON_VERSION}
RUN case ${TARGETPLATFORM} in \ RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging
"linux/arm64") MAMBA_ARCH=aarch64 ;; \ ENV VIRTUAL_ENV=/usr/src/.venv/
*) MAMBA_ARCH=x86_64 ;; \ ENV PATH="$PATH:/usr/src/.venv/bin/"
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# CUDA kernels builder image # CUDA kernels builder image
FROM pytorch-install AS kernel-builder FROM pytorch-install AS kernel-builder
ARG MAX_JOBS=8 ARG MAX_JOBS=8
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX"
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build cmake \ ninja-build cmake \
@ -99,7 +87,7 @@ WORKDIR /usr/src
COPY server/Makefile-flash-att Makefile COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention # Build specific version of flash attention
RUN make build-flash-attention RUN . .venv/bin/activate && make build-flash-attention
# Build Flash Attention v2 CUDA kernels # Build Flash Attention v2 CUDA kernels
FROM kernel-builder AS flash-att-v2-builder FROM kernel-builder AS flash-att-v2-builder
@ -109,96 +97,61 @@ WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2 # Build specific version of flash attention v2
RUN make build-flash-attention-v2-cuda RUN . .venv/bin/activate && make build-flash-attention-v2-cuda
# Build Transformers exllama kernels # Build Transformers exllama kernels
FROM kernel-builder AS exllama-kernels-builder FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/exllama_kernels/ . COPY server/exllama_kernels/ .
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build RUN . .venv/bin/activate && python setup.py build
# Build Transformers exllama kernels # Build Transformers exllama kernels
FROM kernel-builder AS exllamav2-kernels-builder FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/exllamav2_kernels/ . COPY server/Makefile-exllamav2/ Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build RUN . .venv/bin/activate && make build-exllamav2
# Build Transformers awq kernels # Build Transformers awq kernels
FROM kernel-builder AS awq-kernels-builder FROM kernel-builder AS awq-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-awq Makefile COPY server/Makefile-awq Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq RUN . .venv/bin/activate && make build-awq
# Build eetq kernels
FROM kernel-builder AS eetq-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-eetq Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
# Build marlin kernels
FROM kernel-builder AS marlin-kernels-builder
WORKDIR /usr/src
COPY server/marlin/ .
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Lorax Punica kernels # Build Lorax Punica kernels
FROM kernel-builder AS lorax-punica-builder FROM kernel-builder AS lorax-punica-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder AS custom-kernels-builder FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/custom_kernels/ . COPY server/custom_kernels/ .
# Build specific version of transformers # Build specific version of transformers
RUN python setup.py build RUN . .venv/bin/activate && python setup.py build
# Build FBGEMM CUDA kernels
FROM kernel-builder AS fbgemm-builder
WORKDIR /usr/src
COPY server/Makefile-fbgemm Makefile
COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch
COPY server/fix_torch90a.sh fix_torch90a.sh
RUN make build-fbgemm
# Build vllm CUDA kernels
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm-cuda
# Build mamba kernels # Build mamba kernels
FROM kernel-builder AS mamba-builder FROM kernel-builder AS mamba-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile COPY server/Makefile-selective-scan Makefile
RUN make build-all RUN . .venv/bin/activate && make build-all
# Build flashinfer
FROM kernel-builder AS flashinfer-builder
WORKDIR /usr/src
COPY server/Makefile-flashinfer Makefile
RUN . .venv/bin/activate && make install-flashinfer
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base FROM nvidia/cuda:12.4.0-base-ubuntu22.04 AS base
# Conda env
ENV PATH=/opt/conda/bin:$PATH \
CONDA_PREFIX=/opt/conda
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 PORT=80
@ -212,52 +165,64 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
git \ git \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy conda with PyTorch installed # RUN curl -LsSf https://astral.sh/uv/install.sh | sh
COPY --from=pytorch-install /opt/conda /opt/conda # ENV PATH="$PATH:/root/.local/bin"
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from fbgemm builder
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
# Install flash-attention dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir # RUN pip install einops --no-cache-dir
# Copy env with PyTorch installed
COPY --from=pytorch-install /usr/src/.venv /usr/src/.venv
ENV PYTHON_VERSION=3.11
RUN uv python install ${PYTHON_VERSION}
ENV VIRTUAL_ENV=/usr/src/.venv/
ENV PATH="$PATH:/usr/src/.venv/bin/"
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV HF_KERNELS_CACHE=/kernels
RUN cd server && \ RUN cd server && \
make gen-server && \ uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --no-install-project --active && \
pip install -r requirements_cuda.txt && \ make gen-server-raw && \
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \ kernels download .
pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 RUN cd server && \
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --active --python=${PYTHON_VERSION} && \
uv pip install nvidia-nccl-cu12==2.25.1 && \
pwd && \
text-generation-server --help
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/.venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
COPY --from=flashinfer-builder /usr/src/.venv/lib/python3.11/site-packages/flashinfer/ /usr/src/.venv/lib/python3.11/site-packages/flashinfer/
# ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries
# This is needed because exl2 tries to load flash-attn
# And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1
# Deps before the binaries # Deps before the binaries
# The binaries change on every build given we burn the SHA into them # The binaries change on every build given we burn the SHA into them
@ -289,5 +254,6 @@ FROM base
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/"
ENTRYPOINT ["/tgi-entrypoint.sh"] ENTRYPOINT ["/tgi-entrypoint.sh"]
# CMD ["--json-output"] # CMD ["--json-output"]

167
Dockerfile.neuron Normal file
View File

@ -0,0 +1,167 @@
# Fetch and extract the TGI sources
FROM alpine AS tgi
RUN mkdir -p /tgi
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
FROM alpine AS optimum-neuron
RUN mkdir -p /optimum-neuron
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
# Build cargo components (adapted from TGI original Dockerfile)
# Note: we cannot use the cargo-chef base image as it uses python 3.11
FROM ubuntu:22.04 AS chef
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
curl ca-certificates build-essential \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y
ENV PATH="/root/.cargo/bin:${PATH}"
RUN cargo install cargo-chef --locked
WORKDIR /usr/src
FROM chef AS planner
COPY backends/neuron/Cargo.toml Cargo.toml
COPY Cargo.lock Cargo.lock
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
unzip python3-dev libssl-dev pkg-config \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
COPY backends/neuron/Cargo.toml Cargo.toml
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
COPY Cargo.lock Cargo.lock
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --release
# Python base image
FROM ubuntu:22.04 AS base
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
python3-pip \
python3-setuptools \
python-is-python3 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN pip3 --no-cache-dir install --upgrade pip
# Python server build image
FROM base AS pyserver
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
make \
python3-venv \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN install -d /pyserver
WORKDIR /pyserver
COPY backends/neuron/server server
COPY proto proto
RUN pip3 install -r server/build-requirements.txt
RUN VERBOSE=1 BUILDDIR=/pyserver/build PROTODIR=/pyserver/proto make -C server package
# Neuron base image (used for deployment)
FROM base AS neuron
# Install system prerequisites
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
gnupg2 \
wget \
python3-dev \
libexpat1 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN echo "deb https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list
RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
# Install neuronx packages
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
aws-neuronx-dkms=2.19.64.0 \
aws-neuronx-collectives=2.23.135.0-3e70920f2 \
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
aws-neuronx-tools=2.20.204.0 \
libxml2 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"
# Install manually torch CPU version to avoid pulling CUDA
RUN pip3 install \
torch==2.5.1 \
torchvision==0.20.1 \
--index-url https://download.pytorch.org/whl/cpu
RUN pip3 install \
neuronx-cc==2.16.372.0 \
torch-neuronx==2.5.1.2.4.0 \
transformers-neuronx==0.13.322 \
neuronx-distributed==0.10.1 \
libneuronxla==2.1.681.0 \
--extra-index-url=https://pip.repos.neuron.amazonaws.com
# Install HuggingFace packages
RUN pip3 install \
hf_transfer huggingface_hub
# Install optimum-neuron
COPY --from=optimum-neuron /optimum-neuron optimum-neuron
RUN pip3 install ./optimum-neuron
# TGI base env
ENV HUGGINGFACE_HUB_CACHE=/tmp \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
# Disable color logs as they are not supported by CloudWatch
ENV LOGURU_COLORIZE=NO
ENV LOG_COLORIZE=0
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router-v2 /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
# Install python server
COPY --from=pyserver /pyserver/build/dist dist
RUN pip install dist/text_generation_server*.tar.gz
# Final image
FROM neuron
COPY backends/neuron/tgi_env.py /tgi_env.py
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]

24
Dockerfile.nix Normal file
View File

@ -0,0 +1,24 @@
# Build the image and get out the docker file:
#
# docker build -t tgi-nix-builder -f Dockerfile.nix
# docker run --log-driver=none tgi-nix-builder | docker load
FROM nixos/nix:2.18.8 AS builder
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
RUN nix profile install nixpkgs#cachix
RUN cachix use text-generation-inference
WORKDIR /root
ADD . .
RUN nix build .
RUN mkdir /tmp/nix-store-closure
RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure
FROM ubuntu:24.04
WORKDIR /app
# Copy /nix/store
COPY --from=builder /tmp/nix-store-closure /nix/store
COPY --from=builder /root/result /app
RUN ldconfig
CMD ["ldconfig", "/app/bin/text-generation-launcher"]

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -11,11 +11,14 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -28,170 +31,247 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA ARG GIT_SHA
ARG DOCKER_LABEL ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for RoCm FROM rocm/dev-ubuntu-22.04:6.3.1-complete AS base
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
ARG HIPBLASLT_BRANCH="4d40e36"
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
ARG LEGACY_HIPBLASLT_OPTION=
ARG RCCL_BRANCH="648a58d"
ARG RCCL_REPO="https://github.com/ROCm/rccl"
ARG TRITON_BRANCH="e5be006"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG PYTORCH_BRANCH="3a585126"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG AITER_BRANCH="21d47a9"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ENV PATH=/opt/rocm/llvm/bin:$PATH
ENV ROCM_PATH=/opt/rocm
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ARG PYTHON_VERSION=3.11
RUN mkdir -p /app
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \ build-essential \
ca-certificates \ ca-certificates \
ccache \ ccache \
curl \ curl \
git \ git \
make \ ninja-build \
libssl-dev \ cmake \
g++ \ software-properties-common \
# Needed to build VLLM & flash. python3.11-dev \
rocthrust-dev \ python3.11-venv && \
hipsparse-dev \ rm -rf /var/lib/apt/lists/*
hipblas-dev \
hipblaslt-dev \
rocblas-dev \
hiprand-dev \
rocrand-dev \
miopen-hip-dev \
hipfft-dev \
hipcub-dev \
hipsolver-dev \
rccl-dev \
cmake \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Keep in sync with `server/pyproject.toml COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
ARG MAMBA_VERSION=23.1.0-1 ENV PATH="$PATH:/root/.local/bin"
ARG PYTORCH_VERSION='2.3.0' RUN uv python install ${PYTHON_VERSION}
ARG ROCM_VERSION='6.0.2' RUN uv venv --python ${PYTHON_VERSION} && uv pip install pip setuptools packaging
ARG PYTHON_VERSION='3.10.10' ENV VIRTUAL_ENV=/usr/src/.venv/
# Automatically set by buildx ENV PATH="$PATH:/usr/src/.venv/bin/"
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. RUN . .venv/bin/activate && pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
mamba init && \
rm ~/mambaforge.sh
# Install flash-attention, torch dependencies FROM base AS build_hipblaslt
RUN pip install numpy einops ninja --no-cache-dir ARG HIPBLASLT_BRANCH
ARG HIPBLAS_COMMON_BRANCH
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
ARG LEGACY_HIPBLASLT_OPTION
RUN git clone https://github.com/ROCm/hipBLAS-common.git
RUN . .venv/bin/activate && cd hipBLAS-common \
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
&& mkdir build \
&& cd build \
&& cmake .. \
&& make package \
&& dpkg -i ./*.deb
RUN git clone https://github.com/ROCm/hipBLASLt
RUN . .venv/bin/activate && cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
&& cd build/release \
&& make package
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
RUN conda install intel::mkl-static intel::mkl-include FROM base AS build_rccl
RUN pip uninstall -y triton && \ ARG RCCL_BRANCH
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ ARG RCCL_REPO
cd triton/python && \ RUN git clone ${RCCL_REPO}
pip install . RUN . .venv/bin/activate && cd rccl \
&& git checkout ${RCCL_BRANCH} \
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir FROM base AS build_triton
ARG TRITON_BRANCH
ARG TRITON_REPO
RUN git clone ${TRITON_REPO}
RUN . .venv/bin/activate && cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
ARG _GLIBCXX_USE_CXX11_ABI="1" FROM base AS build_amdsmi
ARG CMAKE_PREFIX_PATH="/opt/conda" RUN . .venv/bin/activate && cd /opt/rocm/share/amd_smi \
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" && pip wheel . --wheel-dir=dist
ARG BUILD_CAFFE2="0" \ RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
BUILD_CAFFE2_OPS="0" \
USE_CUDA="0" \
USE_ROCM="1" \
BUILD_TEST="0" \
USE_FBGEMM="0" \
USE_NNPACK="0" \
USE_QNNPACK="0" \
USE_XNNPACK="0" \
USE_FLASH_ATTENTION="1" \
USE_MEM_EFF_ATTENTION="0"
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install FROM base AS build_pytorch
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
RUN git clone ${PYTORCH_REPO} pytorch
RUN . .venv/bin/activate && cd pytorch && git checkout ${PYTORCH_BRANCH} && \
pip install -r requirements.txt && git submodule update --init --recursive \
&& python3 tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${PYTORCH_VISION_REPO} vision
RUN . .venv/bin/activate && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${FA_REPO}
RUN . .venv/bin/activate && cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install \
&& cp /app/flash-attention/dist/*.whl /app/install
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm FROM base AS final
ENV HIP_FORCE_DEV_KERNARG=1 RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
dpkg -i /install/*deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
dpkg -i /install/*deb \
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. ARG AITER_REPO
# However, Triton requires a tunning for each prompt length, which is prohibitive. ARG AITER_BRANCH
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 RUN git clone --recursive ${AITER_REPO}
RUN . .venv/bin/activate && cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
FROM base AS kernel-builder RUN rm -rf /var/lib/apt/lists/*
FROM final AS kernel-builder
# # Build vllm kernels # # Build vllm kernels
FROM kernel-builder AS vllm-builder FROM kernel-builder AS vllm-builder
WORKDIR /usr/src
COPY server/Makefile-vllm Makefile COPY server/Makefile-vllm Makefile
RUN . .venv/bin/activate && pip install setuptools_scm
# Build specific version of vllm # Build specific version of vllm
RUN make build-vllm-rocm RUN . .venv/bin/activate && make build-vllm-rocm
# Build Flash Attention v2 kernels
FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
RUN make build-flash-attention-v2-rocm
# Build Transformers CUDA kernels (gpt-neox and bloom) # Build Transformers CUDA kernels (gpt-neox and bloom)
FROM kernel-builder AS custom-kernels-builder FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ . COPY server/custom_kernels/ .
RUN python setup.py build RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
# Build exllama kernels # Build exllama kernels
FROM kernel-builder AS exllama-kernels-builder FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ . COPY server/exllama_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
RUN python setup.py build
# Build exllama v2 kernels # Build exllama v2 kernels
FROM kernel-builder AS exllamav2-kernels-builder FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ . COPY server/exllamav2_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
RUN python setup.py build FROM kernel-builder AS marlin-kernels
ENV MARLIN_KERNELS_BRANCH=v0.3.6
ENV VLLM_TARGET_DEVICE=rocm
RUN . .venv/bin/activate && git clone https://github.com/danieldk/marlin-kernels.git && \
cd marlin-kernels && \
git checkout ${MARLIN_KERNELS_BRANCH} && \
python3 setup.py bdist_wheel --dist-dir=dist
FROM base AS base-copy FROM kernel-builder AS moe-kernels
ENV MOE_KERNELS_BRANCH=v0.8.2
ENV VLLM_TARGET_DEVICE=rocm
RUN . .venv/bin/activate && git clone https://github.com/danieldk/moe-kernels.git && \
cd moe-kernels && \
git checkout ${MOE_KERNELS_BRANCH} && \
python3 setup.py bdist_wheel --dist-dir=dist
FROM final AS base-copy
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 PORT=80
# Copy builds artifacts from vllm builder ENV VIRTUAL_ENV=/app/.venv/
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages ENV PATH="$PATH:/app/.venv/bin/"
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
RUN cd server && \ RUN cd server && \
make gen-server && \ uv pip install grpcio-tools mypy-protobuf && \
pip install -r requirements_rocm.txt && \ uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir make gen-server-raw
RUN cd server && \
pwd && \
text-generation-server --help
RUN --mount=type=bind,from=vllm-builder,src=/app/vllm/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=exllama-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=exllamav2-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=marlin-kernels,src=/app/marlin-kernels/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=moe-kernels,src=/app/moe-kernels/dist,target=/install \
uv pip install /install/*.whl
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@ -211,8 +291,24 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image # Final image
FROM base-copy FROM base-copy
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"] ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"] ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib"
ENV PYTHONPATH=/app/.venv/lib/python3.11/site-packages
# CMD ["--json-output"]

126
Dockerfile_gaudi Normal file
View File

@ -0,0 +1,126 @@
# Those arguments are required to build the image
ARG HABANA_VERSION=1.20.0
ARG PYTORCH_VERSION=2.6.0
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
ENV PYO3_PYTHON="/root/.local/bin/python" \
PYTHON_SYS_EXECUTABLE="/root/.local/bin/python" \
PYO3_PYTHON_VERSION="3.10"
RUN curl -LsSf https://astral.sh/uv/install.sh | sh \
&& . $HOME/.local/bin/env \
&& uv python install 3.10 --default --preview \
&& test -f /root/.local/bin/python || (echo "Python 3.10 not found at /root/.local/bin/python" && exit 1)
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt
# Text Generation Inference base image
ARG HABANA_VERSION
ARG PYTORCH_VERSION
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
ENV ATTENTION=default
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
# Text Generation Inference base env
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
# Assert that Python 3.10 is installed as the launcher is compiled with Python 3.10
RUN python3.10 --version || (echo "Python 3.10 is not installed" && exit 1)
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
WORKDIR /usr/src
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
libssl-dev \
ca-certificates \
make \
curl \
git \
&& rm -rf /var/lib/apt/lists/*
# Install server
COPY proto proto
COPY backends/gaudi/server server
COPY backends/gaudi/server/Makefile server/Makefile
ARG HABANA_VERSION
RUN cd server && \
make gen-server && \
pip install --no-deps -r requirements.txt && \
bash ./dill-0.3.8-patch.sh && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image
FROM base AS sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh
ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base
ENV HF_HUB_ENABLE_HF_TRANSFER 1
ENV HABANA_VISIBLE_DEVICES all
ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]

View File

@ -1,6 +1,6 @@
ARG PLATFORM=xpu ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -12,11 +12,14 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -29,20 +32,48 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA ARG GIT_SHA
ARG DOCKER_LABEL ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for Intel # Text Generation Inference base image for Intel
FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
USER root USER root
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH=/opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
@ -52,40 +83,41 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/intel-for-pytorch-gpu-dev all main" > /tmp/intel-for-pytorch-gpu-dev.list
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 PORT=80
WORKDIR /usr/src WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_intel.txt && \ pip install -U pip uv && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest ENV TORCH_LLM_ALLREDUCE=1
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
@ -101,20 +133,31 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \ curl \
ca-certificates \ ca-certificates \
make \ make \
g++ \ g++-12 \
gcc-12 \
git \ git \
wget \ wget \
cmake cmake \
libnuma-dev
RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12
RUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30
RUN update-alternatives --set cc /usr/bin/gcc
RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30
RUN update-alternatives --set c++ /usr/bin/g++
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 PORT=80
ARG MAMBA_VERSION=23.1.0-1 ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.10.10' ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx # Automatically set by buildx
ARG TARGETPLATFORM ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH ENV PATH=/opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba # Install mamba
@ -128,42 +171,40 @@ RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \ bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh rm ~/mambaforge.sh
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
RUN conda install -c conda-forge gperftools mkl RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install triton==3.1.0 py-libnuma
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install triton
WORKDIR /usr/src WORKDIR /usr/src
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . ENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
ENV KMP_BLOCKTIME=1
ENV KMP_TPAUSE=0
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_intel.txt && \ pip install -U pip uv && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@ -173,5 +214,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final FROM ${PLATFORM} AS final
ENV ATTENTION=flashdecoding-ipex
ENV PREFIX_CACHING=1
ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

88
Dockerfile_llamacpp Normal file
View File

@ -0,0 +1,88 @@
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
ARG llamacpp_version=b4827
ARG llamacpp_cuda=OFF
ARG llamacpp_native=ON
ARG llamacpp_cpu_arm_arch=native
ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
WORKDIR /opt/src
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt upgrade -y && apt install -y \
clang \
cmake \
curl \
git \
python3-dev \
libssl-dev \
pkg-config \
tar
ADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
RUN mkdir -p llama.cpp \
&& tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \
&& cd llama.cpp \
&& cmake -B build \
-DCMAKE_INSTALL_PREFIX=/usr \
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
-DCMAKE_C_COMPILER=clang \
-DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
-DGGML_CUDA=${llamacpp_cuda} \
-DGGML_NATIVE=${llamacpp_native} \
-DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_SERVER=OFF \
&& cmake --build build --parallel --config Release \
&& cmake --install build
WORKDIR /app
COPY rust-toolchain.toml rust-toolchain.toml
RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.1 --profile minimal -y
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef --locked
FROM deps AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
FROM deps AS builder
COPY --from=planner /app/recipe.json recipe.json
RUN cargo chef cook \
--recipe-path recipe.json \
--profile release \
--package text-generation-router-llamacpp
COPY . .
RUN cargo build \
--profile release \
--package text-generation-router-llamacpp --frozen
FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt upgrade -y && apt install -y \
python3-venv \
python3-pip
RUN python3 -m venv /venv
ENV PATH="/venv/bin:$PATH"
COPY backends/llamacpp/requirements.txt requirements.txt
COPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py
COPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/
RUN pip3 install --no-cache-dir \
-r requirements.txt \
-e gguf-py
COPY --from=builder /usr/lib/libllama.so /usr/lib/
COPY --from=builder /usr/lib/libggml*.so /usr/lib/
COPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/
ENV HF_HUB_ENABLE_HF_TRANSFER=1
ENTRYPOINT ["text-generation-router-llamacpp"]

158
Dockerfile_trtllm Normal file
View File

@ -0,0 +1,158 @@
ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real;100-real;120-real"
ARG cuda_base=12.8.0
ARG build_type=release
ARG ompi_version=4.1.7
ARG sccache_gha_enabled=off
ARG actions_results_url=""
ARG actions_runtime_token=""
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
cmake \
curl \
gcc-14 \
g++-14 \
git \
git-lfs \
lld \
libssl-dev \
libucx-dev \
libasan8 \
libubsan1 \
ninja-build \
pkg-config \
pipx \
python3 \
python3-dev \
python3-setuptools \
tar \
wget --no-install-recommends && \
pipx ensurepath
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI
FROM cuda-builder AS mpi-builder
WORKDIR /opt/src/mpi
ARG ompi_version
ENV OMPI_VERSION=${ompi_version}
ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \
make install && \
rm -rf ${OMPI_TARBALL_FILENAME}/..
# Install TensorRT
FROM cuda-builder AS trt-builder
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
RUN chmod +x /opt/install_tensorrt.sh && \
/opt/install_tensorrt.sh
# Build Backend
FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference
# Scoped global args reuse
ARG cuda_arch_list
ARG build_type
ARG sccache_gha_enabled
ARG actions_results_url
ARG actions_runtime_token
# Install Rust
ENV PATH="/root/.cargo/bin:$PATH"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo && \
cargo install sccache --version ">=0.10.0" --locked
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
ENV USE_LLD_LINKER=ON
ENV CUDA_ARCH_LIST=${cuda_arch_list}
# SCCACHE Specifics args - before finding a better, more generic, way...
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
ENV ACTIONS_RESULTS_URL=${actions_results_url}
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY router router
COPY backends backends
COPY benchmark benchmark
COPY launcher launcher
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
ENV RUSTC_WRAPPER=sccache
ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
RUN export CC=gcc-14 \
export CXX=g++-14 \
export CMAKE_C_COMPILER_LAUNCHER=sccache && \
export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
cargo build --profile ${build_type} --package text-generation-backends-trtllm --bin text-generation-backends-trtllm && \
sccache --show-stats
FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS runtime
RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
pipx ensurepath && \
pipx install --include-deps transformers tokenizers
WORKDIR /usr/local/tgi/bin
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
# This is used only for the CI/CD
FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS ci-runtime
RUN apt update && apt install -y libasan8 libubsan1 libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
pipx ensurepath && \
pipx install --include-deps transformers tokenizers
WORKDIR /usr/local/tgi/bin
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
# Basically we copy from target/debug instead of target/release
COPY --from=tgi-builder /usr/src/text-generation-inference/target/debug/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
# This is the final image
FROM runtime
LABEL co.huggingface.vendor="Hugging Face Inc."
LABEL org.opencontainers.image.authors="hardware@hf.co"
LABEL org.opencontainers.title="Text-Generation-Inference TensorRT-LLM Backend"
ENTRYPOINT ["./text-generation-launcher"]
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]

View File

@ -5,13 +5,13 @@ install-server-cpu:
cd server && make install-server cd server && make install-server
install-router: install-router:
cd router && cargo install --path . cargo install --path backends/v3/
install-launcher: install-launcher:
cd launcher && cargo install --path . cargo install --path launcher/
install-benchmark: install-benchmark:
cd benchmark && cargo install --path . cargo install --path benchmark/
install: install-server install-router install-launcher install: install-server install-router install-launcher
@ -53,3 +53,6 @@ run-falcon-7b-instruct-quantize:
clean: clean:
rm -rf target aml rm -rf target aml
preview_doc:
doc-builder preview text-generation-inference docs/source --not_python_module

107
README.md
View File

@ -1,7 +1,7 @@
<div align="center"> <div align="center">
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0"> <a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">
<img width=560 width=315 alt="Making TGI deployment optimal" src="https://huggingface.co/datasets/Narsil/tgi_assets/resolve/main/thumbnail.png"> <img width=560 alt="Making TGI deployment optimal" src="https://huggingface.co/datasets/Narsil/tgi_assets/resolve/main/thumbnail.png">
</a> </a>
# Text Generation Inference # Text Generation Inference
@ -13,7 +13,7 @@
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational"> <img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
</a> </a>
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) A Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co)
to power Hugging Chat, the Inference API and Inference Endpoint. to power Hugging Chat, the Inference API and Inference Endpoint.
</div> </div>
@ -28,6 +28,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
- [Distributed Tracing](#distributed-tracing) - [Distributed Tracing](#distributed-tracing)
- [Architecture](#architecture) - [Architecture](#architecture)
- [Local install](#local-install) - [Local install](#local-install)
- [Local install (Nix)](#local-install-nix)
- [Optimized architectures](#optimized-architectures) - [Optimized architectures](#optimized-architectures)
- [Run locally](#run-locally) - [Run locally](#run-locally)
- [Run](#run) - [Run](#run)
@ -42,12 +43,15 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
- Tensor Parallelism for faster inference on multiple GPUs - Tensor Parallelism for faster inference on multiple GPUs
- Token streaming using Server-Sent Events (SSE) - Token streaming using Server-Sent Events (SSE)
- Continuous batching of incoming requests for increased total throughput - Continuous batching of incoming requests for increased total throughput
- [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) compatible with Open AI Chat Completion API
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures - Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
- Quantization with : - Quantization with :
- [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [GPT-Q](https://arxiv.org/abs/2210.17323) - [GPT-Q](https://arxiv.org/abs/2210.17323)
- [EETQ](https://github.com/NetEase-FuXi/EETQ) - [EETQ](https://github.com/NetEase-FuXi/EETQ)
- [AWQ](https://github.com/casper-hansen/AutoAWQ) - [AWQ](https://github.com/casper-hansen/AutoAWQ)
- [Marlin](https://github.com/IST-DASLab/marlin)
- [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
@ -80,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0 --model-id $model ghcr.io/huggingface/text-generation-inference:3.2.3 --model-id $model
``` ```
And then you can make requests like And then you can make requests like
@ -92,9 +96,32 @@ curl 127.0.0.1:8080/generate_stream \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
```bash
curl localhost:8080/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.3-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
``` ```
@ -114,23 +141,24 @@ You have the option to utilize the `HF_TOKEN` environment variable for configuri
For example, if you want to serve the gated Llama V2 model variants: For example, if you want to serve the gated Llama V2 model variants:
1. Go to https://huggingface.co/settings/tokens 1. Go to https://huggingface.co/settings/tokens
2. Copy your cli READ token 2. Copy your CLI READ token
3. Export `HF_TOKEN=<your cli READ token>` 3. Export `HF_TOKEN=<your CLI READ token>`
or with Docker: or with Docker:
```shell ```shell
model=meta-llama/Llama-2-7b-chat-hf model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token> token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.2.3 --model-id $model
``` ```
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by [`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
`PyTorch` to do distributed training/inference. `text-generation-inference` make `PyTorch` to do distributed training/inference. `text-generation-inference` makes
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models. use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
@ -163,18 +191,32 @@ overridden with the `--otlp-service-name` argument
![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png) ![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png)
Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)
### Local install ### Local install
You can also opt to install `text-generation-inference` locally. You can also opt to install `text-generation-inference` locally.
First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least First clone the repository and change directory into it:
Python 3.9, e.g. using `conda`:
```shell
git clone https://github.com/huggingface/text-generation-inference
cd text-generation-inference
```
Then [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
Python 3.9, e.g. using `conda` or `python venv`:
```shell ```shell
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
#using conda
conda create -n text-generation-inference python=3.11 conda create -n text-generation-inference python=3.11
conda activate text-generation-inference conda activate text-generation-inference
#using python venv
python3 -m venv .venv
source .venv/bin/activate
``` ```
You may also need to install Protoc. You may also need to install Protoc.
@ -208,6 +250,45 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
sudo apt-get install libssl-dev gcc -y sudo apt-get install libssl-dev gcc -y
``` ```
### Local install (Nix)
Another option is to install `text-generation-inference` locally using [Nix](https://nixos.org). Currently,
we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can
be pulled from a binary cache, removing the need to build them locally.
First follow the instructions to [install Cachix and enable the TGI cache](https://app.cachix.org/cache/text-generation-inference).
Setting up the cache is important, otherwise Nix will build many of the dependencies
locally, which can take hours.
After that you can run TGI with `nix run`:
```shell
cd text-generation-inference
nix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-id meta-llama/Llama-3.1-8B-Instruct
```
**Note:** when you are using Nix on a non-NixOS system, you have to [make some symlinks](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library)
to make the CUDA driver libraries visible to Nix packages.
For TGI development, you can use the `impure` dev shell:
```shell
nix develop .#impure
# Only needed the first time the devshell is started or after updating the protobuf.
(
cd server
mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
)
```
All development dependencies (cargo, Python, Torch), etc. are available in this
dev shell.
## Optimized architectures ## Optimized architectures
TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models). TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
@ -232,7 +313,7 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
### Quantization ### Quantization
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement: You can also run pre-quantized weights (AWQ, GPTQ, Marlin) or on-the-fly quantize weights with bitsandbytes, EETQ, fp8, to reduce the VRAM requirement:
```shell ```shell
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
@ -240,6 +321,8 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantiz
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`. 4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
Read more about quantization in the [Quantization documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/quantization).
## Develop ## Develop
```shell ```shell

BIN
assets/v3_benchmarks.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 209 KiB

View File

@ -107,20 +107,22 @@ impl Client {
#[instrument(skip_all)] #[instrument(skip_all)]
pub async fn warmup( pub async fn warmup(
&mut self, &mut self,
max_input_length: u32, max_input_tokens: Option<u32>,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: Option<u32>,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<(Option<u32>, u32, u32)> {
let mut n_tokens = 0; let mut n_tokens = 0;
let mut requests = Vec::new(); let mut requests = Vec::new();
// Create requests // Create requests
while n_tokens < max_prefill_tokens { while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let mut truncate = max_prefill_tokens - n_tokens;
if let Some(max_input_tokens) = max_input_tokens {
truncate = min(max_input_tokens, truncate);
}
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
input_chunks input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 { if n_tokens == 0 {
input_chunks.push( input_chunks.push(
Chunk::Image(Image { Chunk::Image(Image {
@ -136,7 +138,7 @@ impl Client {
// been updated to support chunks. // been updated to support chunks.
let mut inputs = String::new(); let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
if n_tokens == 0 { if n_tokens == 0 {
// 1 request is enough to test vision heads. // 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation. // Sending images on other queries messes up easily with truncation.
@ -145,6 +147,12 @@ impl Client {
)); ));
} }
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
max_total_tokens - truncate
} else {
1
};
requests.push(Request { requests.push(Request {
id: 0, id: 0,
inputs, inputs,
@ -153,9 +161,13 @@ impl Client {
}), }),
// We truncate the input on the server side to be sure that it has the correct size // We truncate the input on the server side to be sure that it has the correct size
truncate, truncate,
// Most request will have that
add_special_tokens: true,
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
@ -171,7 +183,7 @@ impl Client {
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens,
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: true, ignore_eos_token: true,
}), }),
@ -179,7 +191,7 @@ impl Client {
top_n_tokens: 20, top_n_tokens: 20,
adapter_id: None, adapter_id: None,
}); });
n_tokens += max_input_length; n_tokens += truncate;
// Check max_batch_size // Check max_batch_size
if Some(requests.len()) == max_batch_size { if Some(requests.len()) == max_batch_size {
@ -191,19 +203,23 @@ impl Client {
id: 0, id: 0,
size: requests.len() as u32, size: requests.len() as u32,
requests, requests,
max_tokens: max_input_length, max_tokens: max_input_tokens.unwrap_or(0),
max_blocks: 0, max_blocks: 0,
}; };
let request = tonic::Request::new(WarmupRequest { let request = tonic::Request::new(WarmupRequest {
batch: Some(batch), batch: Some(batch),
max_input_length, max_input_tokens,
max_prefill_tokens, max_prefill_tokens,
max_total_tokens, max_total_tokens,
}) })
.inject_context(); .inject_context();
let response = self.stub.warmup(request).await?.into_inner(); let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens) Ok((
response.max_supported_total_tokens,
response.max_input_tokens,
response.max_total_tokens,
))
} }
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch
@ -214,8 +230,13 @@ impl Client {
pub async fn prefill( pub async fn prefill(
&mut self, &mut self,
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner(); let response = self.stub.prefill(request).await?.into_inner();
Ok(( Ok((
response.generations, response.generations,

View File

@ -0,0 +1,271 @@
/// Multi shard Client
use crate::{v3, Health, ShardInfo};
use crate::{ClientError, Result};
use crate::v3::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings};
use v3::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec<Client>,
}
impl ShardedClient {
fn new(clients: Vec<Client>) -> Self {
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given uri
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<(Option<u32>, u32, u32)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
add_special_tokens: true,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
cache_len: 0,
chunk_len: None,
adapter_id: None,
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch, None).await?;
Ok(())
}
}

62
backends/gaudi/Makefile Normal file
View File

@ -0,0 +1,62 @@
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
root_dir := ${mkfile_dir}/../..
HABANA_VERSION := 1.20.0
PYTORCH_VERSION := 2.6.0
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
image:
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
run-local-dev-container:
docker run -it \
--runtime=habana \
--ipc=host \
--cap-add=sys_nice \
--net=host \
-e HABANA_VISIBLE_DEVICES=all \
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
-e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \
-e LOG_LEVEL=debug \
-e PORT=8080 \
-v /home/ubuntu/.cache/huggingface:/data \
-v $(PWD):/text-generation-inference \
-w /text-generation-inference \
vault.habana.ai/gaudi-docker/$(HABANA_VERSION)/ubuntu22.04/habanalabs/pytorch-installer-$(PYTORCH_VERSION):latest
install-dependencies:
pip install git+https://github.com/HabanaAI/DeepSpeed.git@$(HABANA_VERSION)
pip install outlines~=0.0.34
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
install-server:
make -C ${root_dir}/backends/gaudi/server install PROTO_PATH=../../../proto/v3
install-router:
make -C ${root_dir} install-router
install-launcher:
make -C ${root_dir} install-launcher
# use source to load the rust in path
local-dev-install: install-dependencies
bash -c 'source "$$HOME/.cargo/env" && \
make install-server && \
make install-router && \
make install-launcher'
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
run-integration-tests:
uv pip install -r ${root_dir}/backends/gaudi/server/integration-tests/requirements.txt
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests
# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
capture-expected-outputs-for-integration-tests:
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py

142
backends/gaudi/README.md Normal file
View File

@ -0,0 +1,142 @@
# Text-generation-inference - Gaudi backend
## Description
This is the TGI backend for Intel Gaudi. This backend is composed of the tgi server optimized for Gaudi hardware.
## Build your own image
The simplest way to build TGI with the Gaudi backend is to use the provided `Makefile`:
Option 1: From the project root directory:
```bash
make -C backends/gaudi image
```
Option 2: From the Gaudi backend directory:
```bash
cd backends/gaudi
make image
```
You can now run the server with the following command:
Option 1: Sharded:
```bash
model=meta-llama/Llama-3.1-8B-Instruct
hf_token=$(cat ${HOME}/.cache/huggingface/token)
volume=${HOME}/.cache/huggingface
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data \
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
tgi-gaudi --model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048
```
Option 2: Non-sharded:
```bash
model=meta-llama/Llama-3.1-8B-Instruct
hf_token=$(cat ${HOME}/.cache/huggingface/token)
volume=${HOME}/.cache/huggingface
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data \
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
tgi-gaudi --model-id $model \
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
```
## Contributing
### Local Development
This is useful if you want to run the server locally for better debugging.
```bash
make -C backends/gaudi run-local-dev-container
```
Then run the following command inside the container to install tgi for gaudi:
```bash
make -C backends/gaudi local-dev-install
```
Add rust to path:
```bash
. "$HOME/.cargo/env"
```
Option 1: Run the server (sharded model):
```bash
LOG_LEVEL=debug text-generation-launcher \
--model-id meta-llama/Llama-3.1-8B-Instruct \
--sharded true \
--num-shard 8 \
--max-input-tokens 512 \
--max-total-tokens 1024 \
--max-batch-size 8 \
--max-batch-prefill-tokens 2048
```
Option 2: Run the server (non-sharded model):
```bash
LOG_LEVEL=debug text-generation-launcher \
--model-id meta-llama/Llama-3.1-8B-Instruct \
--max-input-tokens 512 \
--max-total-tokens 1024 \
--max-batch-size 4 \
--max-batch-prefill-tokens 2048
```
You can then test the server with the following curl command from another terminal (can be outside the container):
```bash
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
### Integration tests
To run the integration tests, you need to first build the image:
```bash
make -C backends/gaudi image
```
Then run the following command to run the integration tests:
```bash
make -C backends/gaudi run-integration-tests
```
To capture the expected outputs for the integration tests, you can run the following command:
```bash
make -C backends/gaudi capture-expected-outputs-for-integration-tests
```
#### How the integration tests works
The integration tests works as follows:
1. Start a tgi server in a container, similar to the command:
```bash
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data \
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
tgi-gaudi --model-id $model \
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
```
2. Do a /generate request to the server, similar to the command:
```bash
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
3. Check the output of the server against the expected output:
```python
assert curl_output == expected_output
```
This is the repeated for a set of models and configurations.

View File

@ -0,0 +1,283 @@
# Examples of Docker Commands for Gaudi Backend
This page gives a list of examples of docker run commands for some of the most popular models.
> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.
## Default Precision (BF16)
### Llama3.1-8B on 1 card (BF16)
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
### Llama3.1-70B 8 cards (BF16)
```bash
model=meta-llama/Meta-Llama-3.1-70B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
### Llama2-7B on 1 Card (BF16)
```bash
model=meta-llama/Llama-2-7b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
### Llama2-70B on 8 cards (BF16)
```bash
model=meta-llama/Llama-2-70b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
### Llava-v1.6-Mistral-7B on 1 card (BF16)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```
## FP8 Precision
Please refer to the [FP8 Precision](https://huggingface.co/docs/text-generation-inference/backends/gaudi_new#how-to-use-different-precision-formats) section for more details. You need to measure the statistics of the model first before running the model in FP8 precision.
## Llama3.1-8B on 1 Card (FP8)
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
## Llama3.1-70B on 8 cards (FP8)
```bash
model=meta-llama/Meta-Llama-3.1-70B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
## Llama2-7B on 1 Card (FP8)
```bash
model=meta-llama/Llama-2-7b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
## Llama2-70B on 8 Cards (FP8)
```bash
model=meta-llama/Llama-2-70b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
## Llava-v1.6-Mistral-7B on 1 Card (FP8)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```
## Llava-v1.6-Mistral-7B on 8 Cards (FP8)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```

164
backends/gaudi/server/.gitignore vendored Normal file
View File

@ -0,0 +1,164 @@
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation_server/__pycache__/
text_generation_server/pb/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
transformers
safetensors
flash-attention/
flash-attention-v2/
vllm/
llm-awq/
eetq/
mamba/

View File

@ -0,0 +1,38 @@
include Makefile-flash-att
include Makefile-flash-att-v2
include Makefile-vllm
include Makefile-awq
include Makefile-eetq
include Makefile-selective-scan
PROTO_PATH ?= ../proto/v3
unit-tests:
pytest -s -vv -m "not private" tests
gen-server:
# Compile protos
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I$(PROTO_PATH) --python_out=text_generation_server/pb \
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb $(PROTO_PATH)/generate.proto
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
install: gen-server
pip install pip --upgrade
pip install --no-deps -r requirements.txt
pip install -e "."
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
install-poetry:
curl -sSL https://install.python-poetry.org | python3 -
update-lock:
rm poetry.lock
poetry lock --no-update
export-requirements:
poetry export -o requirements.txt --without-hashes

View File

@ -0,0 +1,15 @@
# Fork that adds only the correct stream to this kernel in order
# to make cuda graphs work.
awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4
awq:
rm -rf llm-awq
git clone https://github.com/huggingface/llm-awq
build-awq: awq
cd llm-awq/ && git fetch && git checkout $(awq_commit)
cd llm-awq/awq/kernels && python setup.py build
install-awq: build-awq
pip uninstall awq_inference_engine -y || true
cd llm-awq/awq/kernels && python setup.py install

View File

@ -0,0 +1,13 @@
eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0
eetq:
# Clone eetq
pip install packaging
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
build-eetq: eetq
cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
cd eetq && python setup.py build
install-eetq: build-eetq
cd eetq && python setup.py install

View File

@ -1,10 +1,10 @@
fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca fbgemm_commit := v0.8.0
build-fbgemm: build-fbgemm:
chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \ @if [ ! -d "fbgemm" ]; then \
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
cp fbgemm_remove_unused.patch fbgemm && \ fi
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \ cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
git submodule update --init --recursive && \ git submodule update --init --recursive && \
cd fbgemm_gpu && \ cd fbgemm_gpu && \
pip install -r requirements.txt && \ pip install -r requirements.txt && \

View File

@ -0,0 +1,12 @@
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
build-flash-attention:
if [ ! -d 'flash-attention' ]; then \
pip install -U packaging ninja --no-cache-dir && \
git clone https://github.com/HazyResearch/flash-attention.git; \
fi
cd flash-attention && git fetch && git checkout $(flash_att_commit) && \
MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build
install-flash-attention: build-flash-attention
cd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install

View File

@ -0,0 +1,21 @@
flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
build-flash-attention-v2-cuda:
pip install -U packaging wheel
pip install flash-attn==$(flash_att_v2_commit_cuda)
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
echo "Flash v2 installed"
build-flash-attention-v2-rocm:
if [ ! -d 'flash-attention-v2' ]; then \
pip install -U packaging ninja --no-cache-dir && \
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
fi
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
cd flash-attention-v2 && \
GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install

View File

@ -0,0 +1,28 @@
selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
causal-conv1d:
rm -rf causal-conv1d
git clone https://github.com/Dao-AILab/causal-conv1d.git
build-causal-conv1d: causal-conv1d
cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
install-causal-conv1d: build-causal-conv1d
pip uninstall causal-conv1d -y || true
cd causal-conv1d/ && pip install .
# selective-scan dependends on causal-conv1d
selective-scan:
rm -rf mamba
git clone https://github.com/state-spaces/mamba.git mamba
build-selective-scan: selective-scan
cd mamba/ && git fetch && git checkout $(selective_scan_commit)
cd mamba && python setup.py build
install-selective-scan: install-causal-conv1d build-selective-scan
pip uninstall selective-scan-cuda -y || true
cd mamba && pip install .
build-all: build-causal-conv1d build-selective-scan

View File

@ -0,0 +1,23 @@
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/Narsil/vllm.git vllm; \
fi
cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
install-vllm-cuda: build-vllm-cuda
cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
build-vllm-rocm:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/mht-sharma/vllm.git vllm; \
fi
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
install-vllm-rocm: build-vllm-rocm
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .

View File

@ -0,0 +1,15 @@
# Text Generation Inference Python gRPC Server
A Python gRPC server for Text Generation Inference
## Install
```shell
make install
```
## Run
```shell
make run-dev
```

View File

@ -0,0 +1,91 @@
#!/bin/bash
git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git
pushd dill
cat <<EOF > dill-0.3.7.patch
diff --git a/dill/_dill.py b/dill/_dill.py
index d0cf543..f6eb662 100644
--- a/dill/_dill.py
+++ b/dill/_dill.py
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
XRangeType = range
from types import MappingProxyType as DictProxyType, new_class
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
-import __main__ as _main_module
+class _LazyMainModule(object):
+ _module = None
+ @property
+ def module(self):
+ if self._module is None:
+ import __main__ as _m_module
+ self._module = _m_module
+ return self._module
+_main_module = _LazyMainModule()
import marshal
import gc
# import zlib
@@ -353,7 +361,7 @@ class Pickler(StockPickler):
_fmode = kwds.pop('fmode', None)
_recurse = kwds.pop('recurse', None)
StockPickler.__init__(self, file, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._diff_cache = {}
self._byref = settings['byref'] if _byref is None else _byref
self._strictio = False #_strictio
@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler):
settings = Pickler.settings
_ignore = kwds.pop('ignore', None)
StockUnpickler.__init__(self, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._ignore = settings['ignore'] if _ignore is None else _ignore
def load(self): #NOTE: if settings change, need to update attributes
obj = StockUnpickler.load(self)
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
if not self._ignore:
# point obj class to main
try: obj.__class__ = getattr(self._main, type(obj).__name__)
@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj):
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
logger.trace(pickler, "# D1")
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
logger.trace(pickler, "# D3")
- elif '__name__' in obj and obj != _main_module.__dict__ \\
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
and type(obj['__name__']) is str \\
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
diff --git a/dill/session.py b/dill/session.py
index 74234ab..1be8d89 100644
--- a/dill/session.py
+++ b/dill/session.py
@@ -233,7 +233,7 @@ def dump_module(
protocol = settings['protocol']
main = module
if main is None:
- main = _main_module
+ main = _main_module.module
elif isinstance(main, str):
main = _import_module(main)
if not isinstance(main, ModuleType):
@@ -501,7 +501,7 @@ def load_module(
pass
assert loaded is main
_restore_modules(unpickler, main)
- if main is _main_module or main is module:
+ if main is _main_module.module or main is module:
return None
else:
return main
EOF
git apply dill-0.3.7.patch
python -m pip install .
popd
rm -fr dill

View File

@ -0,0 +1,91 @@
#!/bin/bash
git clone -b 0.3.8 https://github.com/uqfoundation/dill.git
pushd dill
cat <<EOF > dill-0.3.8.patch
diff --git a/dill/_dill.py b/dill/_dill.py
index d42432f..1d251e6 100644
--- a/dill/_dill.py
+++ b/dill/_dill.py
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
XRangeType = range
from types import MappingProxyType as DictProxyType, new_class
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
-import __main__ as _main_module
+class _LazyMainModule(object):
+ _module = None
+ @property
+ def module(self):
+ if self._module is None:
+ import __main__ as _m_module
+ self._module = _m_module
+ return self._module
+_main_module = _LazyMainModule()
import marshal
import gc
# import zlib
@@ -355,7 +363,7 @@ class Pickler(StockPickler):
_fmode = kwds.pop('fmode', None)
_recurse = kwds.pop('recurse', None)
StockPickler.__init__(self, file, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._diff_cache = {}
self._byref = settings['byref'] if _byref is None else _byref
self._strictio = False #_strictio
@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler):
settings = Pickler.settings
_ignore = kwds.pop('ignore', None)
StockUnpickler.__init__(self, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._ignore = settings['ignore'] if _ignore is None else _ignore
def load(self): #NOTE: if settings change, need to update attributes
obj = StockUnpickler.load(self)
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
if not self._ignore:
# point obj class to main
try: obj.__class__ = getattr(self._main, type(obj).__name__)
@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj):
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
logger.trace(pickler, "# D1")
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
logger.trace(pickler, "# D3")
- elif '__name__' in obj and obj != _main_module.__dict__ \\
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
and type(obj['__name__']) is str \\
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
diff --git a/dill/session.py b/dill/session.py
index e91068a..a921b43 100644
--- a/dill/session.py
+++ b/dill/session.py
@@ -233,7 +233,7 @@ def dump_module(
protocol = settings['protocol']
main = module
if main is None:
- main = _main_module
+ main = _main_module.module
elif isinstance(main, str):
main = _import_module(main)
if not isinstance(main, ModuleType):
@@ -501,7 +501,7 @@ def load_module(
pass
assert loaded is main
_restore_modules(unpickler, main)
- if main is _main_module or main is module:
+ if main is _main_module.module or main is module:
return None
else:
return main
EOF
git apply dill-0.3.8.patch
python -m pip install .
popd
rm -fr dill

View File

@ -0,0 +1,85 @@
import json
import os
from typing import Dict, Any, Generator
import pytest
from test_model import TEST_CONFIGS
UNKNOWN_CONFIGS = {
name: config
for name, config in TEST_CONFIGS.items()
if config["expected_greedy_output"] == "unknown"
or config["expected_batch_output"] == "unknown"
}
@pytest.fixture(scope="module", params=UNKNOWN_CONFIGS.keys())
def test_config(request) -> Dict[str, Any]:
"""Fixture that provides model configurations for testing."""
test_config = UNKNOWN_CONFIGS[request.param]
test_config["test_name"] = request.param
return test_config
@pytest.fixture(scope="module")
def test_name(test_config):
yield test_config["test_name"]
@pytest.fixture(scope="module")
def tgi_service(launcher, test_config, test_name) -> Generator:
"""Fixture that provides a TGI service for testing."""
with launcher(test_config["model_id"], test_name) as service:
yield service
@pytest.mark.asyncio
async def test_capture_expected_outputs(tgi_service, test_config, test_name):
"""Test that captures expected outputs for models with unknown outputs."""
print(f"Testing {test_name} with {test_config['model_id']}")
# Wait for service to be ready
await tgi_service.health(1000)
client = tgi_service.client
# Test single request (greedy)
print("Testing single request...")
response = await client.generate(
test_config["input"],
max_new_tokens=32,
)
greedy_output = response.generated_text
# Test multiple requests (batch)
print("Testing batch requests...")
responses = []
for _ in range(4):
response = await client.generate(
test_config["input"],
max_new_tokens=32,
)
responses.append(response.generated_text)
# Store results in a JSON file
output_file = "server/integration-tests/expected_outputs.json"
results = {}
# Try to load existing results if file exists
if os.path.exists(output_file):
with open(output_file, "r") as f:
results = json.load(f)
# Update results for this model
results[test_name] = {
"model_id": test_config["model_id"],
"input": test_config["input"],
"greedy_output": greedy_output,
"batch_outputs": responses,
"args": test_config["args"],
}
# Save updated results
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults for {test_name} saved to {output_file}")

View File

@ -0,0 +1,292 @@
import asyncio
import contextlib
import os
import shlex
import subprocess
import sys
import threading
import time
from tempfile import TemporaryDirectory
from typing import List
import socket
import docker
import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from loguru import logger
from test_model import TEST_CONFIGS
from text_generation import AsyncClient
from text_generation.types import Response
# Use the latest image from the local docker build
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi")
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None)
HF_TOKEN = os.getenv("HF_TOKEN", None)
assert (
HF_TOKEN is not None
), "HF_TOKEN is not set, please set it as some models are gated and thus the test will fail without it"
if DOCKER_VOLUME is None:
logger.warning(
"DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
)
LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
BASE_ENV = {
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"LOG_LEVEL": LOG_LEVEL,
"HF_TOKEN": os.getenv("HF_TOKEN", None),
}
HABANA_RUN_ARGS = {
"runtime": "habana",
"ipc_mode": "host",
"cap_add": ["sys_nice"],
}
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
level="INFO",
)
def stream_container_logs(container, test_name):
"""Stream container logs in a separate thread."""
try:
for log in container.logs(stream=True, follow=True):
print(
f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}",
end="",
file=sys.stderr,
flush=True,
)
except Exception as e:
logger.error(f"Error streaming container logs: {str(e)}")
class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}", timeout=3600)
def _inner_health(self):
raise NotImplementedError
async def health(self, timeout: int = 60):
assert timeout > 0
start_time = time.time()
logger.info(f"Starting health check with timeout of {timeout}s")
for attempt in range(timeout):
if not self._inner_health():
logger.error("Launcher crashed during health check")
raise RuntimeError("Launcher crashed")
try:
await self.client.generate("test")
elapsed = time.time() - start_time
logger.info(f"Health check passed after {elapsed:.1f}s")
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
if attempt == timeout - 1:
logger.error(f"Health check failed after {timeout}s: {str(e)}")
raise RuntimeError(f"Health check failed: {str(e)}")
if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt
logger.debug(
f"Connection attempt {attempt}/{timeout} failed: {str(e)}"
)
time.sleep(1)
except Exception as e:
logger.error(f"Unexpected error during health check: {str(e)}")
# Get full traceback for debugging
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise
class ContainerLauncherHandle(LauncherHandle):
def __init__(self, docker_client, container_name, port: int):
super(ContainerLauncherHandle, self).__init__(port)
self.docker_client = docker_client
self.container_name = container_name
def _inner_health(self) -> bool:
try:
container = self.docker_client.containers.get(self.container_name)
status = container.status
if status not in ["running", "created"]:
logger.warning(f"Container status is {status}")
# Get container logs for debugging
logs = container.logs().decode("utf-8")
logger.debug(f"Container logs:\n{logs}")
return status in ["running", "created"]
except Exception as e:
logger.error(f"Error checking container health: {str(e)}")
return False
class ProcessLauncherHandle(LauncherHandle):
def __init__(self, process, port: int):
super(ProcessLauncherHandle, self).__init__(port)
self.process = process
def _inner_health(self) -> bool:
return self.process.poll() is None
@pytest.fixture(scope="module")
def data_volume():
tmpdir = TemporaryDirectory()
yield tmpdir.name
try:
# Cleanup the temporary directory using sudo as it contains root files created by the container
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True)
except subprocess.CalledProcessError as e:
logger.error(f"Error cleaning up temporary directory: {str(e)}")
@pytest.fixture(scope="module")
def launcher(data_volume):
@contextlib.contextmanager
def docker_launcher(
model_id: str,
test_name: str,
):
logger.info(
f"Starting docker launcher for model {model_id} and test {test_name}"
)
# Get a random available port
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
return port
port = get_free_port()
logger.debug(f"Using port {port}")
client = docker.from_env()
container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}"
try:
container = client.containers.get(container_name)
logger.info(
f"Stopping existing container {container_name} for test {test_name}"
)
container.stop()
container.wait()
except NotFound:
pass
except Exception as e:
logger.error(f"Error handling existing container: {str(e)}")
model_name = next(
name for name, cfg in TEST_CONFIGS.items() if cfg["model_id"] == model_id
)
tgi_args = TEST_CONFIGS[model_name]["args"].copy()
env = BASE_ENV.copy()
# Add model_id to env
env["MODEL_ID"] = model_id
# Add env config that is definied in the fixture parameter
if "env_config" in TEST_CONFIGS[model_name]:
env.update(TEST_CONFIGS[model_name]["env_config"].copy())
volumes = [f"{DOCKER_VOLUME}:/data"]
logger.debug(f"Using volume {volumes}")
try:
logger.info(f"Creating container with name {container_name}")
# Log equivalent docker run command for debugging, this is not actually executed
container = client.containers.run(
DOCKER_IMAGE,
command=tgi_args,
name=container_name,
environment=env,
detach=True,
volumes=volumes,
ports={"80/tcp": port},
**HABANA_RUN_ARGS,
)
logger.info(f"Container {container_name} started successfully")
# Start log streaming in a background thread
log_thread = threading.Thread(
target=stream_container_logs,
args=(container, test_name),
daemon=True, # This ensures the thread will be killed when the main program exits
)
log_thread.start()
# Add a small delay to allow container to initialize
time.sleep(2)
# Check container status after creation
status = container.status
logger.debug(f"Initial container status: {status}")
if status not in ["running", "created"]:
logs = container.logs().decode("utf-8")
logger.error(f"Container failed to start properly. Logs:\n{logs}")
yield ContainerLauncherHandle(client, container.name, port)
except Exception as e:
logger.error(f"Error starting container: {str(e)}")
# Get full traceback for debugging
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise
finally:
try:
container = client.containers.get(container_name)
logger.info(f"Stopping container {container_name}")
container.stop()
container.wait()
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
container.remove()
logger.info(f"Container {container_name} removed successfully")
except NotFound:
pass
except Exception as e:
logger.warning(f"Error cleaning up container: {str(e)}")
return docker_launcher
@pytest.fixture(scope="module")
def generate_load():
async def generate_load_inner(
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
) -> List[Response]:
try:
futures = [
client.generate(
prompt,
max_new_tokens=max_new_tokens,
decoder_input_details=True,
)
for _ in range(n)
]
return await asyncio.gather(*futures)
except Exception as e:
logger.error(f"Error generating load: {str(e)}")
raise
return generate_load_inner

View File

@ -0,0 +1,2 @@
[pytest]
asyncio_mode = auto

View File

@ -0,0 +1,7 @@
pytest >= 8.3.5
pytest-asyncio >= 0.26.0
docker >= 7.1.0
Levenshtein >= 0.27.1
loguru >= 0.7.3
aiohttp >= 3.11.14
text-generation

View File

@ -0,0 +1,276 @@
from typing import Any, Dict
from text_generation import AsyncClient
import pytest
from Levenshtein import distance as levenshtein_distance
# The "args" config is not optimized for speed but only check that the inference is working for the different models architectures
TEST_CONFIGS = {
"meta-llama/Llama-3.1-8B-Instruct-shared": {
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use",
"expected_batch_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use",
"args": [
"--sharded",
"true",
"--num-shard",
"8",
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"8",
"--max-batch-prefill-tokens",
"2048",
],
},
"meta-llama/Llama-3.1-8B-Instruct": {
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
"expected_batch_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
"env_config": {},
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"meta-llama/Llama-2-7b-chat-hf": {
"model_id": "meta-llama/Llama-2-7b-chat-hf",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
"expected_batch_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"mistralai/Mistral-7B-Instruct-v0.3": {
"model_id": "mistralai/Mistral-7B-Instruct-v0.3",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"bigcode/starcoder2-3b": {
"model_id": "bigcode/starcoder2-3b",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"google/gemma-7b-it": {
"model_id": "google/gemma-7b-it",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"Qwen/Qwen2-0.5B-Instruct": {
"model_id": "Qwen/Qwen2-0.5B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
"expected_batch_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"tiiuae/falcon-7b-instruct": {
"model_id": "tiiuae/falcon-7b-instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
"expected_batch_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"microsoft/phi-1_5": {
"model_id": "microsoft/phi-1_5",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
"expected_batch_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"openai-community/gpt2": {
"model_id": "openai-community/gpt2",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
"expected_batch_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"facebook/opt-125m": {
"model_id": "facebook/opt-125m",
"input": "What is Deep Learning?",
"expected_greedy_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
"expected_batch_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"EleutherAI/gpt-j-6b": {
"model_id": "EleutherAI/gpt-j-6b",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
}
print(f"Testing {len(TEST_CONFIGS)} models")
@pytest.fixture(scope="module", params=TEST_CONFIGS.keys())
def test_config(request) -> Dict[str, Any]:
"""Fixture that provides model configurations for testing."""
test_config = TEST_CONFIGS[request.param]
test_config["test_name"] = request.param
return test_config
@pytest.fixture(scope="module")
def model_id(test_config):
yield test_config["model_id"]
@pytest.fixture(scope="module")
def test_name(test_config):
yield test_config["test_name"]
@pytest.fixture(scope="module")
def expected_outputs(test_config):
return {
"greedy": test_config["expected_greedy_output"],
# "sampling": model_config["expected_sampling_output"],
"batch": test_config["expected_batch_output"],
}
@pytest.fixture(scope="module")
def input(test_config):
return test_config["input"]
@pytest.fixture(scope="module")
def tgi_service(launcher, model_id, test_name):
with launcher(model_id, test_name) as tgi_service:
yield tgi_service
@pytest.fixture(scope="module")
async def tgi_client(tgi_service) -> AsyncClient:
await tgi_service.health(1000)
return tgi_service.client
@pytest.mark.asyncio
async def test_model_single_request(
tgi_client: AsyncClient, expected_outputs: Dict[str, Any], input: str
):
# Bounded greedy decoding without input
response = await tgi_client.generate(
input,
max_new_tokens=32,
)
assert response.details.generated_tokens == 32
assert response.generated_text == expected_outputs["greedy"]
@pytest.mark.asyncio
async def test_model_multiple_requests(
tgi_client, generate_load, expected_outputs, input
):
num_requests = 4
responses = await generate_load(
tgi_client,
input,
max_new_tokens=32,
n=num_requests,
)
assert len(responses) == 4
expected = expected_outputs["batch"]
for r in responses:
assert r.details.generated_tokens == 32
# Compute the similarity with the expectation using the levenshtein distance
# We should not have more than two substitutions or additions
assert levenshtein_distance(r.generated_text, expected) < 3

3014
backends/gaudi/server/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,45 @@
[tool.poetry]
name = "text-generation-server"
version = "2.0.4"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts]
text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = ">=3.9,<3.13"
protobuf = "^5.0"
grpcio = "^1.71.1"
grpcio-status = "*"
grpcio-reflection = "*"
grpc-interceptor = "^0.15.0"
typer = "^0.15.0"
loguru = "^0.7.3"
opentelemetry-api = "^1.32.0"
opentelemetry-exporter-otlp = "^1.32.0"
opentelemetry-instrumentation-grpc = "^0.53b0"
hf-transfer = "^0.1.9"
sentencepiece = "^0.2.0"
peft = "^0.15"
optimum-habana = "1.17"
transformers = "^4.49"
numpy = "^1.26"
accelerate = "^0.33"
outlines= { version = "^0.0.36", optional = true }
prometheus-client = "^0.21.1"
py-cpuinfo = "^9.0.0"
[tool.poetry.group.dev.dependencies]
grpcio-tools = "*"
pytest = "^8.3.5"
[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry.requires-plugins]
poetry-plugin-export = ">=1.8"

View File

@ -0,0 +1,101 @@
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.8 ; python_version >= "3.9" and python_version < "3.13"
cloudpickle==3.1.1 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows" or python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
deprecated==1.2.18 ; python_version >= "3.9" and python_version < "3.13"
diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13"
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.18.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2025.3.2 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.70.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.72.0rc1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.9 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.30.2 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==8.6.1 ; python_version >= "3.9" and python_version < "3.13"
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.6 ; python_version >= "3.9" and python_version < "3.13"
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.7.3 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
peft==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==11.2.1 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==5.29.4 ; python_version >= "3.9" and python_version < "3.13"
psutil==7.0.0 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.33.1 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.11.3 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==14.0.0 ; python_version >= "3.9" and python_version < "3.13"
rpds-py==0.24.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.5.3 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==78.1.0 ; python_version >= "3.12" and python_version < "3.13"
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13"
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"
typing-inspection==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.2.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.17.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -0,0 +1,13 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/__init__.py
# License: Apache License Version 2.0, January 2004
from text_generation_server.adapters.weights import (
AdapterBatchData,
AdapterBatchMetadata,
)
__all__ = [
"AdapterBatchData",
"AdapterBatchMetadata",
]

View File

@ -0,0 +1,30 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/config.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Set, Tuple
import torch
from text_generation_server.adapters.weights import AdapterWeights
@dataclass
class ModuleMap:
module_name: str
module_weights: Dict[str, Tuple[torch.Tensor, str]]
@dataclass
class AdapterConfig(ABC):
base_model_name_or_path: str
@abstractmethod
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass

View File

@ -0,0 +1,471 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/lora.py
# License: Apache License Version 2.0, January 2004
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
AdapterWeights,
BatchAdapterWeights,
)
from text_generation_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
get_tmp_tensors,
orient_for_rank,
pad_rank,
use_cutlass_shrink,
)
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
def shard_lora_weights(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
split_dim: int,
process_group: ProcessGroup,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
]
# [r, hidden_size]
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
return weights_a, weights_b
@dataclass
class LoraConfig(AdapterConfig):
r: int
target_modules: Optional[Union[List[str], str]]
fan_in_fan_out: bool
lora_alpha: int
use_rslora: bool
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set()
module_map = {}
for weight_name in weight_names:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
continue
module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
}
adapter_weight_names.add(lora_a_name)
adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names
@classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
return cls(
base_model_name_or_path=hf_config.base_model_name_or_path,
r=hf_config.r,
target_modules=hf_config.target_modules,
fan_in_fan_out=hf_config.fan_in_fan_out,
lora_alpha=hf_config.lora_alpha,
use_rslora=(
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
),
)
class LoraWeights(AdapterWeights):
"""LoRA weights for a single adapter merged across all layers."""
def __init__(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False
# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)
# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)
self.adapter_config = adapter_config
@property
def weights_a(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_b
@property
def weights_a_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_b
def _transpose_weights(self):
if self._use_cutlass_shrink:
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
self._is_transposed = not self._is_transposed
@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]
# prepare pre-loaded lora weights for use in the model.
#
# this method processes and organizes lora weights for a specific layer type across all layers:
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
# - retrieves weights from `module_map` based on the `layer_type`.
# - processes `nlayers` number of layers.
# - converts weights to the specified `dtype`.
# - shards weights across `world_size` number of processes using the `process_group`.
# - maps weights to specific layers using `target_to_layer`.
# - tracks `unused_weight_names` to identify any unused weights.
#
# the method handles weight transposition, scaling, and padding to ensure compatibility
# with SGMV or BGMV operations.
@classmethod
def prepare_weights(
cls,
config: LoraConfig,
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
nlayers: int,
dtype: torch.dtype,
world_size: int,
process_group: ProcessGroup,
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
) -> Optional[AdapterWeights]:
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight
base_device = base_weight.device
if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
return None
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, dtype)
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, dtype)
scale = get_scaling_factor(
config.lora_alpha,
config.r,
uses_rslora=config.use_rslora,
)
unused_weight_names.discard(lora_a_name)
unused_weight_names.discard(lora_b_name)
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
# (A * B) * C = A * (B * C)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank
return LoraWeights(
*shard_lora_weights(
weights_a=lora_a_list,
weights_b=lora_b_list,
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
process_group=process_group,
),
config,
)
@dataclass
class RankSegments:
rank: int
lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor
# prefill (sgmv)
tmp_shrink: torch.Tensor
tmp_expand: torch.Tensor
segment_starts: torch.Tensor
segment_ends: torch.Tensor
# decode (bgmv)
indices: torch.Tensor
@dataclass
class BatchLoraWeights(BatchAdapterWeights):
lora_a: Dict[int, torch.Tensor]
lora_b: Dict[int, torch.Tensor]
adapter_index_configs: Dict[int, LoraConfig]
rank_data: Dict[int, RankSegments]
use_sgmv: bool
def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs
def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
for rank_data in self.rank_data.values()
)
@classmethod
def load(
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
}
if not adapter_weights:
return None
first_weights = next(iter(adapter_weights.values()))
device = first_weights.weights_a.device
segment_indices = meta.segment_indices
lora_a = {
idx: adapter_weights[idx].weights_a
for idx in segment_indices
if idx in adapter_weights
}
lora_b = {
idx: adapter_weights[idx].weights_b
for idx in segment_indices
if idx in adapter_weights
}
max_rank = max(
(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
),
default=0,
)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1
rank_data = {}
for rank, indices in rank_indices.items():
tmp_shrink = None
tmp_expand = None
segment_starts = None
segment_ends = None
batch_indices = None
if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(
lora_a_ptr_indices.size(0), rank, device
)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
if prefill_head_indices is not None:
for i, segment_index in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
]
batch_indices = [
idx if idx in rank_indices else -1 for idx in batch_indices
]
batch_indices = torch.tensor(
batch_indices, dtype=torch.int64, device=device
)
rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=batch_indices,
)
return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)
def get_scaling_factor(
lora_alpha: int,
r: int,
uses_rslora: bool = False,
) -> float:
"""Computes the scaling factor for the lora weights."""
if uses_rslora:
return lora_alpha / (r**0.5)
return lora_alpha / r
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
if hasattr(v, "lora_weights"):
return v.lora_weights
return v

View File

@ -0,0 +1,146 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/weights.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Type
import torch
@dataclass
class AdapterBatchMetadata:
# [batch_size]
adapter_indices: torch.Tensor
# [num_adapters]
adapter_set: Set[int]
# [num_segments + 1]
adapter_segments: torch.Tensor
# [num_segments]
# maps from segment index to adapter index, i.e.:
# segment_indices[s] == adapter_indices[i]
segment_indices: List[int]
class AdapterWeights(ABC):
@abstractclassmethod
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
pass
@property
def speculative_tokens(self) -> int:
return 0
class BatchAdapterWeights(ABC):
@abstractclassmethod
def has_adapter(self, adapter_index: int) -> bool:
pass
@abstractclassmethod
def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
prefill: bool,
prefill_head_indices: torch.Tensor,
) -> Optional["BatchAdapterWeights"]:
pass
class LayerAdapterWeights:
"""Adapter weights that apply to a particular layer."""
def __init__(self):
self.adapter_weights: Dict[int, AdapterWeights] = {}
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
self.adapter_weights[adapter_idx] = weights
def remove_adapter(self, adapter_idx: int):
if adapter_idx not in self.adapter_weights:
return
del self.adapter_weights[adapter_idx]
def is_empty(self) -> bool:
return len(self.adapter_weights) == 0
def get_data(
self,
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
for batch_type in adapter_weights.get_batch_types():
adapter_batch_types[batch_type][adapter_index] = adapter_weights
batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(
adapter_weights, meta, prefill, prefill_head_indices
)
if batched_weights is not None:
batch_data = batched_weights
return batch_data
@dataclass
class AdapterBatchData:
meta: AdapterBatchMetadata
# layer type -> adapter type -> batch weight data
data: Dict[str, Dict[str, BatchAdapterWeights]]
prefill: bool
@staticmethod
def from_meta(
meta: AdapterBatchMetadata,
weights: Dict[str, LayerAdapterWeights],
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(
meta, prefill, prefill_head_indices if k == "lm_head" else None
)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for lora_data in self.data.values():
if lora_data is None:
continue
for rank_data in lora_data.rank_data.values():
ranks.add(rank_data.rank)
return ranks
def layer_names(self) -> Set[str]:
return set(self.data.keys())
def adapter_keys(self) -> Set[str]:
adapter_keys = set()
for layer_data in self.data.values():
adapter_keys.update(layer_data.keys())
return adapter_keys
@property
def max_rank(self) -> int:
ranks = self.ranks()
return max(ranks) if len(ranks) > 0 else 0

View File

@ -0,0 +1,34 @@
import torch
from typing import Dict, Optional, TypeVar
from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch)
class Cache:
def __init__(self):
self.cache: Dict[int, B] = {}
def pop(self, batch_id: int) -> Optional[B]:
return self.cache.pop(batch_id, None)
def set(self, entry: B):
if entry is not None:
self.cache[entry.batch_id] = entry
def delete(self, batch_id: int):
batch = self.pop(batch_id)
if batch is not None:
del batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
def clear(self):
keys = list(self.cache.keys())
for k in keys:
self.delete(k)
def __len__(self):
return len(self.cache.keys())

View File

@ -0,0 +1,426 @@
import os
import psutil
import signal
import sys
import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from enum import Enum
from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters
app = typer.Typer()
class Quantization(str, Enum):
gptq = "gptq"
awq = "awq"
fp8 = "fp8"
class Dtype(str, Enum):
float16 = "float16"
bloat16 = "bfloat16"
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
):
if sharded:
# assert (
# os.getenv("RANK", None) is not None
# ), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Import here after the logger is added to log potential import exceptions
from text_generation_server import server
from text_generation_server.tracing import setup_tracing
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
# and warn the user
if lora_adapters:
logger.warning("LoRA adapters enabled (experimental feature).")
if "CUDA_GRAPHS" in os.environ:
logger.warning(
"LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
)
global CUDA_GRAPHS
CUDA_GRAPHS = None
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = "bfloat16" if dtype is None else dtype.value
logger.info(f"quantize={quantize}")
if dtype is not None and quantize not in {
None,
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
"gptq",
"awq",
"fp8",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
num_shard = int(os.getenv("WORLD_SIZE", "1"))
logger.info("CLI SHARDED = {}".format(num_shard))
import subprocess
cmd = (
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
)
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
if speculate is not None:
cmd += f"--speculate {speculate}"
logger.info("CLI server start deepspeed ={} ".format(cmd))
sys.stdout.flush()
sys.stderr.flush()
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
do_terminate = False
current_handler = signal.getsignal(signal.SIGTERM)
def terminate_handler(sig, frame):
nonlocal do_terminate
do_terminate = True
if callable(current_handler):
current_handler(sig, frame)
signal.signal(signal.SIGTERM, terminate_handler)
finished = False
while not finished:
try:
if do_terminate:
parent = psutil.Process(proc.pid)
all_procs = parent.children(recursive=True) + [parent]
for p in all_procs:
try:
p.terminate()
except psutil.NoSuchProcess:
pass
_, alive = psutil.wait_procs(all_procs, timeout=30)
for p in alive:
p.kill()
do_terminate = False
proc.wait(timeout=3)
except subprocess.TimeoutExpired:
pass
else:
finished = True
sys.stdout.flush()
sys.stderr.flush()
if proc.returncode != 0:
logger.error(f"{cmd} exited with status = {proc.returncode}")
return proc.returncode
else:
server.serve(
model_id,
lora_adapters,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
max_input_tokens,
)
@app.command()
def download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
auto_convert: bool = True,
logger_level: str = "INFO",
json_output: bool = False,
trust_remote_code: bool = False,
merge_lora: bool = False,
):
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Import here after the logger is added to log potential import exceptions
from text_generation_server import utils
# Test if files were already download
try:
utils.weight_files(model_id, revision, extension)
logger.info("Files are already present on the host. " "Skipping download.")
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
pass
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None
if not is_local_model:
# TODO: maybe reverse the default value of merge_lora?
# currently by default we don't merge the weights with the base model
if merge_lora:
try:
hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
else:
try:
utils.peft.download_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
except Exception:
pass
try:
import json
config = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(config, "r") as f:
config = json.load(f)
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id and base_model_id != model_id:
try:
logger.info(f"Downloading parent model {base_model_id}")
download_weights(
model_id=base_model_id,
revision="main",
extension=extension,
auto_convert=auto_convert,
logger_level=logger_level,
json_output=json_output,
trust_remote_code=trust_remote_code,
)
except Exception:
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to download weights from the hub
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(filenames, model_id, revision)
# Successfully downloaded weights
return
# No weights found on the hub with this extension
except utils.EntryNotFoundError as e:
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
if not extension == ".safetensors" or not auto_convert:
raise e
elif (Path(model_id) / "adapter_config.json").exists():
# Try to load as a local PEFT model
try:
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
elif (Path(model_id) / "config.json").exists():
# Try to load as a local Medusa model
try:
import json
config = Path(model_id) / "config.json"
with open(config, "r") as f:
config = json.load(f)
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id:
try:
logger.info(f"Downloading parent model {base_model_id}")
download_weights(
model_id=base_model_id,
revision="main",
extension=extension,
auto_convert=auto_convert,
logger_level=logger_level,
json_output=json_output,
trust_remote_code=trust_remote_code,
)
except Exception:
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to see if there are local pytorch weights
try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
try:
local_pt_files = utils.weight_files(model_id, revision, ".bin")
except Exception:
local_pt_files = utils.weight_files(model_id, revision, ".pt")
# No local pytorch weights
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
if extension == ".safetensors":
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Downloading PyTorch weights."
)
# Try to see if there are pytorch weights on the hub
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
if auto_convert:
if not trust_remote_code:
logger.warning(
"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
"Pickle files are unsafe and can essentially contain remote code execution!"
"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
)
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights to safetensors."
)
# Safetensors final filenames
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
try:
import transformers
import json
if is_local_model:
config_filename = os.path.join(model_id, "config.json")
else:
config_filename = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(config_filename, "r") as f:
config = json.load(f)
architecture = config["architectures"][0]
class_ = getattr(transformers, architecture)
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])
except Exception:
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names)
@app.command()
def quantize(
model_id: str,
output_dir: str,
revision: Optional[str] = None,
logger_level: str = "INFO",
json_output: bool = False,
trust_remote_code: bool = False,
upload_to_model_id: Optional[str] = None,
percdamp: float = 0.01,
act_order: bool = False,
groupsize: int = 128,
):
if revision is None:
revision = "main"
download_weights(
model_id=model_id,
revision=revision,
logger_level=logger_level,
json_output=json_output,
)
from text_generation_server.layers.gptq.quantize import quantize
quantize(
model_id=model_id,
bits=4,
groupsize=groupsize,
output_dir=output_dir,
revision=revision,
trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id,
percdamp=percdamp,
act_order=act_order,
sym=True,
)
if __name__ == "__main__":
app()

View File

@ -0,0 +1,53 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import os
import habana_frameworks.torch as htorch
quant_config = os.getenv("QUANT_CONFIG", "")
is_quantization_enabled = quant_config != ""
if is_quantization_enabled:
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
def patch_scoped_linear_all_reduce(model):
from deepspeed.module_inject.layers import LinearAllreduce
from optimum.habana.transformers.models.modeling_all_models import (
ScopedLinearAllReduce,
)
for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
patch_scoped_linear_all_reduce(module)
def setup_quantization(model):
if is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
htorch.core.quantization._check_params_as_const(model)
htorch.core.hpu_initialize(model)
return model
def prepare_model_for_quantization(model):
if is_quantization_enabled:
if model.config.model_type in [
"llama",
"falcon",
"qwen2",
"starcoder2",
"gemma",
]:
patch_scoped_linear_all_reduce(model)
from neural_compressor.torch.quantization import FP8Config, convert
config = FP8Config.from_json_file(quant_config)
model = convert(model, config)
return model

View File

@ -0,0 +1,45 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import torch
import grpc
from google.rpc import status_pb2, code_pb2
from grpc_status import rpc_status
from grpc_interceptor.server import AsyncServerInterceptor
from loguru import logger
from typing import Callable, Any
import traceback
import os
class ExceptionInterceptor(AsyncServerInterceptor):
async def intercept(
self,
method: Callable,
request_or_iterator: Any,
context: grpc.ServicerContext,
method_name: str,
) -> Any:
try:
response = method(request_or_iterator, context)
return await response
except Exception as err:
trace = " " + traceback.format_exc() if os.environ.get("DUMP_STACK") else ""
method_name = method_name.split("/")[-1]
logger.exception(f"Method {method_name} encountered an error.")
# Runtime Error cannot be recovered from
if isinstance(err, RuntimeError):
exit(1)
if torch.cuda.is_available():
torch.cuda.empty_cache()
from .utils.debug import dbg_trace
dbg_trace("EXCEPTION", traceback.format_exc())
await context.abort_with_status(
rpc_status.to_status(
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
)
)

View File

@ -0,0 +1,34 @@
from text_generation_server.layers.tensor_parallel import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelEmbedding,
)
from text_generation_server.layers.linear import (
get_linear,
FastLinear,
)
from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.lora import (
LoraLinear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
__all__ = [
"get_linear",
"FastLinear",
"TensorParallelColumnLinear",
"TensorParallelRowLinear",
"TensorParallelEmbedding",
"SpeculativeHead",
"LoraLinear",
"TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear",
"load_layer_norm",
"load_conv2d",
]

View File

@ -0,0 +1,28 @@
from .common import (
Seqlen,
HPUPagedAttentionMetadata,
trim_attn_metadata,
trim_seqlen_metadata,
)
from .hpu import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
)
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales
__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",
"HPUPagedAttentionMetadata",
"trim_seqlen_metadata",
"trim_attn_metadata",
]

View File

@ -0,0 +1,147 @@
from dataclasses import dataclass
import torch
from typing import Optional, List, Dict
import collections
_TYPE_CACHE = {}
@dataclass
class HPUPagedAttentionMetadata:
"""Metadata for PagedAttention."""
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
attn_bias: Optional[torch.Tensor]
def subtuple(
obj: object,
typename: str,
to_copy: List[str],
to_override: Optional[Dict[str, object]] = None,
):
if obj is None:
return None
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
if isinstance(obj, dict):
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
return _TYPE_CACHE[typename](**values)
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedAttentionMetadata",
[
"block_list",
"block_mapping",
"block_usage",
"block_scales",
"block_groups",
"attn_bias",
],
)
return attention_metadata
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def __init__(
self,
input_lengths,
cache_lengths,
cu_seqlen_q=None,
):
self.input_lengths = input_lengths
self.cache_lengths = cache_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
def trim_seqlen_metadata(metadata: Seqlen) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedSeqlen",
[
"input_lengths",
"cache_lengths",
"cu_seqlen_q",
"cu_seqlen_k",
],
)
return attention_metadata

View File

@ -0,0 +1,95 @@
import torch
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from typing import Optional
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from vllm_hpu_extension import ops
from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
SUPPORTS_WINDOWING = False
def fetch_from_cache(cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
return cache[: blocks.size(0)]
else:
return cache.index_select(0, blocks)
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
bs = seqlen.input_lengths.shape[0]
_, head_num, head_size = query.shape
_, kv_head_num, head_size = key.shape
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=softmax_scale,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=seqlen.input_lengths,
padding_side="left",
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
return attn_output
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
):
batch_size, head_num, head_size = query.shape
output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
value_cache=kv_cache.value,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_scales=hpu_attention_meta.block_scales,
block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale,
matmul_qk_op=Matmul(),
matmul_av_op=Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=fetch_from_cache,
values_fetch_func=fetch_from_cache,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, head_size)
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
]

View File

@ -0,0 +1,139 @@
from typing import Tuple
from dataclasses import dataclass, field
import torch
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights
from vllm_hpu_extension import cache_ops
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache:
"""
Key-value cache for attention layers.
"""
kv_cache: Tuple[torch.Tensor, torch.Tensor]
def __init__(
self,
*,
num_blocks: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
self.kv_cache = (
torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache[0]
@property
def value(self):
"""Get the value cache."""
return self.kv_cache[1]
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)

View File

@ -0,0 +1,97 @@
import torch
from typing import List
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def pack(imatrix: torch.Tensor, direction: str = "column"):
"""
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of packing, either "column" or "row"
Returns:
qmatrix (torch.Tensor): packed matrix of integers
"""
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
if direction == "column":
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
elif direction == "row":
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
qmatrix = qmatrix.to(torch.int32)
return qmatrix
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
"""
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
Args:
qmatrix (torch.Tensor): matrix of packed integers
direction (str): direction of unpacking, either "column" or "row"
Returns:
imatrix (torch.Tensor): matrix of integers
"""
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
if direction == "column":
imatrix = torch.bitwise_right_shift(
qmatrix[:, :, None], shifts[None, None, :]
).view(qmatrix.shape[0], -1)
elif direction == "row":
imatrix = torch.bitwise_right_shift(
qmatrix[:, None, :], shifts[None, :, None]
).view(-1, qmatrix.shape[-1])
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
return imatrix
def apply_order(
imatrix: torch.Tensor,
direction: str = "column",
order: List[int] = AWQ_PACK_ORDER,
):
"""
Applies the order to a 4-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of applying order, either "column" or "row"
order (List[int]): order to apply, default is AWQ_PACK_ORDER
Returns:
imatrix (torch.Tensor): matrix of integers
"""
if direction == "column":
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
elif direction == "row":
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
return imatrix
def fast_awq_to_gptq(qweight, qzeros):
# awq uses column packing for both weights and zeros
izeros = unpack(qzeros, direction="column")
iweights = unpack(qweight, direction="column")
# Reverse the order of the iweight and izeros tensors
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
izeros = izeros - 1
# exllama uses row packing for weights and column packing for zeros
qzeros = pack(izeros, direction="column")
qweight = pack(iweights, direction="row")
return qweight, qzeros

View File

@ -0,0 +1,3 @@
from .hpu import WQLinear
__all__ = ["WQLinear"]

View File

@ -0,0 +1,134 @@
from typing import Optional
import torch
import torch.nn as nn
try:
import habana_frameworks.torch.hpu # noqa: F401
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device)
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(
qzeros[:, :, None], shifts[None, None, :]
).to(
torch.int8 # smallest dtype available
)
izeros = izeros.view(izeros.shape[0], -1)
else:
izeros = qzeros
return iweights, izeros
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(
iweights.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
if izeros is not None:
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
return iweights, izeros
def unpack_weight_and_zeros(qweight, qzeros, bits):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
return iweight, izeros
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros(
(normal.shape[0], normal.shape[1] // 32 * bits),
dtype=torch.int32,
device=input.device,
)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
self._preprocessing()
def _preprocessing(self):
device = self.qweight.device
weight, zeros = unpack_weight_and_zeros(
self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
)
self.qweight = pack_tensor(weight).to(device)
self.qzeros = pack_tensor(zeros).to(device)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
x = x.reshape(-1, x.shape[-1])
weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
outputs = torch.matmul(x, weights)
outputs = outputs + self.bias if self.bias is not None else outputs
outputs = outputs.reshape(out_shape)
return outputs

View File

@ -0,0 +1,124 @@
from dataclasses import dataclass
import bitsandbytes as bnb
import torch
from bitsandbytes.nn import Int8Params, Params4bit
from text_generation_server.utils.weights import UnquantizedWeight
@dataclass
class BNBWeight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
class Linear8bitLt(torch.nn.Module):
def __init__(
self,
weight,
bias,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__()
assert (
not memory_efficient_backward
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
# Necessary for stacked layers
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
weight.data,
has_fp16_weights=has_fp16_weights,
requires_grad=has_fp16_weights,
)
self.weight.cuda(weight.device)
self.bias = bias
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
def forward(self, x: torch.Tensor):
self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights:
if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
return out
@dataclass
class BNBFP4Weight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="fp4")
@dataclass
class BNBNF4Weight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="nf4")
class Linear4bit(torch.nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()
self.weight = Params4bit(
weight.data,
requires_grad=False,
compress_statistics=True,
quant_type=quant_type,
)
self.compute_dtype = None
self.weight.cuda(weight.device)
self.bias = bias
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, "quant_state", None) is None:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
)
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
)
out = out.to(inp_dtype)
return out

View File

@ -0,0 +1,41 @@
from accelerate import init_empty_weights
import torch
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = torch.nn.Parameter(weight)
conv2d.bias = torch.nn.Parameter(bias)
return conv2d
@classmethod
def load_conv2d_no_bias(
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = torch.nn.Parameter(weight)
conv2d.bias = None
return conv2d
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias

View File

@ -0,0 +1,78 @@
from dataclasses import dataclass
from typing import List, Union
import torch
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass
class Exl2Weight(Weight):
"""
Exllama2 exl2 quantized weights.
"""
q_weight: torch.Tensor
q_scale: torch.Tensor
q_invperm: torch.Tensor
q_scale_max: torch.Tensor
q_groups: torch.Tensor
def __post_init__(self):
self.q_scale_max /= 256
self.q_invperm = self.q_invperm.short()
@property
def device(self) -> torch.device:
return self.q_weight.device
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.gptq import ExllamaQuantLinear
return ExllamaQuantLinear(self, bias)
class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights."""
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str):
# Sharding is not yet supported, so we return the weights as-is.
return self.get_weights(weights, prefix)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2")
def get_weights_row(self, weights: Weights, prefix: str):
# Sharding is not yet supported, so we return the weights as-is.
return self.get_weights(weights, prefix)

View File

@ -0,0 +1,458 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Type, Union, List
import torch
from text_generation_server.utils.weights import (
Weight,
WeightsLoader,
UnquantizedWeight,
Weights,
)
from vllm_hpu_extension.ops import scaled_fp8_quant
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
import habana_frameworks.torch.utils.experimental as htexp
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = torch.float8_e4m3fn
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""
# On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear
def normalize_e4m3fn_to_native_float8(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return weight, weight_scale, input_scale
def per_tensor_dequantize(
tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor],
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
device = tensor.device
dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
# dequant on cpu to avoid nan on gaudi2
tensor = tensor.to("cpu")
fake_qweight = tensor.to(dtype).to(device)
dq_weight = fake_qweight * inv_scale
return dq_weight
def requantize_with_max_scale(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
if is_hpu_gaudi2():
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[idx], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
)
start = end
return weight, max_w_scale_normalized
def fp8_quantize(
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = torch.float8_e4m3fn,
scalar: bool = False,
):
"""
This function returns a reciprocal of the scale, so that a tensor can be unscaled
by multiplying it with the returned scale. If a scale is given through the `scale`
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
shape = weight.shape
qweight, scale = scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
use_per_token_if_dynamic=not scalar,
)
return qweight.reshape(shape), scale
class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(
self,
activation_scale_ub: Optional[float],
to_fp8: bool,
weight_block_size: Optional[List[int]] = None,
):
self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8
self.weight_block_size = weight_block_size
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
@dataclass
class Fp8Weight(Weight):
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None
force_w8a16: bool = False
weight_block_size: Optional[List[int]] = None
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
self.weight, bias, self.dtype
)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
weight=self.weight,
scale=self.weight_scale,
dtype=self.dtype,
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
weight_block_size=self.weight_block_size,
)
class Fp8Linear(torch.nn.Module):
_device_identity_cache = {}
def __init__(
self,
qweight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
weight_block_size: Optional[List[int]] = None,
) -> None:
super().__init__()
self.dtype = dtype
self.qweight = qweight
self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None
self.weight_block_size = weight_block_size
self.scale_upper_bound = scale_upper_bound
self.bias = bias if bias is not None else None
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=True)
return cls(
qweight=qweight,
scale=scale,
dtype=dtype,
bias=bias,
input_scale=None,
scale_upper_bound=None,
)
@classmethod
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None)
return cls(
qweight=weight,
scale=scale,
input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
bias=bias,
dtype=dtype,
weight_block_size=weight_block_size,
)
@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
self.qweight,
scale,
self.scale,
self.weight_block_size,
output_dtype=input.dtype,
)
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
qinput, scale = fp8_quantize(
input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)
output = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1)

View File

@ -0,0 +1,357 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from loguru import logger
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
from .hpu import QuantLinear
@dataclass
class GPTQWeight(Weight):
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: Optional[torch.Tensor]
bits: int
groupsize: int
use_awq_kernel: bool
use_exllama: bool
def __post_init__(self):
if self.scales.dtype == torch.float:
self.scales = self.scales.half()
@property
def device(self) -> torch.device:
return self.qweight.device
def get_linear(self, bias: torch.Tensor):
if self.use_awq_kernel:
try:
from text_generation_server.layers.awq.quantize import WQLinear
return WQLinear(
w_bit=self.bits,
group_size=self.groupsize,
qweight=self.qweight,
qzeros=self.qzeros,
scales=self.scales,
bias=bias,
)
except ImportError:
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
else:
return QuantLinear(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
bias,
self.bits,
self.groupsize,
)
class GPTQWeightsLoader(WeightsLoader):
"""
Loader for GPTQ- and AWQ-quantized weights.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights(self, weights: Weights, prefix: str):
self._get_gptq_params(weights)
use_exllama = True
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
self._get_gptq_params(weights)
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
self._get_gptq_params(weights)
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str):
self._get_gptq_params(weights)
use_exllama = True
desc_act = self.desc_act
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
else:
g_idx = None
if weights.process_group.size() > 1:
if g_idx is not None:
if (
not torch.equal(
# Remove g_idx[0] to adapt the check with TP>1.
(g_idx - g_idx[0]).cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
desc_act = True
from text_generation_server.layers.gptq import (
GPTQWeight,
)
if not desc_act and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
if g_idx is not None:
# qzeros, scales sharded, and g_idx must be adjusted accordingly
g_idx = g_idx - g_idx[0]
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def _get_gptq_params(self, weights: Weights):
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights.has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"

View File

@ -0,0 +1,186 @@
import math
import numpy as np
import torch
import torch.nn as nn
try:
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self.wf = torch.tensor(
list(range(0, 32, self.bits)), dtype=torch.int32
).unsqueeze(0)
self._preprocessing()
def unpack_zeros_from_cuda_old_format(self):
zeros = torch.bitwise_right_shift(
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0),
).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = zeros + 1
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
self.scales.dtype
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
return zeros
def unpack_weight_from_cuda_old_format(self):
weight = torch.bitwise_right_shift(
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1),
).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
return weight
def _preprocessing(self):
orig_device = self.qweight.device
self.qweight = self.qweight.cpu()
weight = self.unpack_weight_from_cuda_old_format()
new_qweight = pack_tensor(weight)
self.qweight = new_qweight.to(orig_device)
# TODO: Support group indexing and remove the check
columns = self.qweight.shape[0]
g_idx_trivial = [i // self.groupsize for i in range(columns)]
g_idx_trivial = torch.tensor(
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
)
assert torch.equal(
self.g_idx, g_idx_trivial
), "Non-trivial tensor g_idx is not supported"
self.qzeros = self.qzeros.cpu()
zeros = self.unpack_zeros_from_cuda_old_format()
new_qzeros = pack_tensor(zeros)
self.qzeros = new_qzeros.to(orig_device)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
out = torch.matmul(x, weight)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,56 @@
import torch
# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
def torch_snr_error(
y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
) -> torch.Tensor:
"""
Compute SNR between y_pred(tensor) and y_real(tensor)
SNR can be calcualted as following equation:
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
torch.Tensor: _description_
"""
if y_pred.shape != y_real.shape:
raise ValueError(
f"Can not compute snr loss for tensors with different shape. "
f"({y_pred.shape} and {y_real.shape})"
)
reduction = str(reduction).lower()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(0)
y_real = y_real.unsqueeze(0)
y_pred = y_pred.flatten(start_dim=1)
y_real = y_real.flatten(start_dim=1)
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
if reduction == "mean":
return torch.mean(snr)
elif reduction == "sum":
return torch.sum(snr)
elif reduction == "none":
return snr
else:
raise ValueError("Unsupported reduction method.")

View File

@ -0,0 +1,67 @@
import torch
from torch import nn
from accelerate import init_empty_weights
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
ln = cls(weight.shape, eps=eps)
ln.weight = torch.nn.Parameter(weight)
ln.bias = torch.nn.Parameter(bias)
return ln
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
ln = cls(weight.shape, eps=eps)
ln.weight = torch.nn.Parameter(weight)
ln.bias = None
return ln
torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
return super().forward(hidden_states), residual
class FastRMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float):
super().__init__()
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
@classmethod
def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight")
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
from vllm_hpu_extension.kernels import rms_norm
orig_shape = hidden_states.shape
if residual is not None:
residual += hidden_states.view(residual.shape)
else:
residual = hidden_states
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
if len(orig_shape) == 2:
residual = residual.unsqueeze(0)
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual.view(orig_shape)

View File

@ -0,0 +1,38 @@
import torch
from torch.nn import functional as F
class FastLinear(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
if bias is not None:
self.bias = torch.nn.Parameter(bias, requires_grad=False)
else:
self.bias = None
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
def get_linear(weight, bias):
# Weights that are loaded through methods that are not
# quantization-aware are still bare tensors. We may want
# to change this in the future.
if isinstance(weight, torch.Tensor):
return FastLinear(weight, bias)
return weight.get_linear(bias)

View File

@ -0,0 +1,279 @@
from typing import TYPE_CHECKING, Optional, List
import torch
import torch.distributed
from torch import nn
from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import (
add_lora_a_bgmv,
add_lora_b_bgmv,
has_sgmv,
lora_a_sgmv_cutlass,
lora_b_sgmv_cutlass,
orient_for_rank,
)
if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData
from text_generation_server.adapters.lora import BatchLoraWeights
class LoraLinear(nn.Module):
def __init__(
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
):
super().__init__()
self.base_layer = base_layer
self.layer_id = layer_id
self.process_group = process_group
def forward_layer_type(
self,
result: torch.Tensor,
input: torch.Tensor,
adapter_data: "AdapterBatchData",
layer_type: str,
start_idx: int,
end_idx: int,
) -> torch.Tensor:
if adapter_data is None:
return result
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on
# the layer type (e.g., attention vs. feed-forward layers). We define the current
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
# This approach ensures accurate LoRA application across various layer sizes and
# configurations, adapting to different model architectures and parallelization strategies.
#
# Example scenarios where this is necessary:
# 1. The adapter's size doesn't evenly divide across GPUs.
# 2. We're processing the last segment which might be smaller.
# 3. Different projection layers (q, k, v) have different sizes.
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill
v = lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
for adapter_index in adapter_data.meta.adapter_set:
if data is not None and data.has_adapter(adapter_index):
adapter_mask = (
(adapter_data.meta.adapter_indices == adapter_index)
.to(input.dtype)
.view(-1, 1)
)
layer_result = self.forward_lora(
input, data, adapter_index, adapter_mask
)
result[:, start_idx:end_idx] += layer_result
return result
def forward_lora(
self,
input: torch.Tensor,
data: "BatchLoraWeights",
adapter_index: int,
adapter_mask: torch.Tensor,
) -> torch.Tensor:
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
lora_a = orient_for_rank(lora_a, lora_b.size(0))
a_out = input @ lora_a
if self.process_group.size() > 1:
a_out = self.collect_lora_a(a_out)
result = (a_out @ lora_b) * adapter_mask
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Implemented in subclasses")
class TensorParallelMultiAdapterLinear(LoraLinear):
def __init__(
self,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
super().__init__(base_layer, layer_id, process_group)
self.layer_names = layer_names
self.sizes = sizes
@classmethod
def load(
cls,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
return TensorParallelMultiAdapterLinear(
base_layer, layer_id, layer_names, sizes, process_group
)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
# noop if no layer names are provided (e.g. for models without adapters)
if self.layer_names is None:
return result
# handle models like Bloom that have inputs of shape
# (batch_size, sequence_length, hidden_size)
# we need to reshape them to (batch_size * sequence_length, hidden_size)
# for the LoRA computation, then reshape back
prev_shape = result.shape
is_3d = len(input.shape) >= 3
if is_3d:
input = input.reshape(-1, input.shape[-1])
result = result.reshape(-1, result.shape[-1])
offset = 0
for i, layer_name in enumerate(self.layer_names):
start_idx = offset // self.process_group.size()
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
# different projection sizes across layers and model architectures.
if self.sizes is not None:
offset += self.sizes[i]
end_idx = offset // self.process_group.size()
else:
end_idx = result.shape[1]
result = self.forward_layer_type(
result, input, adapter_data, layer_name, start_idx, end_idx
)
if is_3d:
result = result.reshape(prev_shape)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
gathered_tensors = [
torch.empty_like(a_out) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(gathered_tensors, a_out)
return torch.cat(gathered_tensors, dim=1)
class TensorParallelAdapterRowLinear(LoraLinear):
def __init__(self, base_layer, layer_id, layer_name, process_group):
super().__init__(base_layer, layer_id, process_group)
self.layer_name = layer_name
@classmethod
def load(cls, base_layer, layer_id, layer_name, process_group):
return cls(base_layer, layer_id, layer_name, process_group)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
if self.layer_name is None:
return result
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
stride = result.shape[-1] // self.process_group.size()
start_idx = self.process_group.rank() * stride
end_idx = (self.process_group.rank() + 1) * stride
self.forward_layer_type(
result, input, adapter_data, self.layer_name, start_idx, end_idx
)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
torch.distributed.all_reduce(a_out, group=self.process_group)
return a_out

View File

@ -0,0 +1,191 @@
import torch
from torch import nn
from typing import Tuple, Optional
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.layers.linear import FastLinear
from text_generation_server.layers.tensor_parallel import (
TensorParallelHead,
TensorParallelColumnLinear,
)
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(self, config, medusa_config, weights):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(get_speculate())
]
)
def forward(self, x):
if not self.heads:
return None
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, medusa_config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(medusa_config["medusa_num_layers"])
]
)
n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x
class MedusaHeadV1(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.lm_head = lm_head
self.medusa = medusa
@staticmethod
def load(config, prefix: str, weights):
from pathlib import Path
from safetensors import safe_open
import json
speculator = config.speculator
path = speculator["path"]
medusa_config = str(Path(path) / "config.json")
for fname in speculator["model_paths"]:
filename = str(Path(path) / fname)
with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
medusa = MedusaModel(config, medusa_config, weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MedusaHeadV1(lm_head, medusa)
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
speculative_logits = self.medusa(input)
return logits, speculative_logits
class MedusaHeadV2(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
from pathlib import Path
from safetensors import safe_open
import json
speculator_path = config.speculator["path"]
medusa_config = str(Path(speculator_path) / "config.json")
filename = str(Path(speculator_path) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
self.n_medusa_heads = get_speculate()
assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
dim=0,
weights=weights,
bias=True,
)
self.process_group = weights.process_group
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
self.act = torch.nn.SiLU()
self.lm_head = TensorParallelHead.load(config, prefix, weights)
def forward(self, x):
# If we have too many tokens, we skip speculative logits
if x.shape[0] > 128:
logits = self.lm_head(x)
return logits, None
size = x.shape[-1]
block_size = (size + self.world_size - 1) // self.world_size
start = self.rank * block_size
stop = (self.rank + 1) * block_size
x_block = x[:, start:stop]
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
medusa_res = self.act(self.linear(x)).reshape(
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
)
# Apply all residual medusa heads
output = x[:, start:stop].unsqueeze(-2) + medusa_res
# Gather medusa heads
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
# Stack x and medusa residual x
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
# Compute lm head on x + medusa residual x
logits = self.lm_head(stacked_x)
# Finally, split logits from speculative logits
logits, speculative_logits = torch.split(
logits, [1, self.n_medusa_heads], dim=-2
)
# Squeeze added dimension
logits = logits.squeeze(-2)
return logits, speculative_logits

View File

@ -0,0 +1,282 @@
import torch
import math
from torch import nn
from torch.nn import functional as F
from typing import Optional, Tuple
from text_generation_server.layers import TensorParallelEmbedding, FastLinear
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.utils.speculate import get_speculate
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
elementwise_scale_weight : torch.Tensor
learned scaling term after normalization?
elementwise_shift_bias : torch.Tensor
learned bias term after normalization?
eps : float
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
"""
def __init__(
self,
prefix,
config,
weights,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = weights.get_tensor(f"{prefix}.weight")
self.bias = weights.get_tensor(f"{prefix}.bias")
self.eps = eps
def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x
INV_SQRT2 = 2**-0.5
def simple_norm(x: torch.Tensor, eps=1e-06):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
x = xf.type_as(x)
return x * INV_SQRT2
class MLPSpeculatorModelTied(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
self.proj0 = FastLinear.load(
config,
prefix=f"{prefix}.proj.0",
weights=weights,
bias=False,
)
self.proj1 = FastLinear.load(
config,
prefix=f"{prefix}.proj.1",
weights=weights,
bias=False,
)
self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
self.ln = MLPSpeculatorLayerNorm(
prefix=f"{prefix}.ln.0",
config=config,
weights=weights,
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
self.activation = nn.GELU()
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
self.inner_dim / 2
)
self.emb.weight *= self.emb_weight
def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
):
top_k_tokens_per_head = self.top_k_tokens_per_head
# k indicates # of candidates
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids.unsqueeze(0)
all_probs = torch.empty(
b, self.n_predict, self.vsize, device=state.device
) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb(ind)
# z = z.mul(self.emb_weight) # b k d
if i == 0:
state = self.proj0(state) * self.state_weight + z
else:
state = self.proj1(state) * self.state_weight + z
state = self.activation(self.ln(state)) # b k d
probs = F.log_softmax(self.head(state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
# Update distribution set with new logits
all_probs[:, i] = probs.exp()
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'
speculative_logits = all_probs
return speculative_logits
class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = nn.ModuleList(
[
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
for i in range(self.n_predict)
]
)
self.proj = [
FastLinear.load(
config,
prefix=f"{prefix}.proj.{i}",
weights=weights,
bias=False,
)
for i in range(self.n_predict)
]
self.head = nn.ModuleList(
[
FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
for i in range(self.n_predict)
]
)
self.ln = nn.ModuleList(
[
MLPSpeculatorLayerNorm(
prefix=f"{prefix}.ln.{i}",
config=config,
weights=weights,
)
for i in range(self.n_predict)
]
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
self.activation = nn.GELU()
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
self.inner_dim / 2
)
self.emb.weight *= self.emb_weight
def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
):
top_k_tokens_per_head = self.top_k_tokens_per_head
# k indicates # of candidates
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids.unsqueeze(0)
all_probs = torch.empty(
b, self.n_predict, self.vsize, device=state.device
) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb[i](ind)
# z = z.mul(self.emb_weight) # b k d
state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
# Update distribution set with new logits
all_probs[:, i] = probs.exp()
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'
speculative_logits = all_probs
return speculative_logits
class MLPSpeculatorHead(nn.Module):
def __init__(self, lm_head, mlp_speculator, scale_input: bool):
super().__init__()
self.lm_head = lm_head
self.mlp_speculator = mlp_speculator
self.scale_input = scale_input
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
input_ids = logits.argmax(dim=-1)
if self.scale_input:
input = simple_norm(input)
speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits
@staticmethod
def load(config, prefix: str, weights):
from pathlib import Path
from safetensors import safe_open
speculator_path = config.speculator["path"]
for fname in config.speculator["model_paths"]:
filename = str(Path(speculator_path) / fname)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
tie_weights = config.speculator_config.get("tie_weights", False)
if tie_weights:
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
else:
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
# This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
scale_input = config.speculator_config.get("scale_input", False)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)

View File

@ -0,0 +1,250 @@
from typing import Optional, Protocol, runtime_checkable
import torch
import torch.nn as nn
from loguru import logger
from transformers.activations import ACT2FN
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
Weights,
UnquantizedWeight,
)
from .fused_moe import fused_topk, grouped_topk
# NOTE: we are using a protocol here, because multiple inherance is not nice.
# We need `Module`, and `Module` -> some abstract class -> some concrete
# class inheritance is whacky.
@runtime_checkable
class MoELayer(Protocol):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ...
def forward(
self, x: torch.Tensor, *, gating_output: torch.Tensor
) -> torch.Tensor: ...
class DenseMoELayer(nn.Module):
"""
Layer for MoE that applies *all* experts to each tokens and then weights
their outputs based on the calculated routing. This layer is much slower
than `SparseMoELayer` and should only be used when no fused kernels are
available (e.g. for unsupported quantizers).
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
):
super().__init__()
assert scoring_func is None, "scoring func is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
log_once(
logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer",
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.n_experts = n_experts
self.renormalize = renormalize
self.topk = topk
self.topk_group = topk_group
if "gelu" in hidden_act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh"
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
),
)
elif "silu" in hidden_act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[hidden_act]
self.gate_proj = [
TensorParallelColumnLinear.load(
None,
prefix=f"{prefix}.{i}.{gate_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.up_proj = [
TensorParallelColumnLinear.load(
None,
prefix=f"{prefix}.{i}.{up_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.down_proj = [
TensorParallelRowLinear.load(
None,
prefix=f"{prefix}.{i}.{down_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.process_group = weights.process_group
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gating_output: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if self.n_expert_group is not None and self.topk_group is not None:
topk_weights, topk_ids = grouped_topk(
x,
gating_output,
self.topk,
renormalize=self.renormalize,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
else:
topk_weights, topk_ids = fused_topk(
x, gating_output, self.topk, self.renormalize
)
topk_weights = topk_weights.to(x.dtype)
weights = torch.zeros(
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
)
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
out = torch.zeros_like(x)
for i in range(self.n_experts):
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
h = self.down_proj[i](h, reduce=False)
out += h * weights[:, i].view(-1, 1)
return out
class SparseMoELayer(nn.Module):
"""
Layer for MoE that uses fused kernels to only apply the active experts
for each token (rather than applying all experts and selecting the
outputs of active experts).
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader):
if (
isinstance(weights.loader, HybridFP8UnquantLoader)
and weights.loader.to_fp8
):
cls = FP8SparseMoELayer
else:
cls = UnquantizedSparseMoELayer
else:
raise ValueError(
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
)
log_once(
logger.info,
"Using MoE layer wih fused gemm",
)
self.moe = cls(
n_expert_group=n_expert_group,
n_experts=n_experts,
prefix=prefix,
renormalize=renormalize,
topk=topk,
topk_group=topk_group,
weights=weights,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
down_proj_name=down_proj_name,
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return self.moe(x, gating_output=gating_output)
@staticmethod
def is_supported(weights: Weights) -> bool:
return (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader)

View File

@ -0,0 +1,173 @@
from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
Fp8Weight,
fp8_quantize,
quant_dtype,
normalize_e4m3fn_to_native_float8,
)
try:
from .unquantized import fused_moe
except Exception:
fused_moe = None
class FP8SparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.gate_up_proj_input_scale,
) = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
)
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
_load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
)
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
a1_scale=self.gate_up_proj_input_scale,
a2_scale=self.down_proj_input_scale,
)
def _load_expert_weights(
get_weight_fn,
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
all_weight_scales = None
max_input_scale = None
for i in range(n_experts):
weight = get_weight_fn(prefix, i, name, weights)
assert isinstance(weight, Fp8Weight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=quant_dtype,
device=weight.weight.device,
)
if all_weight_scales is None:
all_weight_scales = torch.empty(
(n_experts,) + weight.weight_scale.shape,
dtype=torch.float32,
device=weight.weight.device,
)
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
all_weight[i], all_weight_scales[i], current_input_scale = (
normalize_e4m3fn_to_native_float8(
weight.weight, weight.weight_scale, weight.input_scale
)
)
if current_input_scale is not None:
if max_input_scale is None or current_input_scale > max_input_scale:
max_input_scale = current_input_scale
else:
all_weight[i], all_weight_scales[i] = fp8_quantize(
weight.weight, scalar=True
)
assert all_weight is not None
return all_weight, all_weight_scales, max_input_scale
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
)
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}")
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
)

View File

@ -0,0 +1,65 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
topk_weights = torch.nn.functional.softmax(
gating_output, dim=1, dtype=torch.float32
)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids

View File

@ -0,0 +1,121 @@
from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import DynamicFusedMOE
class UnquantizedSparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
)
self.down_proj = _load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
)
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
for i in range(n_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return self.hpu_fused_moe(x, gating_output, self.topk)
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
assert isinstance(weight, UnquantizedWeight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
all_weight[i] = weight.weight
assert all_weight is not None
return all_weight
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
assert isinstance(weight, UnquantizedWeight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
all_weight[i] = weight.weight
assert all_weight is not None
return all_weight

View File

@ -0,0 +1,606 @@
import os
import math
import torch
from torch import nn
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {
"type": os.environ["ROPE_SCALING"],
"factor": float(os.environ["ROPE_FACTOR"]),
}
return rope_scaling
return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
super().__init__()
self.inv_freq = inv_freq
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, inv_freq.device, max_position_embeddings
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
num_tokens = query.shape[0]
head_size = query.shape[-1]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
@classmethod
def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if not hasattr(config, "max_position_embeddings") and hasattr(
config, "max_seq_len"
):
# handling for dbrx
config.max_position_embeddings = config.max_seq_len
if rope_scaling is not None:
# `rope_type` is now standard in transformers, but some existing models
# have `type` instead.
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
if rope_type == "linear":
pass
elif rope_type == "default":
pass
elif rope_type == "mrope":
mrope_section = rope_scaling["mrope_section"]
if mrope_section is not None:
return RotaryPositionEmbeddingMultimodalSections(
inv_freq,
scaling_factor,
mrope_section,
config.max_position_embeddings,
)
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
elif rope_type == "llama3":
inv_freq = apply_llama3_scaling(
inv_freq,
scaling_factor=rope_scaling["factor"],
low_freq_factor=rope_scaling["low_freq_factor"],
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
elif rope_type == "yarn":
scaling_factor = rope_scaling["factor"]
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=base,
device=inv_freq.device,
scaling_factor=scaling_factor,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
)
elif rope_type in ["su", "longrope"]:
short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device
)
short_inv_freq = 1.0 / (
short_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
)
long_factor = torch.tensor(
rope_scaling["long_factor"], dtype=torch.float32, device=device
)
long_inv_freq = 1.0 / (
long_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
)
original_max_position_embeddings = (
config.original_max_position_embeddings
)
max_position_embeddings = config.max_position_embeddings
if max_position_embeddings <= original_max_position_embeddings:
scaling_factor = 1.0
else:
scale = max_position_embeddings / original_max_position_embeddings
scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(original_max_position_embeddings)
)
# if short_mscale and long_mscale are provided we need to scale the freqs
# using the Phi3LongRoPEScaledRotaryEmbedding
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
short_mscale = rope_scaling["short_mscale"]
long_mscale = rope_scaling["long_mscale"]
return Phi3LongRoPEScaledRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
max_position_embeddings=config.max_position_embeddings,
short_mscale=short_mscale,
long_mscale=long_mscale,
original_max_position_embeddings=original_max_position_embeddings,
)
return SuRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
scaling_factor=scaling_factor,
original_max_position_embeddings=original_max_position_embeddings,
max_position_embeddings=config.max_position_embeddings,
)
else:
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
@classmethod
def load(cls, config, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=config.max_position_embeddings,
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
elif rope_scaling["type"] == "yarn":
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
)
else:
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor):
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1)
class SuRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
short_inv_freq,
long_inv_freq,
scaling_factor,
original_max_position_embeddings,
max_position_embeddings,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
self.long_inv_freq = long_inv_freq
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
short_freqs = torch.outer(
t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
freqs = torch.cat([short_freqs, long_freqs])
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
short_inv_freq: torch.Tensor,
long_inv_freq: torch.Tensor,
max_position_embeddings: int,
short_mscale: float,
long_mscale: float,
original_max_position_embeddings: int,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
self.long_inv_freq = long_inv_freq
self.max_position_embeddings = max_position_embeddings
self.short_mscale = short_mscale
self.long_mscale = long_mscale
self.original_max_position_embeddings = original_max_position_embeddings
# cache
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
short_freqs = torch.outer(
t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
short_freqs = short_freqs * self.short_mscale
long_freqs = long_freqs * self.long_mscale
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
freqs[: self.original_max_position_embeddings] = short_freqs
freqs[self.original_max_position_embeddings :] = long_freqs
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
newbase = self.base * (
(self.scaling_factor * seqlen / self.max_position_embeddings)
- (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(
self.dim, newbase, self.inv_freq.device
)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
# Find dim range bounds based on rotations
def find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def get_mscale(scale: float = 1.0, mscale: float = 1.0):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings,
base,
device,
scaling_factor,
*,
extrapolation_factor,
attn_factor,
beta_fast,
beta_slow,
mscale: float,
mscale_all_dim: float,
):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale_all_dim = mscale_all_dim
self.scaling_factor = scaling_factor
self.mscale = float(
get_mscale(self.scaling_factor, mscale)
/ get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings or True:
inv_freq_extrapolation = _create_inv_freq(
self.dim, self.base, self.inv_freq.device
)
freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
low, high = find_correction_range(
self.beta_fast,
self.beta_slow,
self.dim,
self.base,
self.max_position_embeddings,
)
inv_freq_mask = (
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
self.inv_freq = inv_freq
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
def apply_llama3_scaling(
freqs: torch.Tensor,
*,
scaling_factor: int,
low_freq_factor: int,
high_freq_factor: int,
original_max_position_embeddings: int,
):
low_freq_wavelen = original_max_position_embeddings / low_freq_factor
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
def __init__(
self,
inv_freq: torch.Tensor,
scaling_factor: float,
sections: list,
max_position_embeddings,
):
self.sections = sections
self._cos_cached = None
self._sin_cached = None
self.section_indices = (
torch.arange(len(self.sections))
.repeat_interleave(torch.tensor(self.sections))
.view(1, 1, -1)
.to(inv_freq.device)
)
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(
self, dtype: torch.dtype, device: torch.device, seqlen: int
):
# always cache the cos/sin for the full sequence length to avoid
# recomputing if the sequence length is smaller than the cached one
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
self._sections = self.section_indices.expand(seqlen, -1, -1)
def get_cos_sin(
self,
position_ids: torch.Tensor,
):
slen = position_ids.shape[0]
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
return cos, sin

View File

@ -0,0 +1,52 @@
import torch
import json
from typing import Tuple, Optional
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.mlp import MLPSpeculatorHead
class SpeculativeHead(torch.nn.Module):
def __init__(self, lm_head, speculator):
super().__init__()
self.head = lm_head
self.speculator = speculator
@staticmethod
def load(config, prefix: str, weights):
speculator = config.speculator
if speculator:
speculator_path = config.speculator["path"]
speculator_config = str(speculator_path / "config.json")
with open(speculator_config, "r") as f:
speculator_config = json.load(f)
config.speculator_config = speculator_config
try:
architecture = speculator_config["architectures"][0]
if architecture == "MLPSpeculatorPreTrainedModel":
speculator = MLPSpeculatorHead.load(config, prefix, weights)
else:
speculator = None
except KeyError:
try:
speculator = MedusaHeadV1.load(config, prefix, weights)
except Exception:
speculator = MedusaHeadV2(config, prefix, weights)
lm_head = None
else:
lm_head = TensorParallelHead.load(config, prefix, weights)
speculator = None
return SpeculativeHead(lm_head, speculator)
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.speculator is not None:
return self.speculator(input)
assert self.head is not None
logits = self.head(input)
return logits, None

View File

@ -0,0 +1,244 @@
import torch
from torch.nn import functional as F
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
import habana_frameworks.torch as htorch
class LayerConcat(torch.nn.Module):
"""
Apply multiple layers to the input and concatenate their
outputs.
"""
def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
"""
`dim` is the dimension along which layer outputs are concatenated.
"""
super().__init__()
self.layers = layers
self.dim = dim
def forward(self, x: torch.Tensor):
outputs = [layer(x) for layer in self.layers]
return torch.cat(outputs, self.dim)
class SuperLayer(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear.forward(x)
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear)
self.process_group = process_group
self.should_gather = should_gather
@staticmethod
def load(config, prefix: str, weights):
if config.quantize == "exl2":
try:
# If the piece and LM head embeddings are shared, we have
# non-quantized weights...
weight = weights.get_tensor(f"{prefix}.weight")
except Exception:
# ...otherwise they are quantized.
weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1
elif weights.process_group.size() > 1:
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
# If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
return TensorParallelHead(
get_linear(weight, bias=None),
process_group=weights.process_group,
should_gather=should_gather,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if not self.should_gather:
return super().forward(input)
world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0]
if input.shape[0] == 1:
world_out = input.new_empty(1, out_dim * world_size)
local_out = input.new_empty(1, out_dim)
gather_input = local_out
else:
world_out = input.new_empty(out_dim * world_size, input.shape[0])
gather_input = input.new_empty(out_dim, input.shape[0])
local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out)
htorch.core.mark_step()
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1:
return world_out
return world_out.T
output = super().forward(input)
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
htorch.core.mark_step()
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up(prefix)
if bias:
raise NotImplementedError("packed_gate_up only implemented without bias")
else:
bias = None
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
def load_qkv(
cls,
config,
prefix: str,
weights,
bias: bool,
num_heads: int,
num_key_value_heads: int,
):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv(
prefix,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
)
if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan")
else:
bias = None
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_weights_col(prefix)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
if config.quantize == "exl2":
linears = []
for prefix in prefixes:
weight = weights.get_weights_col(prefix)
b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b))
linear = LayerConcat(linears)
else:
weight = weights.get_multi_weights_col(prefixes, dim=dim)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim)
else:
bias = None
linear = get_linear(weight, bias)
return cls(linear)
class TensorParallelRowLinear(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
self.process_group = process_group
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(
get_linear(weight, bias),
process_group=weights.process_group,
)
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1 and reduce:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(torch.nn.Module):
def __init__(self, prefix: str, weights, reduce=True):
super().__init__()
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group
world_size = process_group.size()
rank = process_group.rank()
block_size = (num_embeddings + world_size - 1) // world_size
self.min_id = rank * block_size
self.max_id = min(num_embeddings, (rank + 1) * block_size)
self.null_idx = weight.shape[
0
] # Usually block_size, might be less in non even vocab_size.
self.process_group = weights.process_group
self.reduce = reduce
"""Additional 0 entry used for masking"""
self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = torch.where(
(self.min_id > input) | (input >= self.max_id),
self.null_idx,
input - self.min_id,
)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
torch.distributed.all_reduce(out, group=self.process_group)
return out

View File

@ -0,0 +1,994 @@
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch
import os
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from pathlib import Path
from typing import List, Dict
import enum
# Needed to properly setup habana_frameworks
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
)
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
load_and_merge_adapters,
AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights
from text_generation_server.utils.log import log_master
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
__all__ = [
"Model",
"CausalLM",
"Seq2SeqLM",
"get_model_with_lora_adapters",
]
from text_generation_server.models.globals import ATTENTION
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = False
if ATTENTION == "paged":
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
)
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
FlashDeepseekV3ForCausalLM,
DeepseekV3Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.models.pali_gemma import (
PaliGemmaBatch,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.custom_modeling.flash_mllama import (
FlashMllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_llava_next import (
FlashLlavaNextForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
FlashStarcoder2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
FlashGPTJForCausalLM,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.idefics3 import (
Idefics3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
Qwen2_5VLForConditionalGeneration,
Qwen2_5_VLConfig,
Qwen2_5_VLProcessor,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
class ModelType(enum.Enum):
DEEPSEEK_V2 = {
"type": "deepseek_v2",
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
DEEPSEEK_V3 = {
"type": "deepseek_v3",
"name": "Deepseek V3",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
}
IDEFICS2 = {
"type": "idefics2",
"name": "Idefics 2",
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
"multimodal": True,
}
IDEFICS3 = {
"type": "idefics3",
"name": "Idefics 3",
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
"multimodal": True,
}
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
"multimodal": True,
}
LLAMA = {
"type": "llama",
"name": "Llama",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
PHI3 = {
"type": "phi3",
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
PALIGEMMA = {
"type": "paligemma",
"name": "PaliGemma",
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
}
GEMMA2 = {
"type": "gemma2",
"name": "Gemma2",
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
}
COHERE = {
"type": "cohere",
"name": "Cohere",
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
}
DBRX = {
"type": "dbrx",
"name": "Dbrx",
"url": "https://huggingface.co/databricks/dbrx-instruct",
}
MAMBA = {
"type": "mamba",
"name": "Mamba",
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
}
MISTRAL = {
"type": "mistral",
"name": "Mistral",
"url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
}
MIXTRAL = {
"type": "mixtral",
"name": "Mixtral",
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
}
GPT_BIGCODE = {
"type": "gpt_bigcode",
"name": "Gpt Bigcode",
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
}
PHI = {
"type": "phi",
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
PHI_MOE = {
"type": "phimoe",
"name": "PhiMoe",
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
}
FALCON = {
"type": "falcon",
"name": "Falcon",
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
}
STARCODER2 = {
"type": "starcoder2",
"name": "StarCoder 2",
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
}
QWEN2 = {
"type": "qwen2",
"name": "Qwen 2",
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
}
QWEN2_VL = {
"type": "qwen2_vl",
"name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
}
QWEN2_5_VL = {
"type": "qwen2_5_vl",
"name": "Qwen 2.5 VL",
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
}
GALACTICA = {
"type": "galactica",
"name": "Galactica",
"url": "https://huggingface.co/facebook/galactica-120b",
}
SANTACODER = {
"type": "santacoder",
"name": "SantaCoder",
"url": "https://huggingface.co/bigcode/santacoder",
}
GPT2 = {
"type": "gpt2",
"name": "Gpt2",
"url": "https://huggingface.co/openai-community/gpt2",
}
GPT_NEOX = {
"type": "gpt_neox",
"name": "Gpt Neox",
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
}
GPTJ = {
"type": "gptj",
"name": "Gptj",
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
}
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
"multimodal": True,
}
__GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
# Disable gradients
torch.set_grad_enabled(False)
def get_model(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
global FLASH_ATTENTION
if speculate is not None:
set_speculate(speculate)
else:
set_speculate(0)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type", None)
speculator = None
if "medusa_num_heads" in config_dict:
medusa_model_id = model_id
medusa_revision = revision
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(medusa_model_id).exists()
if not is_local:
medusa_config = hf_hub_download(
medusa_model_id, revision=medusa_revision, filename="config.json"
)
hf_hub_download(
medusa_model_id,
revision=medusa_revision,
filename="medusa_lm_head.safetensors",
)
speculator = {
"path": Path(medusa_config).parent,
"model_paths": ["medusa_lm_head.safetensors"],
}
else:
speculator = {
"path": Path(medusa_model_id),
"model_paths": ["medusa_lm_head.safetensors"],
}
method = "medusa"
elif model_type == "mlp_speculator":
mlp_model_id = model_id
mlp_revision = revision
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_mlp = config_dict["n_predict"]
if speculate is not None:
if speculate > speculate_mlp:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_mlp)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(mlp_model_id).exists()
extension = ".safetensors"
if not is_local:
mlp_speculator_config = hf_hub_download(
mlp_model_id, revision=mlp_revision, filename="config.json"
)
api = HfApi()
info = api.model_info(mlp_model_id, revision=mlp_revision)
filenames = [
s.rfilename
for s in info.siblings
if s.rfilename.endswith(extension)
and len(s.rfilename.split("/")) == 1
and "arguments" not in s.rfilename
and "args" not in s.rfilename
and "training" not in s.rfilename
]
for filename in filenames:
hf_hub_download(
mlp_model_id,
revision=mlp_revision,
filename=filename,
)
speculator_dir_path = Path(mlp_speculator_config).parent
# if these are downloaded, they get converted to safetensors
filenames.extend(
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
)
speculator = {
"path": Path(mlp_speculator_config).parent,
"model_paths": filenames,
}
else:
speculator = Path(mlp_model_id)
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
speculator = {"path": speculator, "model_paths": filenames}
method = "mlp_speculator"
else:
method = "n-gram"
speculate = get_speculate()
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"]
kv_cache_dtype = dtype
if FLASH_ATTENTION:
if model_type == DEEPSEEK_V2:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config,
head_size=head_size,
)
elif model_type == DEEPSEEK_V3:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV3Config,
head_size=head_size,
)
elif (
model_type == GPT_BIGCODE
or model_type == GPT2
and model_id.startswith("bigcode/")
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashSantacoderForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
)
elif model_type == GPT2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPT2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPTJ:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTJForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPT_NEOX:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTNeoXForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
)
elif model_type == PHI:
return FlashCausalLM(
model_id=model_id,
model_class=FlashPhiForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == PHI_MOE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
config_class=PhiMoEConfig,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == BAICHUAN:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == COHERE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashCohereForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == DBRX:
return FlashCausalLM(
model_id=model_id,
model_class=FlashDbrxForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DbrxConfig,
)
elif (
model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
and not sharded
and not config_dict.get("alibi", False)
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
)
elif model_type == MISTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MIXTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == STARCODER2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashStarcoder2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2_5_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
)
elif model_type == MLLAMA:
return FlashMllamaCausalLM(
model_id=model_id,
model_class=FlashMllamaForConditionalGeneration,
batch_class=FlashMllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == IDEFICS2:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics2ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
elif model_type == IDEFICS3:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}},
)
elif model_type == PALIGEMMA:
return FlashVlmCausalLM(
model_id=model_id,
model_class=PaliGemmaForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
)
elif model_type == LLAVA_NEXT:
return FlashVlmCausalLM(
model_class=FlashLlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
adapt_transformers_to_gaudi()
if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if model_type == "gpt_bigcode":
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
if model_type == "bloom":
return BLOOM(
model_id=model_id,
revision=revision,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "llava_next":
return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "mllama":
return VlmCausalLM(
model_class=MllamaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise ValueError(f"Unsupported model type {model_type}")
# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
model_id: str,
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
trust_remote_code: bool,
max_input_tokens: int,
adapter_to_index: Dict[str, int],
):
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
model = get_model(
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
max_input_tokens,
)
if len(lora_adapters) > 0:
target_to_layer = build_layer_weight_lookup(model.model)
for index, adapter in enumerate(lora_adapters):
# The AdapterParameters object allows for merging multiple adapters into a single adapter.
# At the moment, we only support loading a single adapter into the model, but we keep the
# AdapterParameters object for easier extension in the future.
adapter_parameters = AdapterParameters(
adapter_info=[adapter],
# when merging multiple adapters we can weight them differently
# if this is not set, all adapters will be weighted equally
# see: text_generation_server.utils.merges.strategies for impl
weights=None,
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index + 1
adapter_to_index[adapter.id] = adapter_index
logger.info(
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
)
weight_names = tuple([v[0] for v in target_to_layer.values()])
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_and_merge_adapters(
model.model_id,
adapter_parameters,
adapter_index,
weight_names,
False,
)
unused_weight_names = adapter_weight_names.copy()
adapter_layers = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"qkv_proj",
]
for layer_name in adapter_layers:
nlayers = (
1 if layer_name == "lm_head" else len(model.model.model.layers)
)
adapter_weights = LoraWeights.prepare_weights(
config=adapter_config,
module_map=module_map,
layer_type=layer_name,
unused_weight_names=unused_weight_names,
nlayers=nlayers,
dtype=model.dtype,
world_size=model.world_size,
process_group=model.process_group,
target_to_layer=target_to_layer,
)
if adapter_weights is None:
continue
model.layer_to_adapter_weights[layer_name].add_adapter(
adapter_index, adapter_weights
)
if len(unused_weight_names) > 0:
logger.warning(
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
)
if adapter_tokenizer is not None:
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
model.loaded_adapters.add(adapter_index)
return model

View File

@ -0,0 +1,52 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import torch
from typing import Optional, Type
from transformers import PreTrainedTokenizerBase
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
class BloomCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(
pb=pb,
tokenizer=tokenizer,
dtype=dtype,
device=device,
)
batch.keys_head_dim_last = False
return batch
class BLOOM(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id,
revision=revision,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More