From fa8a8e05afa435bd39308974925a1098239dfcb4 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 21 Feb 2024 11:05:32 +0100 Subject: [PATCH 1/4] fix(router): fix openapi and add jsonschema validation (#1578) --- .github/workflows/tests.yaml | 4 + Cargo.lock | 157 ++++++++++++++++++ docs/openapi.json | 51 ++++++ .../test_flash_llama_grammar_json.json | 122 +++++++------- integration-tests/models/test_flash_awq.py | 3 - .../models/test_flash_awq_sharded.py | 2 - integration-tests/models/test_flash_medusa.py | 3 - .../models/test_flash_mistral.py | 3 - integration-tests/models/test_flash_phi.py | 3 - .../models/test_flash_starcoder_gptq.py | 3 - .../models/test_grammar_llama.py | 7 +- integration-tests/models/test_mamba.py | 3 - router/Cargo.toml | 1 + router/src/lib.rs | 37 +---- router/src/server.rs | 1 + router/src/validation.rs | 29 +++- server/text_generation_server/utils/tokens.py | 5 +- 17 files changed, 312 insertions(+), 122 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 5b19eb8c..29ff6d45 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -41,6 +41,10 @@ jobs: components: rustfmt, clippy - name: Install Protoc uses: arduino/setup-protoc@v1 + - name: Clean unused files + run: | + sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android + sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET - name: Install sccache run: | curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache diff --git a/Cargo.lock b/Cargo.lock index f8151731..ccccdf3c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,7 +24,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" dependencies = [ "cfg-if", + "getrandom", "once_cell", + "serde", "version_check", "zerocopy", ] @@ -265,6 +267,21 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -716,6 +733,16 @@ dependencies = [ "cc", ] +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set", + "regex", +] + [[package]] name = "fastrand" version = "2.0.1" @@ -780,6 +807,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fraction" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678" +dependencies = [ + "lazy_static", + "num", +] + [[package]] name = "futures" version = "0.3.30" @@ -895,8 +932,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1181,6 +1220,15 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +[[package]] +name = "iso8601" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153" +dependencies = [ + "nom", +] + [[package]] name = "itertools" version = "0.10.5" @@ -1223,6 +1271,36 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonschema" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978" +dependencies = [ + "ahash", + "anyhow", + "base64 0.21.7", + "bytecount", + "clap", + "fancy-regex", + "fraction", + "getrandom", + "iso8601", + "itoa", + "memchr", + "num-cmp", + "once_cell", + "parking_lot", + "percent-encoding", + "regex", + "reqwest", + "serde", + "serde_json", + "time", + "url", + "uuid", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -1574,12 +1652,84 @@ dependencies = [ "winapi", ] +[[package]] +name = "num" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-cmp" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa" + +[[package]] +name = "num-complex" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.18" @@ -2874,6 +3024,7 @@ dependencies = [ "futures-util", "hf-hub", "init-tracing-opentelemetry", + "jsonschema", "metrics", "metrics-exporter-prometheus", "minijinja", @@ -3530,6 +3681,12 @@ dependencies = [ "zip", ] +[[package]] +name = "uuid" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" + [[package]] name = "valuable" version = "0.1.0" diff --git a/docs/openapi.json b/docs/openapi.json index d72d32e9..fad01aec 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1022,6 +1022,57 @@ } } }, + "GrammarType": { + "oneOf": [ + { + "type": "object", + "required": [ + "type", + "value" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "json" + ] + }, + "value": { + "type": "string", + "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.", + "example": { + "properties": { + "location": { + "type": "string" + } + } + } + } + } + }, + { + "type": "object", + "required": [ + "type", + "value" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "regex" + ] + }, + "value": { + "type": "string" + } + } + } + ], + "discriminator": { + "propertyName": "type" + } + }, "Info": { "type": "object", "required": [ diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json index 7b12b158..d7fb620d 100644 --- a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json @@ -135,129 +135,129 @@ "special": false, "text": "\",\"" }, - { - "id": 4230, - "logprob": -0.020492554, - "special": false, - "text": "last" - }, - { - "id": 1170, - "logprob": -0.0013818741, - "special": false, - "text": "Name" - }, - { - "id": 4710, - "logprob": -0.0067749023, - "special": false, - "text": "\":\"" - }, - { - "id": 29950, - "logprob": -0.11578369, - "special": false, - "text": "H" - }, - { - "id": 14339, - "logprob": -0.004131317, - "special": false, - "text": "olt" - }, - { - "id": 29920, - "logprob": -0.0033359528, - "special": false, - "text": "z" - }, - { - "id": 3284, - "logprob": -0.20471191, - "special": false, - "text": "\",\"" - }, { "id": 29882, - "logprob": -0.0069274902, + "logprob": -0.08862305, "special": false, "text": "h" }, { - "id": 20838, - "logprob": -0.19580078, + "id": 711, + "logprob": -0.66259766, "special": false, - "text": "obb" + "text": "ob" }, { - "id": 29891, - "logprob": -2.2649765e-06, + "id": 1609, + "logprob": -5.51939e-05, "special": false, - "text": "y" + "text": "by" }, { "id": 4710, - "logprob": -0.32080078, + "logprob": -0.23120117, "special": false, "text": "\":\"" }, { "id": 29911, - "logprob": -2.1035156, + "logprob": -2.3730469, "special": false, "text": "T" }, { "id": 11003, - "logprob": -0.020767212, + "logprob": -0.032104492, "special": false, "text": "rees" }, { "id": 3284, - "logprob": -0.6010742, + "logprob": -0.22021484, + "special": false, + "text": "\",\"" + }, + { + "id": 4230, + "logprob": -0.06726074, + "special": false, + "text": "last" + }, + { + "id": 1170, + "logprob": -0.003501892, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.0045661926, + "special": false, + "text": "\":\"" + }, + { + "id": 29950, + "logprob": -0.12512207, + "special": false, + "text": "H" + }, + { + "id": 14339, + "logprob": -0.009552002, + "special": false, + "text": "olt" + }, + { + "id": 29920, + "logprob": -0.00042438507, + "special": false, + "text": "z" + }, + { + "id": 3284, + "logprob": -0.11651611, "special": false, "text": "\",\"" }, { "id": 29876, - "logprob": -0.57666016, + "logprob": -0.29736328, "special": false, "text": "n" }, { "id": 398, - "logprob": -0.0061073303, + "logprob": -0.003030777, "special": false, "text": "um" }, { "id": 29907, - "logprob": -0.45703125, + "logprob": -0.3774414, "special": false, "text": "C" }, { "id": 1446, - "logprob": -0.0002872944, + "logprob": -0.0003130436, "special": false, "text": "ats" }, { "id": 1115, - "logprob": -0.0021018982, + "logprob": -0.0021514893, "special": false, "text": "\":" }, { "id": 29906, - "logprob": -0.08996582, + "logprob": -0.071899414, "special": false, "text": "2" }, { "id": 29913, - "logprob": -0.021697998, + "logprob": -0.018997192, "special": false, "text": "}" }, @@ -270,5 +270,5 @@ ], "top_tokens": null }, - "generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}" + "generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}" } diff --git a/integration-tests/models/test_flash_awq.py b/integration-tests/models/test_flash_awq.py index 62a95f48..ead918c3 100644 --- a/integration-tests/models/test_flash_awq.py +++ b/integration-tests/models/test_flash_awq.py @@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( "What is Deep Learning?", @@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): responses = await generate_load( flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_awq_sharded.py b/integration-tests/models/test_flash_awq_sharded.py index 1c687fc9..a83614ac 100644 --- a/integration-tests/models/test_flash_awq_sharded.py +++ b/integration-tests/models/test_flash_awq_sharded.py @@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): response = await flash_llama_awq_sharded.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_load_sharded( flash_llama_awq_sharded, generate_load, response_snapshot ): diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index a0ce0570..e0cc1039 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_simple(flash_medusa, response_snapshot): response = await flash_medusa.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_all_params(flash_medusa, response_snapshot): response = await flash_medusa.generate( "What is Deep Learning?", @@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): responses = await generate_load( flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_mistral.py b/integration-tests/models/test_flash_mistral.py index ace3328b..52b51928 100644 --- a/integration-tests/models/test_flash_mistral.py +++ b/integration-tests/models/test_flash_mistral.py @@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral(flash_mistral, response_snapshot): response = await flash_mistral.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral_all_params(flash_mistral, response_snapshot): response = await flash_mistral.generate( "Test request", @@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot): responses = await generate_load( flash_mistral, "Test request", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py index 0987b3a1..9d6ca566 100644 --- a/integration-tests/models/test_flash_phi.py +++ b/integration-tests/models/test_flash_phi.py @@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi(flash_phi, response_snapshot): response = await flash_phi.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi_all_params(flash_phi, response_snapshot): response = await flash_phi.generate( "Test request", @@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py index 5e448d55..329158b7 100644 --- a/integration-tests/models/test_flash_starcoder_gptq.py +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot): response = await flash_starcoder_gptq.generate( "def geometric_mean(L: List[float]):", @@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq_default_params( flash_starcoder_gptq, generous_response_snapshot ): @@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params( @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq_load( flash_starcoder_gptq, generate_load, generous_response_snapshot ): diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py index f068496c..585d0656 100644 --- a/integration-tests/models/test_grammar_llama.py +++ b/integration-tests/models/test_grammar_llama.py @@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "Whats Googles DNS", @@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot) @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "info: david holtz like trees and has two cats. ", @@ -92,13 +89,12 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): assert response.details.generated_tokens == 30 assert ( response.generated_text - == '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}' + == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}' ) assert response == response_snapshot @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_load( flash_llama_grammar, generate_load, response_snapshot ): @@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load( # this is the same as the above test, but only fires off a single request # this is only to ensure that the parallel and single inference produce the same result @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_single_load_instance( flash_llama_grammar, generate_load, response_snapshot ): diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index 5ec2ec31..bf3701b4 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( "What is Deep Learning?", max_new_tokens=10 @@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( "blue, red, yellow, ", @@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba_load( fused_kernel_mamba, generate_load, generous_response_snapshot ): diff --git a/router/Cargo.toml b/router/Cargo.toml index 7d6dc017..170debda 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -22,6 +22,7 @@ text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" hf-hub = { version = "0.3.0", features = ["tokio"] } +jsonschema = { version = "0.17.1", features = ["draft202012"] } metrics = "0.21.1" metrics-exporter-prometheus = { version = "0.12.1", features = [] } nohash-hasher = "0.2.0" diff --git a/router/src/lib.rs b/router/src/lib.rs index b7285e65..c6928a5a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -64,39 +64,16 @@ impl HubTokenizerConfig { } } -mod json_object_or_string_to_string { - use serde::{Deserialize, Deserializer}; - use serde_json::Value; - - // A custom deserializer that treats both strings and objects as strings. - // This provides flexibility with input formats for the 'grammar' field. - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; - - match value { - Value::String(s) => Ok(s), - // Safely handle serialization and return an error if it fails - Value::Object(o) => { - serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string())) - } - _ => Err(serde::de::Error::custom( - "expected string or object for grammar", - )), - } - } -} - #[derive(Clone, Debug, Deserialize, ToSchema)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { - #[serde( - rename = "json", - deserialize_with = "json_object_or_string_to_string::deserialize" - )] - Json(String), + /// A string that represents a [JSON Schema](https://json-schema.org/). + /// + /// JSON Schema is a declarative language that allows to annotate JSON documents + /// with types and descriptions. + #[serde(rename = "json")] + #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] + Json(serde_json::Value), #[serde(rename = "regex")] Regex(String), } diff --git a/router/src/server.rs b/router/src/server.rs index 054ba5a2..ebde7133 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -893,6 +893,7 @@ pub async fn run( Info, CompatGenerateRequest, GenerateRequest, + GrammarType, ChatRequest, Message, ChatCompletionChoice, diff --git a/router/src/validation.rs b/router/src/validation.rs index bf85b12f..204dbf92 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,7 +1,9 @@ /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest, GrammarType}; +use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; +use serde_json::Value; use text_generation_client::{ GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, }; @@ -313,8 +315,29 @@ impl Validation { return Err(ValidationError::Grammar); } match grammar { - // currently both are handled the same way since compilation is done in Python - GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()), + GrammarType::Json(json) => { + let json = match json { + // if value is a string, we need to parse it again to make sure its + // a valid json + Value::String(s) => serde_json::from_str(&s) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string())), + Value::Object(_) => Ok(json), + _ => Err(ValidationError::Grammar), + }?; + + // Check if the json is a valid JSONSchema + JSONSchema::options() + .with_draft(Draft::Draft202012) + .compile(&json) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + + ( + // Serialize json to string + serde_json::to_string(&json) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, + ProtoGrammarType::Json.into(), + ) + } GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), } } @@ -486,6 +509,8 @@ pub enum ValidationError { Tokenizer(String), #[error("grammar is not supported")] Grammar, + #[error("grammar is not valid: {0}")] + InvalidGrammar(String), } #[cfg(test)] diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 72c6c21c..32789850 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -328,7 +328,6 @@ class HeterogeneousNextTokenChooser: scores = scores.view(B, S, -1) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) - mask = torch.full((scores.shape[-1],), -math.inf, device=self.device) for j in range(S): _scores = scores[:, j] @@ -338,10 +337,10 @@ class HeterogeneousNextTokenChooser: _scores = self.repetition_processor(input_ids, _scores) if self.frequency_processor is not None: _scores = self.frequency_processor(input_ids, _scores) - for warper in self.warpers: - _scores = warper(input_ids, _scores) if self.grammar_processor is not None: _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + for warper in self.warpers: + _scores = warper(input_ids, _scores) _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids From c86f58d37cfff019fea878d8f2bf9b4da26c1d8e Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 21 Feb 2024 14:15:22 +0100 Subject: [PATCH 2/4] feat: add support for Gemma (#1583) --- integration-tests/conftest.py | 3 + .../test_flash_gemma/test_flash_gemma.json | 89 +++ .../test_flash_gemma_all_params.json | 89 +++ .../test_flash_gemma_load.json | 358 ++++++++++ integration-tests/models/test_flash_gemma.py | 61 ++ .../text_generation_server/models/__init__.py | 25 + .../custom_modeling/flash_gemma_modeling.py | 609 ++++++++++++++++++ .../models/flash_gemma.py | 104 +++ 8 files changed, 1338 insertions(+) create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json create mode 100644 integration-tests/models/test_flash_gemma.py create mode 100644 server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py create mode 100644 server/text_generation_server/models/flash_gemma.py diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index e0228894..80457bc2 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -40,6 +40,9 @@ class ResponseComparator(JSONSnapshotExtension): exclude=None, matcher=None, ): + if isinstance(data, Response): + data = data.dict() + if isinstance(data, List): data = [d.dict() for d in data] diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json new file mode 100644 index 00000000..80f0d053 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.8671875, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.4375, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8203125, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23242188, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.08544922, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.9375, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.671875, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.40429688, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.1875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json new file mode 100644 index 00000000..8253dc96 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 7539, + "logprob": -0.73046875, + "special": false, + "text": " forms" + }, + { + "id": 708, + "logprob": 0.0, + "special": false, + "text": " are" + }, + { + "id": 671, + "logprob": -1.703125, + "special": false, + "text": " an" + }, + { + "id": 8727, + "logprob": 0.0, + "special": false, + "text": " essential" + }, + { + "id": 1702, + "logprob": 0.0, + "special": false, + "text": " part" + }, + { + "id": 576, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 573, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 11859, + "logprob": -1.6953125, + "special": false, + "text": " lab" + }, + { + "id": 2185, + "logprob": -1.3125, + "special": false, + "text": " process" + }, + { + "id": 578, + "logprob": -1.5, + "special": false, + "text": " and" + } + ], + "top_tokens": null + }, + "generated_text": "Test request forms are an essential part of the lab process and" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json new file mode 100644 index 00000000..e69ee25d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + } +] diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py new file mode 100644 index 00000000..2822b5e2 --- /dev/null +++ b/integration-tests/models/test_flash_gemma.py @@ -0,0 +1,61 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gemma_handle(launcher): + with launcher("gg-hf/gemma-2b", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gemma(flash_gemma_handle): + await flash_gemma_handle.health(300) + return flash_gemma_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma(flash_gemma, response_snapshot): + response = await flash_gemma.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_all_params(flash_gemma, response_snapshot): + response = await flash_gemma.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): + responses = await generate_load(flash_gemma, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index da7d8416..abab3486 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -52,6 +52,9 @@ try: from text_generation_server.models.flash_llama import ( FlashLlama, ) + from text_generation_server.models.flash_gemma import ( + FlashGemma, + ) from text_generation_server.models.flash_santacoder import ( FlashSantacoderSharded, ) @@ -312,6 +315,28 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + if model_type == "gemma": + if FLASH_ATTENTION: + return FlashGemma( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + use_medusa=use_medusa, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate") + ) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]: if sharded: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py new file mode 100644 index 00000000..4a08bc2a --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -0,0 +1,609 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +import os +from shutil import copyfile + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple +from tokenizers import processors +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from transformers.utils import logging + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + get_linear, + FastRMSNorm, +) + +GemmaTokenizer = None + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = { + "vocab_file": "tokenizer.model", + "tokenizer_file": "tokenizer.json", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class GemmaTokenizerFast(PreTrainedTokenizerFast): + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + slow_tokenizer_class = GemmaTokenizer + padding_side = "left" + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + clean_up_tokenization_spaces=False, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + add_bos_token=True, + add_eos_token=False, + use_default_system_prompt=False, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + use_default_system_prompt=use_default_system_prompt, + **kwargs, + ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self.update_post_processor() + self.use_default_system_prompt = use_default_system_prompt + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def update_post_processor(self): + """ + Updates the underlying post processor with the current `bos_token` and `eos_token`. + """ + bos = self.bos_token + bos_token_id = self.bos_token_id + if bos is None and self.add_bos_token: + raise ValueError("add_bos_token = True but bos_token = None") + + eos = self.eos_token + eos_token_id = self.eos_token_id + if eos is None and self.add_eos_token: + raise ValueError("add_eos_token = True but eos_token = None") + + single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" + pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + + def save_vocabulary( + self, save_directory: str, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + @property + def default_chat_template(self): + raise NotImplementedError + + # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + +class GemmaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=256128, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class GemmaFastRMSNorm(FastRMSNorm): + @classmethod + def load(cls, prefix, weights, eps=1e-6): + weight = weights.get_tensor(f"{prefix}.weight") + 1 + return cls(weight, eps) + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.head_dim + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) + + +class FlashGemmaAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.head_size = config.head_dim + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + paged_attention.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class GemmaMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class FlashGemmaLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = FlashGemmaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = GemmaFastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = GemmaFastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class FlashGemmaModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + embed_norm = config.hidden_size**0.5 + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.embed_tokens.weight *= embed_norm + + self.layers = nn.ModuleList( + [ + FlashGemmaLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = GemmaFastRMSNorm.load( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashGemmaForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = FlashGemmaModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py new file mode 100644 index 00000000..220b3992 --- /dev/null +++ b/server/text_generation_server/models/flash_gemma.py @@ -0,0 +1,104 @@ +import torch +import torch.distributed + +from opentelemetry import trace +from typing import Optional + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + GemmaTokenizerFast, + FlashGemmaForCausalLM, + GemmaConfig, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + + +class FlashGemma(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + use_medusa: Optional[str] = None, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.bfloat16 if dtype is None else dtype + else: + raise NotImplementedError("FlashGemma is only available on GPU") + + tokenizer = GemmaTokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + use_fast=True, + from_slow=False, + ) + + config = GemmaConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id, revision) + + model = FlashGemmaForCausalLM(config, weights) + if use_medusa: + from text_generation_server.utils.medusa import MedusaModel + from huggingface_hub import hf_hub_download + import json + import os + from pathlib import Path + + is_local_model = ( + Path(use_medusa).exists() and Path(use_medusa).is_dir() + ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None + + if not is_local_model: + medusa_config = hf_hub_download( + use_medusa, revision=revision, filename="config.json" + ) + medusa_head = hf_hub_download( + use_medusa, revision=revision, filename="medusa_lm_head.pt" + ) + else: + medusa_config = str(Path(use_medusa) / "config.json") + medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + + with open(medusa_config, "r") as f: + config = json.load(f) + medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" + weights = Weights( + [medusa_sf], device, dtype, process_group=self.process_group + ) + lm_head = model.lm_head + model.lm_head = MedusaModel(config, weights, lm_head) + + torch.distributed.barrier(group=self.process_group) + super(FlashGemma, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) From 9c1cb81cd8fc01c8f736c554f316ba42b0695717 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 21 Feb 2024 14:50:57 +0100 Subject: [PATCH 3/4] v1.4.2 (#1585) --- Cargo.lock | 138 +++++++++++++++---------------- Cargo.toml | 2 +- docs/openapi.json | 2 +- integration-tests/pyproject.toml | 2 +- router/src/main.rs | 2 +- server/pyproject.toml | 2 +- 6 files changed, 73 insertions(+), 75 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ccccdf3c..f52e8038 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" +checksum = "d713b3834d76b85304d4d525563c1276e2e30dc97cc67bfb4585a4a29fc2c89f" dependencies = [ "cfg-if", "getrandom", @@ -42,9 +42,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.11" +version = "0.6.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" +checksum = "96b09b5178381e0874812a9b157f7fe84982617e48f71f4e3235482775e5b540" dependencies = [ "anstyle", "anstyle-parse", @@ -90,9 +90,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.79" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" [[package]] name = "arc-swap" @@ -130,7 +130,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -141,7 +141,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -305,9 +305,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.0" +version = "3.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d32a994c2b3ca201d9b263612a374263f05e7adde37c4707f693dcd375076d1f" +checksum = "c764d619ca78fccbf3069b37bd7af92577f044bb15236036662d79b6559f25b7" [[package]] name = "bytecount" @@ -367,12 +367,9 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" [[package]] name = "cc" -version = "1.0.83" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] +checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730" [[package]] name = "cfg-if" @@ -382,9 +379,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.0" +version = "4.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80c21025abd42669a92efc996ef13cfb2c5c627858421ea58d5c3b331a6c134f" +checksum = "c918d541ef2913577a0f9566e9ce27cb35b6df072075769e0b26cb5a554520da" dependencies = [ "clap_builder", "clap_derive", @@ -392,9 +389,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.0" +version = "4.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458bf1f341769dfcf849846f65dffdf9146daa56bcd2a47cb4e1de9915567c99" +checksum = "9f3e7391dad68afb0c2ede1bf619f579a3dc9c2ec67f089baa397123a2f3d1eb" dependencies = [ "anstream", "anstyle", @@ -411,7 +408,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -873,7 +870,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -1441,7 +1438,7 @@ checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -1529,7 +1526,7 @@ checksum = "f686d68a09079e63b1d2c64aa305095887ce50565f00a922ebfaeeee0d9ba6ce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -1804,9 +1801,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.63" +version = "0.10.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15c9d69dd87a29568d4d017cfe8ec518706046a05184e5aea92d0af890b803c8" +checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" dependencies = [ "bitflags 2.4.2", "cfg-if", @@ -1825,7 +1822,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -1836,9 +1833,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.99" +version = "0.9.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22e1bf214306098e4832460f797824c05d25aacdf896f64a985fb0fd992454ae" +checksum = "ae94056a791d0e1217d18b6cbdccb02c61e3054fc69893607f4067e3bb0b1fd1" dependencies = [ "cc", "libc", @@ -2041,7 +2038,7 @@ checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -2087,7 +2084,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" dependencies = [ "proc-macro2", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -2160,7 +2157,7 @@ dependencies = [ "prost 0.12.3", "prost-types", "regex", - "syn 2.0.49", + "syn 2.0.50", "tempfile", "which", ] @@ -2188,7 +2185,7 @@ dependencies = [ "itertools 0.11.0", "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -2439,16 +2436,17 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.7" +version = "0.17.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", + "cfg-if", "getrandom", "libc", "spin 0.9.8", "untrusted 0.9.0", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -2472,7 +2470,7 @@ dependencies = [ "quote", "rust-embed-utils", "shellexpand", - "syn 2.0.49", + "syn 2.0.50", "walkdir", ] @@ -2533,7 +2531,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" dependencies = [ "log", - "ring 0.17.7", + "ring 0.17.8", "rustls-pki-types", "rustls-webpki", "subtle", @@ -2561,7 +2559,7 @@ version = "0.102.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" dependencies = [ - "ring 0.17.7", + "ring 0.17.8", "rustls-pki-types", "untrusted 0.9.0", ] @@ -2574,9 +2572,9 @@ checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" [[package]] name = "ryu" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "same-file" @@ -2608,7 +2606,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.7", + "ring 0.17.8", "untrusted 0.9.0", ] @@ -2637,38 +2635,38 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.21" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" dependencies = [ "serde", ] [[package]] name = "serde" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] name = "serde_json" -version = "1.0.113" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ "itoa", "ryu", @@ -2851,7 +2849,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -2873,9 +2871,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.49" +version = "2.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915aea9e586f80826ee59f8453c1101f9d1c4b3964cd2460185ee8e299ada496" +checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb" dependencies = [ "proc-macro2", "quote", @@ -2961,7 +2959,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "1.4.1" +version = "1.4.2" dependencies = [ "average", "clap", @@ -2982,7 +2980,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "1.4.1" +version = "1.4.2" dependencies = [ "futures", "grpc-metadata", @@ -2998,7 +2996,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "1.4.1" +version = "1.4.2" dependencies = [ "clap", "ctrlc", @@ -3014,7 +3012,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "1.4.1" +version = "1.4.2" dependencies = [ "async-stream", "axum", @@ -3067,14 +3065,14 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] name = "thread_local" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" dependencies = [ "cfg-if", "once_cell", @@ -3233,7 +3231,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3348,7 +3346,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3421,7 +3419,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3551,9 +3549,9 @@ checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" dependencies = [ "tinyvec", ] @@ -3662,7 +3660,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3767,7 +3765,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", "wasm-bindgen-shared", ] @@ -3801,7 +3799,7 @@ checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3828,7 +3826,7 @@ version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring 0.17.7", + "ring 0.17.8", "untrusted 0.9.0", ] @@ -4128,7 +4126,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ce8148d1..02acaf5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ members = [ resolver = "2" [workspace.package] -version = "1.4.1" +version = "1.4.2" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/docs/openapi.json b/docs/openapi.json index fad01aec..89ffa0d7 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "1.4.1" + "version": "1.4.2" }, "paths": { "/": { diff --git a/integration-tests/pyproject.toml b/integration-tests/pyproject.toml index e42c0d88..b0b4a07c 100644 --- a/integration-tests/pyproject.toml +++ b/integration-tests/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-integration-tests" -version = "1.4.1" +version = "1.4.2" description = "Text Generation Inference integration tests" authors = ["Nicolas Patry "] diff --git a/router/src/main.rs b/router/src/main.rs index 60a66a41..6a736b12 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -270,7 +270,7 @@ async fn main() -> Result<(), RouterError> { let compat_return_full_text = match &model_info.pipeline_tag { None => { tracing::warn!("no pipeline tag found for model {tokenizer_name}"); - false + true } Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", }; diff --git a/server/pyproject.toml b/server/pyproject.toml index ad67745c..ba19c031 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-server" -version = "1.4.1" +version = "1.4.2" description = "Text Generation Inference Python gRPC Server" authors = ["Olivier Dehaene "] From 010508cec8e60e3995f4d0a0f98f30c499c28b59 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 21 Feb 2024 15:30:45 +0100 Subject: [PATCH 4/4] fix: fix openapi schema (#1586) --- docs/openapi.json | 117 +++++++++++++++++++++++++++++++++++++++---- router/src/lib.rs | 2 +- router/src/server.rs | 12 +++-- 3 files changed, 118 insertions(+), 13 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 89ffa0d7..f8e52a8d 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -637,6 +637,35 @@ } } }, + "ChatCompletionComplete": { + "type": "object", + "required": [ + "index", + "message", + "finish_reason" + ], + "properties": { + "finish_reason": { + "type": "string" + }, + "index": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "logprobs": { + "allOf": [ + { + "$ref": "#/components/schemas/ChatCompletionLogprobs" + } + ], + "nullable": true + }, + "message": { + "$ref": "#/components/schemas/Message" + } + } + }, "ChatCompletionDelta": { "type": "object", "required": [ @@ -654,6 +683,59 @@ } } }, + "ChatCompletionLogprob": { + "type": "object", + "required": [ + "token", + "logprob", + "top_logprobs" + ], + "properties": { + "logprob": { + "type": "number", + "format": "float" + }, + "token": { + "type": "string" + }, + "top_logprobs": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ChatCompletionTopLogprob" + } + } + } + }, + "ChatCompletionLogprobs": { + "type": "object", + "required": [ + "content" + ], + "properties": { + "content": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ChatCompletionLogprob" + } + } + } + }, + "ChatCompletionTopLogprob": { + "type": "object", + "required": [ + "token", + "logprob" + ], + "properties": { + "logprob": { + "type": "number", + "format": "float" + }, + "token": { + "type": "string" + } + } + }, "ChatRequest": { "type": "object", "required": [ @@ -1038,15 +1120,7 @@ ] }, "value": { - "type": "string", - "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.", - "example": { - "properties": { - "location": { - "type": "string" - } - } - } + "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions." } } }, @@ -1362,6 +1436,31 @@ "items": { "$ref": "#/components/schemas/SimpleToken" } + }, + "Usage": { + "type": "object", + "required": [ + "prompt_tokens", + "completion_tokens", + "total_tokens" + ], + "properties": { + "completion_tokens": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "prompt_tokens": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "total_tokens": { + "type": "integer", + "format": "int32", + "minimum": 0 + } + } } } }, diff --git a/router/src/lib.rs b/router/src/lib.rs index c6928a5a..113e2642 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -347,7 +347,7 @@ pub(crate) struct ChatCompletionTopLogprob { logprob: f32, } -#[derive(Clone, Deserialize, Serialize)] +#[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, diff --git a/router/src/server.rs b/router/src/server.rs index ebde7133..9fdd66cc 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -3,11 +3,12 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ - BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, - ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, + BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, + ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, + ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, - StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse, + StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -896,9 +897,13 @@ pub async fn run( GrammarType, ChatRequest, Message, + ChatCompletionComplete, ChatCompletionChoice, ChatCompletionDelta, ChatCompletionChunk, + ChatCompletionLogprob, + ChatCompletionLogprobs, + ChatCompletionTopLogprob, ChatCompletion, GenerateParameters, PrefillToken, @@ -913,6 +918,7 @@ pub async fn run( StreamDetails, ErrorResponse, GrammarType, + Usage, ) ), tags(