mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
v1.4.0 (#1494)
This commit is contained in:
parent
ac580f515b
commit
efd4b97d15
12
.github/workflows/delete_doc_comment.yml
vendored
12
.github/workflows/delete_doc_comment.yml
vendored
@ -1,12 +0,0 @@
|
||||
name: Delete doc comment
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [ closed ]
|
||||
|
||||
|
||||
jobs:
|
||||
delete:
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
|
||||
with:
|
||||
pr_number: ${{ github.event.number }}
|
12
Cargo.lock
generated
12
Cargo.lock
generated
@ -1398,9 +1398,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "minijinja"
|
||||
version = "1.0.10"
|
||||
version = "1.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "208758577ef2c86cf5dd3e85730d161413ec3284e2d73b2ef65d9a24d9971bcb"
|
||||
checksum = "fb5c5e3d2b4c0a6832bd3d571f7c19a7c1c1f05f11a6e85ae1a29f76be5f9455"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
@ -2811,7 +2811,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "1.3.4"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap",
|
||||
@ -2832,7 +2832,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "1.3.4"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"grpc-metadata",
|
||||
@ -2849,7 +2849,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "1.3.4"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"ctrlc",
|
||||
@ -2865,7 +2865,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "1.3.4"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
|
@ -9,7 +9,7 @@ members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "1.3.4"
|
||||
version = "1.4.0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
1294
docs/openapi.json
1294
docs/openapi.json
File diff suppressed because one or more lines are too long
@ -19,6 +19,6 @@ docker run --gpus all \
|
||||
--shm-size 1g \
|
||||
-e HUGGING_FACE_HUB_TOKEN=$token \
|
||||
-p 8080:80 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 \
|
||||
--model-id $model
|
||||
```
|
||||
|
@ -8,7 +8,7 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
|
||||
model=tiiuae/falcon-7b-instruct
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
@ -20,7 +20,7 @@ To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://d
|
||||
TGI also supports ROCm-enabled AMD GPUs (only MI210 and MI250 are tested), details are available in the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). To launch TGI on ROCm GPUs, please use instead:
|
||||
|
||||
```bash
|
||||
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3-rocm --model-id $model
|
||||
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model
|
||||
```
|
||||
|
||||
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
||||
@ -91,7 +91,7 @@ curl 127.0.0.1:8080/generate \
|
||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```bash
|
||||
docker run ghcr.io/huggingface/text-generation-inference:1.3 --help
|
||||
docker run ghcr.io/huggingface/text-generation-inference:1.4 --help
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
@ -21,7 +21,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.generated_text == ": {request}\")\n response = self"
|
||||
assert response.generated_text == ': {request}")\n response = self'
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@ -52,14 +52,12 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_phi, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"{[r.generated_text for r in responses]}"
|
||||
assert responses[0].generated_text == ": {request}\")\n response = self"
|
||||
assert responses[0].generated_text == ': {request}")\n response = self'
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation-integration-tests"
|
||||
version = "1.3.4"
|
||||
version = "1.4.0"
|
||||
description = "Text Generation Inference integration tests"
|
||||
authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
||||
|
||||
|
82
server/poetry.lock
generated
82
server/poetry.lock
generated
@ -1812,13 +1812,13 @@ xmp = ["defusedxml"]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.4.0"
|
||||
version = "1.5.0"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"},
|
||||
{file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"},
|
||||
{file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
|
||||
{file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@ -1886,51 +1886,51 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "15.0.2"
|
||||
version = "16.0.0"
|
||||
description = "Python library for Apache Arrow"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyarrow-15.0.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:88b340f0a1d05b5ccc3d2d986279045655b1fe8e41aba6ca44ea28da0d1455d8"},
|
||||
{file = "pyarrow-15.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eaa8f96cecf32da508e6c7f69bb8401f03745c050c1dd42ec2596f2e98deecac"},
|
||||
{file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23c6753ed4f6adb8461e7c383e418391b8d8453c5d67e17f416c3a5d5709afbd"},
|
||||
{file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f639c059035011db8c0497e541a8a45d98a58dbe34dc8fadd0ef128f2cee46e5"},
|
||||
{file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:290e36a59a0993e9a5224ed2fb3e53375770f07379a0ea03ee2fce2e6d30b423"},
|
||||
{file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06c2bb2a98bc792f040bef31ad3e9be6a63d0cb39189227c08a7d955db96816e"},
|
||||
{file = "pyarrow-15.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:f7a197f3670606a960ddc12adbe8075cea5f707ad7bf0dffa09637fdbb89f76c"},
|
||||
{file = "pyarrow-15.0.2-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:5f8bc839ea36b1f99984c78e06e7a06054693dc2af8920f6fb416b5bca9944e4"},
|
||||
{file = "pyarrow-15.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f5e81dfb4e519baa6b4c80410421528c214427e77ca0ea9461eb4097c328fa33"},
|
||||
{file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4f240852b302a7af4646c8bfe9950c4691a419847001178662a98915fd7ee7"},
|
||||
{file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e7d9cfb5a1e648e172428c7a42b744610956f3b70f524aa3a6c02a448ba853e"},
|
||||
{file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2d4f905209de70c0eb5b2de6763104d5a9a37430f137678edfb9a675bac9cd98"},
|
||||
{file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:90adb99e8ce5f36fbecbbc422e7dcbcbed07d985eed6062e459e23f9e71fd197"},
|
||||
{file = "pyarrow-15.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:b116e7fd7889294cbd24eb90cd9bdd3850be3738d61297855a71ac3b8124ee38"},
|
||||
{file = "pyarrow-15.0.2-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:25335e6f1f07fdaa026a61c758ee7d19ce824a866b27bba744348fa73bb5a440"},
|
||||
{file = "pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90f19e976d9c3d8e73c80be84ddbe2f830b6304e4c576349d9360e335cd627fc"},
|
||||
{file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a22366249bf5fd40ddacc4f03cd3160f2d7c247692945afb1899bab8a140ddfb"},
|
||||
{file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2a335198f886b07e4b5ea16d08ee06557e07db54a8400cc0d03c7f6a22f785f"},
|
||||
{file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e6d459c0c22f0b9c810a3917a1de3ee704b021a5fb8b3bacf968eece6df098f"},
|
||||
{file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:033b7cad32198754d93465dcfb71d0ba7cb7cd5c9afd7052cab7214676eec38b"},
|
||||
{file = "pyarrow-15.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:29850d050379d6e8b5a693098f4de7fd6a2bea4365bfd073d7c57c57b95041ee"},
|
||||
{file = "pyarrow-15.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:7167107d7fb6dcadb375b4b691b7e316f4368f39f6f45405a05535d7ad5e5058"},
|
||||
{file = "pyarrow-15.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e85241b44cc3d365ef950432a1b3bd44ac54626f37b2e3a0cc89c20e45dfd8bf"},
|
||||
{file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:248723e4ed3255fcd73edcecc209744d58a9ca852e4cf3d2577811b6d4b59818"},
|
||||
{file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ff3bdfe6f1b81ca5b73b70a8d482d37a766433823e0c21e22d1d7dde76ca33f"},
|
||||
{file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f3d77463dee7e9f284ef42d341689b459a63ff2e75cee2b9302058d0d98fe142"},
|
||||
{file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:8c1faf2482fb89766e79745670cbca04e7018497d85be9242d5350cba21357e1"},
|
||||
{file = "pyarrow-15.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:28f3016958a8e45a1069303a4a4f6a7d4910643fc08adb1e2e4a7ff056272ad3"},
|
||||
{file = "pyarrow-15.0.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:89722cb64286ab3d4daf168386f6968c126057b8c7ec3ef96302e81d8cdb8ae4"},
|
||||
{file = "pyarrow-15.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd0ba387705044b3ac77b1b317165c0498299b08261d8122c96051024f953cd5"},
|
||||
{file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad2459bf1f22b6a5cdcc27ebfd99307d5526b62d217b984b9f5c974651398832"},
|
||||
{file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58922e4bfece8b02abf7159f1f53a8f4d9f8e08f2d988109126c17c3bb261f22"},
|
||||
{file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:adccc81d3dc0478ea0b498807b39a8d41628fa9210729b2f718b78cb997c7c91"},
|
||||
{file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8bd2baa5fe531571847983f36a30ddbf65261ef23e496862ece83bdceb70420d"},
|
||||
{file = "pyarrow-15.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6669799a1d4ca9da9c7e06ef48368320f5856f36f9a4dd31a11839dda3f6cc8c"},
|
||||
{file = "pyarrow-15.0.2.tar.gz", hash = "sha256:9c9bc803cb3b7bfacc1e96ffbfd923601065d9d3f911179d81e72d99fd74a3d9"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:22a1fdb1254e5095d629e29cd1ea98ed04b4bbfd8e42cc670a6b639ccc208b60"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:574a00260a4ed9d118a14770edbd440b848fcae5a3024128be9d0274dbcaf858"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0815d0ddb733b8c1b53a05827a91f1b8bde6240f3b20bf9ba5d650eb9b89cdf"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df0080339387b5d30de31e0a149c0c11a827a10c82f0c67d9afae3981d1aabb7"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:edf38cce0bf0dcf726e074159c60516447e4474904c0033f018c1f33d7dac6c5"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:91d28f9a40f1264eab2af7905a4d95320ac2f287891e9c8b0035f264fe3c3a4b"},
|
||||
{file = "pyarrow-16.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:99af421ee451a78884d7faea23816c429e263bd3618b22d38e7992c9ce2a7ad9"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d22d0941e6c7bafddf5f4c0662e46f2075850f1c044bf1a03150dd9e189427ce"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:266ddb7e823f03733c15adc8b5078db2df6980f9aa93d6bb57ece615df4e0ba7"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cc23090224b6594f5a92d26ad47465af47c1d9c079dd4a0061ae39551889efe"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56850a0afe9ef37249d5387355449c0f94d12ff7994af88f16803a26d38f2016"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:705db70d3e2293c2f6f8e84874b5b775f690465798f66e94bb2c07bab0a6bb55"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:5448564754c154997bc09e95a44b81b9e31ae918a86c0fcb35c4aa4922756f55"},
|
||||
{file = "pyarrow-16.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:729f7b262aa620c9df8b9967db96c1575e4cfc8c25d078a06968e527b8d6ec05"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:fb8065dbc0d051bf2ae2453af0484d99a43135cadabacf0af588a3be81fbbb9b"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:20ce707d9aa390593ea93218b19d0eadab56390311cb87aad32c9a869b0e958c"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5823275c8addbbb50cd4e6a6839952682a33255b447277e37a6f518d6972f4e1"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ab8b9050752b16a8b53fcd9853bf07d8daf19093533e990085168f40c64d978"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:42e56557bc7c5c10d3e42c3b32f6cff649a29d637e8f4e8b311d334cc4326730"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:2a7abdee4a4a7cfa239e2e8d721224c4b34ffe69a0ca7981354fe03c1328789b"},
|
||||
{file = "pyarrow-16.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:ef2f309b68396bcc5a354106741d333494d6a0d3e1951271849787109f0229a6"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:ed66e5217b4526fa3585b5e39b0b82f501b88a10d36bd0d2a4d8aa7b5a48e2df"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc8814310486f2a73c661ba8354540f17eef51e1b6dd090b93e3419d3a097b3a"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c2f5e239db7ed43e0ad2baf46a6465f89c824cc703f38ef0fde927d8e0955f7"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f293e92d1db251447cb028ae12f7bc47526e4649c3a9924c8376cab4ad6b98bd"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:dd9334a07b6dc21afe0857aa31842365a62eca664e415a3f9536e3a8bb832c07"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d91073d1e2fef2c121154680e2ba7e35ecf8d4969cc0af1fa6f14a8675858159"},
|
||||
{file = "pyarrow-16.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:71d52561cd7aefd22cf52538f262850b0cc9e4ec50af2aaa601da3a16ef48877"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:b93c9a50b965ee0bf4fef65e53b758a7e8dcc0c2d86cebcc037aaaf1b306ecc0"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d831690844706e374c455fba2fb8cfcb7b797bfe53ceda4b54334316e1ac4fa4"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35692ce8ad0b8c666aa60f83950957096d92f2a9d8d7deda93fb835e6053307e"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dd3151d098e56f16a8389c1247137f9e4c22720b01c6f3aa6dec29a99b74d80"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bd40467bdb3cbaf2044ed7a6f7f251c8f941c8b31275aaaf88e746c4f3ca4a7a"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:00a1dcb22ad4ceb8af87f7bd30cc3354788776c417f493089e0a0af981bc8d80"},
|
||||
{file = "pyarrow-16.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fda9a7cebd1b1d46c97b511f60f73a5b766a6de4c5236f144f41a5d5afec1f35"},
|
||||
{file = "pyarrow-16.0.0.tar.gz", hash = "sha256:59bb1f1edbbf4114c72415f039f1359f1a57d166a331c3229788ccbfbb31689a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.16.6,<2"
|
||||
numpy = ">=1.16.6"
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow-hotfix"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation-server"
|
||||
version = "1.3.4"
|
||||
version = "1.4.0"
|
||||
description = "Text Generation Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
|
@ -13,11 +13,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -28,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -12,11 +12,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -27,18 +27,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -3,24 +3,27 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class ProcessGroup:
|
||||
def __init__(self, rank: int, world_size: int):
|
||||
self._rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
def size(self)->int:
|
||||
def size(self) -> int:
|
||||
return self.world_size
|
||||
|
||||
def rank(self)->int:
|
||||
def rank(self) -> int:
|
||||
return self._rank
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
|
||||
self.weight = torch.arange(vocab_size*hidden_dim).float().view(vocab_size, hidden_dim)
|
||||
self.weight = (
|
||||
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
|
||||
)
|
||||
self.process_group = ProcessGroup(rank, world_size)
|
||||
|
||||
|
||||
def get_partial_sharded(self, name:str, dim: int):
|
||||
def get_partial_sharded(self, name: str, dim: int):
|
||||
assert dim == 0
|
||||
|
||||
rank = self.process_group.rank()
|
||||
@ -35,10 +38,11 @@ class Weights:
|
||||
def get_shape(self, name: str):
|
||||
return self.weight.shape
|
||||
|
||||
|
||||
def test_weight_hub_files_offline_error():
|
||||
|
||||
vocab_size= 17
|
||||
weights = Weights(rank=0, world_size=1, vocab_size = vocab_size,hidden_dim = 256)
|
||||
vocab_size = 17
|
||||
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
|
||||
embeddings = TensorParallelEmbedding("", weights)
|
||||
|
||||
input_ids = torch.arange(vocab_size)
|
||||
@ -47,18 +51,27 @@ def test_weight_hub_files_offline_error():
|
||||
assert embeddings.max_id == 17
|
||||
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
|
||||
|
||||
weights_0_2 = Weights(rank=0, world_size=2, vocab_size = vocab_size,hidden_dim = 256)
|
||||
weights_1_2 = Weights(rank=1, world_size=2, vocab_size = vocab_size,hidden_dim = 256)
|
||||
weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
|
||||
weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
|
||||
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
|
||||
assert embeddings_0_2.min_id == 0
|
||||
assert embeddings_0_2.max_id == 9
|
||||
torch.testing.assert_close(embeddings_0_2.weight , torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0).view(10, 256).float())
|
||||
torch.testing.assert_close(
|
||||
embeddings_0_2.weight,
|
||||
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
|
||||
.view(10, 256)
|
||||
.float(),
|
||||
)
|
||||
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
|
||||
assert embeddings_1_2.min_id == 9
|
||||
assert embeddings_1_2.max_id == 17
|
||||
torch.testing.assert_close(embeddings_1_2.weight , torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0).view(9, 256).float())
|
||||
torch.testing.assert_close(
|
||||
embeddings_1_2.weight,
|
||||
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
|
||||
.view(9, 256)
|
||||
.float(),
|
||||
)
|
||||
output_tp_0 = embeddings_0_2.forward(input_ids)
|
||||
output_tp_1 = embeddings_1_2.forward(input_ids)
|
||||
|
||||
torch.testing.assert_close(output, output_tp_0 + output_tp_1)
|
||||
|
||||
|
@ -270,7 +270,7 @@ def download_weights(
|
||||
pass
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
elif (Path(model_id) / "adapter_config.json").exists():
|
||||
# Try to load as a local PEFT model
|
||||
try:
|
||||
|
@ -17,6 +17,7 @@ from text_generation_server.utils.layers import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
@ -25,15 +26,15 @@ class PhiConfig(PretrainedConfig):
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
hidden_act="gelu_fast", # llama uses silu
|
||||
layer_norm_eps=1e-05, # rms in llama,
|
||||
hidden_act="gelu_fast", # llama uses silu
|
||||
layer_norm_eps=1e-05, # rms in llama,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
resid_pdrop=0.1, # llama doesn't have this
|
||||
partial_rotary_factor=0.5, # important difference between llama and phi
|
||||
resid_pdrop=0.1, # llama doesn't have this
|
||||
partial_rotary_factor=0.5, # important difference between llama and phi
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -55,6 +56,7 @@ class PhiConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# this is the same as llama except for Phi uses bias=True
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
@ -68,6 +70,7 @@ def load_attention(config, prefix, weights):
|
||||
bias=True,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
@ -94,6 +97,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||
get_linear(weight, bias=True, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class FlashPhiAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -173,8 +177,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
#
|
||||
# Apply partial positional embeddings in place
|
||||
self.rotary_emb(
|
||||
query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
|
||||
cos, sin
|
||||
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
|
||||
)
|
||||
|
||||
# Reshape key and value and cache
|
||||
@ -210,7 +213,8 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads*self.head_size))
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
@ -256,7 +260,9 @@ class FlashPhiLayer(nn.Module):
|
||||
)
|
||||
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.input_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
|
||||
|
||||
@ -287,10 +293,13 @@ class FlashPhiLayer(nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))
|
||||
hidden_states = self.resid_dropout(attn_output).add(
|
||||
self.resid_dropout(self.mlp(hidden_states))
|
||||
)
|
||||
|
||||
return hidden_states, res
|
||||
|
||||
|
||||
class FlashPhiModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
@ -361,6 +370,7 @@ class FlashPhiModel(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashPhiForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
@ -380,7 +390,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
@ -54,9 +54,19 @@ def load_col(config, prefix, weights, bias):
|
||||
bias_h = bias_h[0]
|
||||
bias_block_size = bias_h // bias_size
|
||||
|
||||
bias_q_part = bias_slice_[bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size]
|
||||
bias_k_part = bias_slice_[bias_h + bias_rank * bias_block_size : bias_h + (bias_rank + 1) * bias_block_size]
|
||||
bias_v_part = bias_slice_[2 * bias_h + bias_rank * bias_block_size : 2 * bias_h + (bias_rank + 1) * bias_block_size]
|
||||
bias_q_part = bias_slice_[
|
||||
bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size
|
||||
]
|
||||
bias_k_part = bias_slice_[
|
||||
bias_h
|
||||
+ bias_rank * bias_block_size : bias_h
|
||||
+ (bias_rank + 1) * bias_block_size
|
||||
]
|
||||
bias_v_part = bias_slice_[
|
||||
2 * bias_h
|
||||
+ bias_rank * bias_block_size : 2 * bias_h
|
||||
+ (bias_rank + 1) * bias_block_size
|
||||
]
|
||||
|
||||
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
|
||||
if bias.dtype != torch.int32:
|
||||
@ -352,8 +362,12 @@ class MultiheadAttention(nn.Module):
|
||||
hidden_size = config.d_model
|
||||
head_dim = hidden_size // self.n_heads
|
||||
|
||||
self.q_ln = LPLayerNorm(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
|
||||
self.k_ln = LPLayerNorm(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
|
||||
self.q_ln = LPLayerNorm(
|
||||
d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights
|
||||
)
|
||||
self.k_ln = LPLayerNorm(
|
||||
self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights
|
||||
)
|
||||
if self.attn_impl == "flash":
|
||||
self.attn_fn = flash_attn_fn
|
||||
elif self.attn_impl == "triton":
|
||||
@ -684,7 +698,6 @@ class LPLayerNorm(torch.nn.LayerNorm):
|
||||
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
|
||||
self.normalized_shape = self.weight.shape
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
module_device = x.device
|
||||
downcast_x = _cast_if_autocast_enabled(x)
|
||||
@ -798,7 +811,7 @@ class MPTModel(MPTPreTrainedModel):
|
||||
self.wte = TensorParallelEmbedding("transformer.wte", weights)
|
||||
|
||||
if not self.alibi:
|
||||
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
|
||||
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
|
||||
|
@ -62,14 +62,12 @@ class PhiConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# RotaryEmbedding is a class that implements the rotary embedding.
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
inv_freq = [
|
||||
1.0 / 10000.0 ** (i / dim)
|
||||
for i in range(0, dim, 2)
|
||||
]
|
||||
inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]
|
||||
inv_freq_len = len(inv_freq)
|
||||
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
|
||||
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
|
||||
@ -131,6 +129,7 @@ class PhiCausalLMHead(nn.Module):
|
||||
hidden_states = self.linear(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
|
||||
class PhiMHA(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
@ -172,19 +171,27 @@ class PhiMHA(nn.Module):
|
||||
v = torch.cat([prev_v, v], dim=1)
|
||||
|
||||
past_kv_cache = [k, v]
|
||||
attn_weights = torch.einsum('bthd,bshd->bhts', q, k * self.softmax_scale)
|
||||
attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale)
|
||||
|
||||
if attention_mask is not None:
|
||||
seqlen_k = k.shape[1]
|
||||
seqlen_q = q.shape[1]
|
||||
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), 1)
|
||||
causal_mask = torch.triu(
|
||||
torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),
|
||||
1,
|
||||
)
|
||||
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
|
||||
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
|
||||
attn_output = attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)).transpose(1, 2).flatten(-2)
|
||||
attn_output = (
|
||||
attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))
|
||||
.transpose(1, 2)
|
||||
.flatten(-2)
|
||||
)
|
||||
return self.out_proj(attn_output), past_kv_cache
|
||||
|
||||
|
||||
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
@ -204,19 +211,22 @@ class PhiMLP(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
self.activation = torch.nn.functional.gelu
|
||||
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
|
||||
class PhiBlock(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.layer_norm = nn.LayerNorm.load(prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon)
|
||||
self.layer_norm = nn.LayerNorm.load(
|
||||
prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
||||
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
|
||||
|
||||
@ -228,11 +238,14 @@ class PhiBlock(nn.Module):
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
|
||||
attn_outputs, past_kv_cache = self.mixer(
|
||||
hidden_states, kv_cache, attention_mask
|
||||
)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
out = attn_outputs + feed_forward_hidden_states + residual
|
||||
return out, past_kv_cache
|
||||
|
||||
|
||||
# PhiModel implements the embedding layer and the transformer blocks.
|
||||
class PhiModel(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
@ -241,9 +254,12 @@ class PhiModel(nn.Module):
|
||||
self.tp_world_size = weights.process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="transformer.embd.wte", weights=weights
|
||||
)
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[PhiBlock(f"transformer.h.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
|
||||
[
|
||||
PhiBlock(f"transformer.h.{layer_id}", config, weights)
|
||||
for layer_id in range(config.n_layer)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -258,14 +274,19 @@ class PhiModel(nn.Module):
|
||||
seq_len = hidden_states.shape[1]
|
||||
mask = None if seq_len <= 1 else attention_mask
|
||||
|
||||
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
|
||||
past_key_values = (
|
||||
[None] * len(self.blocks) if past_key_values is None else past_key_values
|
||||
)
|
||||
|
||||
for index, block in enumerate(self.blocks):
|
||||
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
|
||||
hidden_states, new_key_values = block(
|
||||
hidden_states, past_key_values[index], mask
|
||||
)
|
||||
past_key_values[index] = new_key_values
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
||||
|
||||
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
||||
class PhiForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
@ -290,12 +311,15 @@ class PhiForCausalLM(torch.nn.Module):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = nn.CrossEntropyLoss()(
|
||||
logits[:, :-1].view(-1, logits.size(-1)),
|
||||
labels[:, 1:].view(-1)
|
||||
logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return ((loss,) + (logits,) + model_output[1:]) if loss is not None else (logits,) + model_output[1:]
|
||||
return (
|
||||
((loss,) + (logits,) + model_output[1:])
|
||||
if loss is not None
|
||||
else (logits,) + model_output[1:]
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
@ -304,5 +328,3 @@ class PhiForCausalLM(torch.nn.Module):
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -73,11 +73,11 @@ class FlashLlama(FlashCausalLM):
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
|
||||
"WEIGHTS_CACHE_OVERRIDE", None
|
||||
) is not None
|
||||
|
||||
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
@ -88,7 +88,7 @@ class FlashLlama(FlashCausalLM):
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
|
@ -63,11 +63,11 @@ class FlashPhi(FlashCausalLM):
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
|
||||
"WEIGHTS_CACHE_OVERRIDE", None
|
||||
) is not None
|
||||
|
||||
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
@ -78,7 +78,7 @@ class FlashPhi(FlashCausalLM):
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
|
@ -5,13 +5,17 @@ from transformers import AutoConfig, AutoTokenizer
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.custom_modeling.phi_modeling import PhiConfig, PhiForCausalLM
|
||||
from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||
PhiConfig,
|
||||
PhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class Phi(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
@ -60,4 +64,3 @@ class Phi(CausalLM):
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
@ -510,7 +510,9 @@ class TensorParallelEmbedding(nn.Module):
|
||||
block_size = (num_embeddings + world_size - 1) // world_size
|
||||
self.min_id = rank * block_size
|
||||
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||
self.null_idx = weight.shape[0] # Usually block_size, might be less in non even vocab_size.
|
||||
self.null_idx = weight.shape[
|
||||
0
|
||||
] # Usually block_size, might be less in non even vocab_size.
|
||||
self.process_group = weights.process_group
|
||||
self.reduce = reduce
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user