From 55acb86f42fae5dc7260a938cb31a33b86dcced3 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 15 Feb 2024 04:28:10 -0500 Subject: [PATCH] Outlines guided generation (#1539) This WIP PR starts to add grammar support via outlines, currently this PR supports very simple regex grammars and does not optimize for precompiling or caching grammar fsm's. todo: - [X] add simple outlines guidance to `NextTokenChooser` - [X] update protos for grammar - [X] update generation params API - [X] constrain simple grammar - [ ] support parsing more complex grammar into fsm - [ ] support all outline support grammar types - [ ] explore optimizations to avoid recompiling grammars guided request ```bash curl -s 'http://localhost:3000/generate' \ --header 'Content-Type: application/json' \ --data-raw '{ "inputs": "make an email for david: \n", "parameters": { "max_new_tokens": 6, "grammar": "[\\w-]+@([\\w-]+\\.)+[\\w-]+" } }' | jq ``` response ```json { "generated_text": "david@example.com" } ``` unguided request ```bash curl -s 'http://localhost:3000/generate' \ --header 'Content-Type: application/json' \ --data '{ "inputs": "make an email for david: \n", "parameters": { "max_new_tokens": 6 } }' | jq ``` response ```json { "generated_text": " email = 'david" } ``` --- benchmark/src/lib.rs | 4 +- clients/python/text_generation/client.py | 10 + clients/python/text_generation/types.py | 28 +- docs/source/basic_tutorials/launcher.md | 8 + integration-tests/conftest.py | 30 +- .../test_flash_llama_grammar.json | 89 ++++ .../test_flash_llama_grammar_json.json | 274 ++++++++++ .../test_flash_llama_grammar_load.json | 478 +++++++++++++++++ .../test_flash_llama_grammar_regex.json | 109 ++++ ...sh_llama_grammar_single_load_instance.json | 73 +++ .../models/test_grammar_llama.py | 151 ++++++ launcher/src/main.rs | 10 + proto/generate.proto | 10 + router/client/src/client.rs | 4 + router/client/src/lib.rs | 4 +- router/src/health.rs | 3 + router/src/lib.rs | 40 ++ router/src/main.rs | 4 + router/src/queue.rs | 8 +- router/src/server.rs | 3 + router/src/validation.rs | 46 +- server/poetry.lock | 492 +++++++++++++++++- server/pyproject.toml | 1 + .../models/causal_lm.py | 5 +- .../models/flash_causal_lm.py | 24 +- .../models/flash_mistral.py | 2 +- .../models/galactica.py | 2 +- .../models/idefics_causal_lm.py | 5 +- server/text_generation_server/models/mamba.py | 5 +- .../models/seq2seq_lm.py | 5 +- .../utils/logits_process.py | 142 ++++- server/text_generation_server/utils/tokens.py | 92 +++- 32 files changed, 2128 insertions(+), 33 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json create mode 100644 integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json create mode 100644 integration-tests/models/test_grammar_llama.py diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 6deae48d..638c6514 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -8,7 +8,7 @@ use crate::app::App; use crate::event::Event; use crossterm::ExecutableCommand; use std::io; -use text_generation_client::{NextTokenChooserParameters, ShardedClient}; +use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient}; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; use tui::backend::CrosstermBackend; @@ -45,6 +45,8 @@ pub async fn run( repetition_penalty: repetition_penalty.unwrap_or(1.0), frequency_penalty: frequency_penalty.unwrap_or(0.0), watermark, + grammar: String::new(), + grammar_type: GrammarType::None as i32, }; // Initialize terminal properties diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 0bf80f8c..bbccbf1d 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -10,6 +10,7 @@ from text_generation.types import ( Response, Request, Parameters, + Grammar, ) from text_generation.errors import parse_error @@ -76,6 +77,7 @@ class Client: watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, ) -> Response: """ Given a prompt, generate the following text @@ -138,6 +140,7 @@ class Client: watermark=watermark, decoder_input_details=decoder_input_details, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -169,6 +172,7 @@ class Client: typical_p: Optional[float] = None, watermark: bool = False, top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -227,6 +231,7 @@ class Client: typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -326,6 +331,7 @@ class AsyncClient: watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -370,6 +376,7 @@ class AsyncClient: Returns: Response: generated response """ + # Validate parameters parameters = Parameters( best_of=best_of, @@ -388,6 +395,7 @@ class AsyncClient: typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -417,6 +425,7 @@ class AsyncClient: typical_p: Optional[float] = None, watermark: bool = False, top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -475,6 +484,7 @@ class AsyncClient: typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index aa02d8d8..3426411b 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,10 +1,24 @@ from enum import Enum from pydantic import BaseModel, validator -from typing import Optional, List +from typing import Optional, List, Union from text_generation.errors import ValidationError +# enum for grammar type +class GrammarType(str, Enum): + Json = "json" + Regex = "regex" + + +# Grammar type and value +class Grammar(BaseModel): + # Grammar type + type: GrammarType + # Grammar value + value: Union[str, dict] + + class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False @@ -41,6 +55,8 @@ class Parameters(BaseModel): decoder_input_details: bool = False # Return the N most likely tokens at each step top_n_tokens: Optional[int] = None + # grammar to use for generation + grammar: Optional[Grammar] = None @validator("best_of") def valid_best_of(cls, field_value, values): @@ -109,6 +125,14 @@ class Parameters(BaseModel): raise ValidationError("`top_n_tokens` must be strictly positive") return v + @validator("grammar") + def valid_grammar(cls, v): + if v is not None: + if v.type == GrammarType.Regex and not v.value: + raise ValidationError("`value` cannot be empty for `regex` grammar") + if v.type == GrammarType.Json and not v.value: + raise ValidationError("`value` cannot be empty for `json` grammar") + return v class Request(BaseModel): # Prompt @@ -157,7 +181,7 @@ class Token(BaseModel): # Token text text: str # Logprob - logprob: float + logprob: Optional[float] = None # Is the token a special token # Can be used to ignore tokens when concatenating special: bool diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index be31a7a4..36fa1241 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -378,6 +378,14 @@ Options: [env: TOKENIZER_CONFIG_PATH=] +``` +## DISABLE_GRAMMAR_SUPPORT +```shell + --disable-grammar-support + Disable outlines grammar constrained generation. This is a feature that allows you to generate text that follows a specific grammar + + [env: DISABLE_GRAMMAR_SUPPORT=] + ``` ## ENV ```shell diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index efeda08d..e0228894 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from text_generation import AsyncClient -from text_generation.types import Response, Details, InputToken, Token, BestOfSequence +from text_generation.types import ( + Response, + Details, + InputToken, + Token, + BestOfSequence, + Grammar, +) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) @@ -224,6 +231,7 @@ def launcher(event_loop): quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, + disable_grammar_support: bool = False, dtype: Optional[str] = None, ): port = random.randint(8000, 10_000) @@ -247,6 +255,8 @@ def launcher(event_loop): env = os.environ + if disable_grammar_support: + args.append("--disable-grammar-support") if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) if quantize is not None: @@ -287,12 +297,15 @@ def launcher(event_loop): quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, + disable_grammar_support: bool = False, dtype: Optional[str] = None, ): port = random.randint(8000, 10_000) args = ["--model-id", model_id, "--env"] + if disable_grammar_support: + args.append("--disable-grammar-support") if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) if quantize is not None: @@ -370,11 +383,22 @@ def launcher(event_loop): @pytest.fixture(scope="module") def generate_load(): async def generate_load_inner( - client: AsyncClient, prompt: str, max_new_tokens: int, n: int + client: AsyncClient, + prompt: str, + max_new_tokens: int, + n: int, + seed: Optional[int] = None, + grammar: Optional[Grammar] = None, + stop_sequences: Optional[List[str]] = None, ) -> List[Response]: futures = [ client.generate( - prompt, max_new_tokens=max_new_tokens, decoder_input_details=True + prompt, + max_new_tokens=max_new_tokens, + decoder_input_details=True, + seed=seed, + grammar=grammar, + stop_sequences=stop_sequences, ) for _ in range(n) ] diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json new file mode 100644 index 00000000..0e87f59e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -13.90625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.328125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.0566406, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.5253906, + "special": false, + "text": "\n" + }, + { + "id": 29902, + "logprob": -2.7578125, + "special": false, + "text": "I" + }, + { + "id": 4966, + "logprob": -1.9033203, + "special": false, + "text": " hope" + }, + { + "id": 445, + "logprob": -0.5019531, + "special": false, + "text": " this" + }, + { + "id": 6911, + "logprob": -0.21264648, + "special": false, + "text": " helps" + }, + { + "id": 29991, + "logprob": -0.5991211, + "special": false, + "text": "!" + }, + { + "id": 2803, + "logprob": -0.37475586, + "special": false, + "text": " Let" + }, + { + "id": 592, + "logprob": -0.018463135, + "special": false, + "text": " me" + }, + { + "id": 1073, + "logprob": -0.0008597374, + "special": false, + "text": " know" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nI hope this helps! Let me know" +} diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json new file mode 100644 index 00000000..7b12b158 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json @@ -0,0 +1,274 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 30, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 5235, + "logprob": -10.0625, + "text": "info" + }, + { + "id": 29901, + "logprob": -3.2324219, + "text": ":" + }, + { + "id": 13260, + "logprob": -10.625, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.08276367, + "text": "id" + }, + { + "id": 8753, + "logprob": -7.5273438, + "text": "hol" + }, + { + "id": 17559, + "logprob": -3.8476562, + "text": "tz" + }, + { + "id": 763, + "logprob": -10.140625, + "text": "like" + }, + { + "id": 10697, + "logprob": -10.1953125, + "text": "trees" + }, + { + "id": 322, + "logprob": -2.5742188, + "text": "and" + }, + { + "id": 756, + "logprob": -7.4882812, + "text": "has" + }, + { + "id": 1023, + "logprob": -5.0507812, + "text": "two" + }, + { + "id": 274, + "logprob": -5.3164062, + "text": "c" + }, + { + "id": 1446, + "logprob": -0.6694336, + "text": "ats" + }, + { + "id": 29889, + "logprob": -0.9995117, + "text": "." + }, + { + "id": 29871, + "logprob": -4.2421875, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 6377, + "logprob": -0.14916992, + "special": false, + "text": "{\"" + }, + { + "id": 29888, + "logprob": -0.13598633, + "special": false, + "text": "f" + }, + { + "id": 12935, + "logprob": -0.017669678, + "special": false, + "text": "irs" + }, + { + "id": 29873, + "logprob": -0.00085639954, + "special": false, + "text": "t" + }, + { + "id": 1170, + "logprob": -0.0054016113, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.13549805, + "special": false, + "text": "\":\"" + }, + { + "id": 19504, + "logprob": -0.8852539, + "special": false, + "text": "David" + }, + { + "id": 3284, + "logprob": -0.16394043, + "special": false, + "text": "\",\"" + }, + { + "id": 4230, + "logprob": -0.020492554, + "special": false, + "text": "last" + }, + { + "id": 1170, + "logprob": -0.0013818741, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.0067749023, + "special": false, + "text": "\":\"" + }, + { + "id": 29950, + "logprob": -0.11578369, + "special": false, + "text": "H" + }, + { + "id": 14339, + "logprob": -0.004131317, + "special": false, + "text": "olt" + }, + { + "id": 29920, + "logprob": -0.0033359528, + "special": false, + "text": "z" + }, + { + "id": 3284, + "logprob": -0.20471191, + "special": false, + "text": "\",\"" + }, + { + "id": 29882, + "logprob": -0.0069274902, + "special": false, + "text": "h" + }, + { + "id": 20838, + "logprob": -0.19580078, + "special": false, + "text": "obb" + }, + { + "id": 29891, + "logprob": -2.2649765e-06, + "special": false, + "text": "y" + }, + { + "id": 4710, + "logprob": -0.32080078, + "special": false, + "text": "\":\"" + }, + { + "id": 29911, + "logprob": -2.1035156, + "special": false, + "text": "T" + }, + { + "id": 11003, + "logprob": -0.020767212, + "special": false, + "text": "rees" + }, + { + "id": 3284, + "logprob": -0.6010742, + "special": false, + "text": "\",\"" + }, + { + "id": 29876, + "logprob": -0.57666016, + "special": false, + "text": "n" + }, + { + "id": 398, + "logprob": -0.0061073303, + "special": false, + "text": "um" + }, + { + "id": 29907, + "logprob": -0.45703125, + "special": false, + "text": "C" + }, + { + "id": 1446, + "logprob": -0.0002872944, + "special": false, + "text": "ats" + }, + { + "id": 1115, + "logprob": -0.0021018982, + "special": false, + "text": "\":" + }, + { + "id": 29906, + "logprob": -0.08996582, + "special": false, + "text": "2" + }, + { + "id": 29913, + "logprob": -0.021697998, + "special": false, + "text": "}" + }, + { + "id": 2, + "logprob": 0.0, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}" +} diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json new file mode 100644 index 00000000..b7b26a2c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json @@ -0,0 +1,478 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.03125, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04244995, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4863281, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32714844, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7685547, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.2376709, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.01008606, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64160156, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.5, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46557617, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5341797, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022907257, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.23840332, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.23840332, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.23840332, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + } +] diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json new file mode 100644 index 00000000..1ba9ae1e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json @@ -0,0 +1,109 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 806, + "logprob": -11.890625, + "text": "Wh" + }, + { + "id": 1446, + "logprob": -3.6699219, + "text": "ats" + }, + { + "id": 2921, + "logprob": -7.8203125, + "text": "Go" + }, + { + "id": 468, + "logprob": -8.0703125, + "text": "og" + }, + { + "id": 793, + "logprob": -2.1875, + "text": "les" + }, + { + "id": 16332, + "logprob": -9.7109375, + "text": "DNS" + } + ], + "seed": null, + "tokens": [ + { + "id": 29946, + "logprob": -1.4765625, + "special": false, + "text": "4" + }, + { + "id": 29906, + "logprob": -0.9199219, + "special": false, + "text": "2" + }, + { + "id": 29889, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -1.1367188, + "special": false, + "text": "1" + }, + { + "id": 29889, + "logprob": -1.4648438, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -0.40722656, + "special": false, + "text": "1" + }, + { + "id": 29889, + "logprob": -0.17419434, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -0.20251465, + "special": false, + "text": "1" + }, + { + "id": 29900, + "logprob": -1.5527344, + "special": false, + "text": "0" + }, + { + "id": 29896, + "logprob": -1.3710938, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": "42.1.1.101" +} diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json new file mode 100644 index 00000000..7ffb17cb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7685547, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.33666992, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.009979248, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.53759766, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.0008878708, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.002275467, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" +} diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py new file mode 100644 index 00000000..f068496c --- /dev/null +++ b/integration-tests/models/test_grammar_llama.py @@ -0,0 +1,151 @@ +import pytest +import json + +from text_generation.types import GrammarType + + +@pytest.fixture(scope="module") +def flash_llama_grammar_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_grammar(flash_llama_grammar_handle): + await flash_llama_grammar_handle.health(300) + return flash_llama_grammar_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "Whats Googles DNS", + max_new_tokens=10, + decoder_input_details=True, + seed=0, + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", + }, + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "42.1.1.101" + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "info: david holtz like trees and has two cats. ", + max_new_tokens=100, + decoder_input_details=True, + seed=0, + grammar={ + "type": GrammarType.Json, # "json" + "value": json.dumps( + { + "type": "object", + "$id": "https://example.com/person.schema.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Person", + "properties": { + "firstName": { + "type": "string", + "description": "The person'''s first name.", + }, + "lastName": { + "type": "string", + "description": "The person'''s last name.", + }, + "hobby": { + "description": "The person'''s hobby.", + "type": "string", + }, + "numCats": { + "description": "The number of cats the person has.", + "type": "integer", + "minimum": 0, + }, + }, + "required": ["firstName", "lastName", "hobby", "numCats"], + } + ), + }, + ) + + assert response.details.generated_tokens == 30 + assert ( + response.generated_text + == '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}' + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_load( + flash_llama_grammar, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_grammar, + "name: david. email: ", + max_new_tokens=10, + n=4, + stop_sequences=[".com"], + seed=0, + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + }, + ) + + assert len(responses) == 4 + + expected = "123456@gmail.com" + + for response in responses: + assert response.generated_text == expected + + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot + + +# this is the same as the above test, but only fires off a single request +# this is only to ensure that the parallel and single inference produce the same result +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_single_load_instance( + flash_llama_grammar, generate_load, response_snapshot +): + response = await flash_llama_grammar.generate( + "name: david. email: ", + max_new_tokens=10, + stop_sequences=[".com"], + seed=0, + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + }, + ) + + # assert response.details.generated_tokens == 30 + assert response.generated_text == "123456@gmail.com" + + assert response == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b5b476fa..35e23316 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -384,6 +384,11 @@ struct Args { #[clap(long, env)] tokenizer_config_path: Option, + /// Disable outlines grammar constrained generation. + /// This is a feature that allows you to generate text that follows a specific grammar. + #[clap(long, env)] + disable_grammar_support: bool, + /// Display a lot of information about your runtime environment #[clap(long, short, action)] env: bool, @@ -1061,6 +1066,11 @@ fn spawn_webserver( args.model_id, ]; + // Grammar support + if args.disable_grammar_support { + router_args.push("--disable-grammar-support".to_string()); + } + // Tokenizer config path if let Some(ref tokenizer_config_path) = args.tokenizer_config_path { router_args.push("--tokenizer-config-path".to_string()); diff --git a/proto/generate.proto b/proto/generate.proto index dde344f6..5bcf792d 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -51,6 +51,12 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} +enum GrammarType { + GRAMMAR_TYPE_NONE = 0; + GRAMMAR_TYPE_JSON = 1; + GRAMMAR_TYPE_REGEX = 2; +} + message NextTokenChooserParameters { /// exponential scaling output probability distribution float temperature = 1; @@ -70,6 +76,10 @@ message NextTokenChooserParameters { float frequency_penalty = 9; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; + /// grammar (applied if not empty) + string grammar = 10; + /// grammar type + GrammarType grammar_type = 11; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 592338fa..51b75a49 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -221,6 +221,8 @@ impl Client { repetition_penalty: 1.0, frequency_penalty: 0.0, watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, }) } else { Some(NextTokenChooserParameters { @@ -233,6 +235,8 @@ impl Client { repetition_penalty: 1.2, frequency_penalty: 0.1, watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, }) }; requests.push(Request { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index c38b931b..6782d9ff 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,8 +9,8 @@ pub use client::Client; pub use pb::generate::v2::HealthResponse; pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, - Request, StoppingCriteriaParameters, Tokens, + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/health.rs b/router/src/health.rs index e830a3c3..b05b3094 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -1,5 +1,6 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use text_generation_client::GrammarType as ProtoGrammarType; use text_generation_client::{ Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; @@ -45,6 +46,8 @@ impl Health { repetition_penalty: 1.0, frequency_penalty: 0.0, watermark: false, + grammar: String::new(), + grammar_type: ProtoGrammarType::None as i32, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/lib.rs b/router/src/lib.rs index a9d783bb..87873821 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -45,6 +45,43 @@ impl HubTokenizerConfig { } } +mod json_object_or_string_to_string { + use serde::{Deserialize, Deserializer}; + use serde_json::Value; + + // A custom deserializer that treats both strings and objects as strings. + // This provides flexibility with input formats for the 'grammar' field. + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::String(s) => Ok(s), + // Safely handle serialization and return an error if it fails + Value::Object(o) => { + serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string())) + } + _ => Err(serde::de::Error::custom( + "expected string or object for grammar", + )), + } + } +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(tag = "type", content = "value")] +pub(crate) enum GrammarType { + #[serde( + rename = "json", + deserialize_with = "json_object_or_string_to_string::deserialize" + )] + Json(String), + #[serde(rename = "regex")] + Regex(String), +} + mod token_serde { use super::*; use serde::de; @@ -201,6 +238,8 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] pub top_n_tokens: Option, + #[serde(default)] + pub grammar: Option, } fn default_max_new_tokens() -> Option { @@ -226,6 +265,7 @@ fn default_parameters() -> GenerateParameters { decoder_input_details: false, seed: None, top_n_tokens: None, + grammar: None, } } diff --git a/router/src/main.rs b/router/src/main.rs index 1757e459..472e9f7f 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -78,6 +78,8 @@ struct Args { ngrok_edge: Option, #[clap(long, env, default_value_t = false)] messages_api_enabled: bool, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, } #[tokio::main] @@ -111,6 +113,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, messages_api_enabled, + disable_grammar_support, } = args; // Launch Tokio runtime @@ -388,6 +391,7 @@ async fn main() -> Result<(), RouterError> { ngrok_edge, tokenizer_config, messages_api_enabled, + disable_grammar_support, ) .await?; Ok(()) diff --git a/router/src/queue.rs b/router/src/queue.rs index 00021812..e98bc43d 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -472,7 +472,9 @@ enum QueueCommand { #[cfg(test)] mod tests { use super::*; - use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; + use text_generation_client::{ + GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, + }; use tracing::info_span; fn default_queue() -> Queue { @@ -495,7 +497,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { - inputs: "".to_string(), + inputs: String::new(), input_length: 0, truncate: 0, decoder_input_details: false, @@ -509,6 +511,8 @@ mod tests { repetition_penalty: 0.0, frequency_penalty: 0.0, watermark: false, + grammar: String::new(), + grammar_type: ProtoGrammarType::None as i32, }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, diff --git a/router/src/server.rs b/router/src/server.rs index 450494df..7d9c453c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -616,6 +616,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: None, + grammar: None, }, }; @@ -781,6 +782,7 @@ pub async fn run( ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, messages_api_enabled: bool, + grammar_support: bool, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -842,6 +844,7 @@ pub async fn run( max_top_n_tokens, max_input_length, max_total_tokens, + grammar_support, ); let generation_health = Arc::new(AtomicBool::new(false)); let health_ext = Health::new(client.clone(), generation_health.clone()); diff --git a/router/src/validation.rs b/router/src/validation.rs index 5802391d..b3a4ef1b 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -2,10 +2,12 @@ /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; -use crate::{GenerateParameters, GenerateRequest}; +use crate::{GenerateParameters, GenerateRequest, GrammarType}; use rand::{thread_rng, Rng}; use std::env; -use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; +use text_generation_client::{ + GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, +}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokenizers::TruncationDirection; @@ -22,6 +24,7 @@ pub struct Validation { max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, + disable_grammar_support: bool, /// Channel to communicate with the background tokenization task sender: Option>, skip_tokenizer_in_tgi: bool, @@ -36,6 +39,7 @@ impl Validation { max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, + disable_grammar_support: bool, ) -> Self { // If we have a fast tokenizer let sender = if let Some(tokenizer) = tokenizer { @@ -74,6 +78,7 @@ impl Validation { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, skip_tokenizer_in_tgi, } } @@ -199,6 +204,7 @@ impl Validation { watermark, decoder_input_details, top_n_tokens, + grammar, .. } = request.parameters; @@ -317,6 +323,28 @@ impl Validation { .validate_input(request.inputs, truncate, max_new_tokens) .await?; + // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar + // NOTE: this is currently difficult because we need the tokenizer in Python to build + // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which + // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM + // compiler and use that to build the FSM here. + + // Validate grammar and unpack the grammar and type for the proto message + let (grammar, grammar_type) = match grammar { + Some(grammar) => { + // Ensure that grammar is not set if it's not supported + if self.disable_grammar_support { + return Err(ValidationError::Grammar); + } + match grammar { + // currently both are handled the same way since compilation is done in Python + GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()), + GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), + } + } + None => (String::new(), ProtoGrammarType::None.into()), + }; + let parameters = NextTokenChooserParameters { temperature, repetition_penalty, @@ -327,6 +355,8 @@ impl Validation { do_sample, seed, watermark, + grammar, + grammar_type, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, @@ -478,6 +508,8 @@ pub enum ValidationError { StopSequence(usize, usize), #[error("tokenizer error {0}")] Tokenizer(String), + #[error("grammar is not supported")] + Grammar, #[error("`watermark` = true is not allowed with FP8 quantization.")] WatermarkWithQuantization, } @@ -497,6 +529,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 6; let workers = 1; + let disable_grammar_support = true; let validation = Validation::new( workers, tokenizer, @@ -505,6 +538,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, ); let max_new_tokens = 10; @@ -525,6 +559,7 @@ mod tests { let max_top_n_tokens = 4; let max_input_length = 5; let max_total_tokens = 6; + let disable_grammar_support = true; let workers = 1; let validation = Validation::new( workers, @@ -534,6 +569,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, ); let max_new_tokens = 10; @@ -555,6 +591,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 6; let workers = 1; + let disable_grammar_support = true; let validation = Validation::new( workers, tokenizer, @@ -563,6 +600,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, ); match validation .validate(GenerateRequest { @@ -589,6 +627,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 106; let workers = 1; + let disable_grammar_support = true; let validation = Validation::new( workers, tokenizer, @@ -597,6 +636,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, ); match validation .validate(GenerateRequest { @@ -652,6 +692,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 106; let workers = 1; + let disable_grammar_support = true; let validation = Validation::new( workers, tokenizer, @@ -660,6 +701,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, ); match validation .validate(GenerateRequest { diff --git a/server/poetry.lock b/server/poetry.lock index f1445b1d..7afe275e 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -140,6 +140,17 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "annotated-types" +version = "0.6.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, + {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, +] + [[package]] name = "async-timeout" version = "4.0.3" @@ -305,6 +316,17 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "cloudpickle" +version = "3.0.0" +description = "Pickler class to extend the standard pickle.Pickler functionality" +optional = false +python-versions = ">=3.8" +files = [ + {file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"}, + {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, +] + [[package]] name = "colorama" version = "0.4.6" @@ -439,6 +461,17 @@ files = [ graph = ["objgraph (>=1.7.2)"] profile = ["gprof2dot (>=2022.7.29)"] +[[package]] +name = "diskcache" +version = "5.6.3" +description = "Disk Cache -- Disk and file backed persistent cache." +optional = false +python-versions = ">=3" +files = [ + {file = "diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19"}, + {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"}, +] + [[package]] name = "exceptiongroup" version = "1.2.1" @@ -945,6 +978,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "interegular" +version = "0.3.3" +description = "a regex intersection checker" +optional = false +python-versions = ">=3.7" +files = [ + {file = "interegular-0.3.3-py37-none-any.whl", hash = "sha256:b0c07007d48c89d6d19f7204972d369b2a77222722e126b6aa63aa721dc3b19c"}, + {file = "interegular-0.3.3.tar.gz", hash = "sha256:d9b697b21b34884711399ba0f0376914b81899ce670032486d0d048344a76600"}, +] + [[package]] name = "jinja2" version = "3.1.3" @@ -962,6 +1006,99 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.4.0" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.0-py3-none-any.whl", hash = "sha256:42942470d4062537be4d54c83511186da1fc14ba354961a2114da91efa9a4ed7"}, + {file = "joblib-1.4.0.tar.gz", hash = "sha256:1eb0dc091919cd384490de890cb5dfd538410a6d4b3b54eef09fb8c50b409b1c"}, +] + +[[package]] +name = "jsonschema" +version = "4.21.1" +description = "An implementation of JSON Schema validation for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema-4.21.1-py3-none-any.whl", hash = "sha256:7996507afae316306f9e2290407761157c6f78002dcf7419acb99822143d1c6f"}, + {file = "jsonschema-4.21.1.tar.gz", hash = "sha256:85727c00279f5fa6bedbe6238d2aa6403bedd8b4864ab11207d07df3cc1b2ee5"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +jsonschema-specifications = ">=2023.03.6" +referencing = ">=0.28.4" +rpds-py = ">=0.7.1" + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"] + +[[package]] +name = "jsonschema-specifications" +version = "2023.12.1" +description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"}, + {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"}, +] + +[package.dependencies] +referencing = ">=0.31.0" + +[[package]] +name = "lark" +version = "1.1.9" +description = "a modern parsing library" +optional = false +python-versions = ">=3.6" +files = [ + {file = "lark-1.1.9-py3-none-any.whl", hash = "sha256:a0dd3a87289f8ccbb325901e4222e723e7d745dbfc1803eaf5f3d2ace19cf2db"}, + {file = "lark-1.1.9.tar.gz", hash = "sha256:15fa5236490824c2c4aba0e22d2d6d823575dcaf4cdd1848e34b6ad836240fba"}, +] + +[package.extras] +atomic-cache = ["atomicwrites"] +interegular = ["interegular (>=0.3.1,<0.4.0)"] +nearley = ["js2py"] +regex = ["regex"] + +[[package]] +name = "llvmlite" +version = "0.42.0" +description = "lightweight wrapper around basic LLVM functionality" +optional = false +python-versions = ">=3.9" +files = [ + {file = "llvmlite-0.42.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3366938e1bf63d26c34fbfb4c8e8d2ded57d11e0567d5bb243d89aab1eb56098"}, + {file = "llvmlite-0.42.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c35da49666a21185d21b551fc3caf46a935d54d66969d32d72af109b5e7d2b6f"}, + {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70f44ccc3c6220bd23e0ba698a63ec2a7d3205da0d848804807f37fc243e3f77"}, + {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f8d8717a9073b9e0246998de89929071d15b47f254c10eef2310b9aac033d"}, + {file = "llvmlite-0.42.0-cp310-cp310-win_amd64.whl", hash = "sha256:8d90edf400b4ceb3a0e776b6c6e4656d05c7187c439587e06f86afceb66d2be5"}, + {file = "llvmlite-0.42.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ae511caed28beaf1252dbaf5f40e663f533b79ceb408c874c01754cafabb9cbf"}, + {file = "llvmlite-0.42.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81e674c2fe85576e6c4474e8c7e7aba7901ac0196e864fe7985492b737dbab65"}, + {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb3975787f13eb97629052edb5017f6c170eebc1c14a0433e8089e5db43bcce6"}, + {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5bece0cdf77f22379f19b1959ccd7aee518afa4afbd3656c6365865f84903f9"}, + {file = "llvmlite-0.42.0-cp311-cp311-win_amd64.whl", hash = "sha256:7e0c4c11c8c2aa9b0701f91b799cb9134a6a6de51444eff5a9087fc7c1384275"}, + {file = "llvmlite-0.42.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:08fa9ab02b0d0179c688a4216b8939138266519aaa0aa94f1195a8542faedb56"}, + {file = "llvmlite-0.42.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b2fce7d355068494d1e42202c7aff25d50c462584233013eb4470c33b995e3ee"}, + {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebe66a86dc44634b59a3bc860c7b20d26d9aaffcd30364ebe8ba79161a9121f4"}, + {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d47494552559e00d81bfb836cf1c4d5a5062e54102cc5767d5aa1e77ccd2505c"}, + {file = "llvmlite-0.42.0-cp312-cp312-win_amd64.whl", hash = "sha256:05cb7e9b6ce69165ce4d1b994fbdedca0c62492e537b0cc86141b6e2c78d5888"}, + {file = "llvmlite-0.42.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bdd3888544538a94d7ec99e7c62a0cdd8833609c85f0c23fcb6c5c591aec60ad"}, + {file = "llvmlite-0.42.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d0936c2067a67fb8816c908d5457d63eba3e2b17e515c5fe00e5ee2bace06040"}, + {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a78ab89f1924fc11482209f6799a7a3fc74ddc80425a7a3e0e8174af0e9e2301"}, + {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7599b65c7af7abbc978dbf345712c60fd596aa5670496561cc10e8a71cebfb2"}, + {file = "llvmlite-0.42.0-cp39-cp39-win_amd64.whl", hash = "sha256:43d65cc4e206c2e902c1004dd5418417c4efa6c1d04df05c6c5675a27e8ca90e"}, + {file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"}, +] + [[package]] name = "loguru" version = "0.6.0" @@ -1189,6 +1326,17 @@ files = [ [package.dependencies] dill = ">=0.3.8" +[[package]] +name = "nest-asyncio" +version = "1.6.0" +description = "Patch asyncio to allow nested event loops" +optional = false +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, + {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, +] + [[package]] name = "networkx" version = "3.2.1" @@ -1207,6 +1355,40 @@ doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9. extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] +[[package]] +name = "numba" +version = "0.59.1" +description = "compiling Python code using LLVM" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numba-0.59.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:97385a7f12212c4f4bc28f648720a92514bee79d7063e40ef66c2d30600fd18e"}, + {file = "numba-0.59.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0b77aecf52040de2a1eb1d7e314497b9e56fba17466c80b457b971a25bb1576d"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3476a4f641bfd58f35ead42f4dcaf5f132569c4647c6f1360ccf18ee4cda3990"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:525ef3f820931bdae95ee5379c670d5c97289c6520726bc6937a4a7d4230ba24"}, + {file = "numba-0.59.1-cp310-cp310-win_amd64.whl", hash = "sha256:990e395e44d192a12105eca3083b61307db7da10e093972ca285c85bef0963d6"}, + {file = "numba-0.59.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43727e7ad20b3ec23ee4fc642f5b61845c71f75dd2825b3c234390c6d8d64051"}, + {file = "numba-0.59.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:411df625372c77959570050e861981e9d196cc1da9aa62c3d6a836b5cc338966"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2801003caa263d1e8497fb84829a7ecfb61738a95f62bc05693fcf1733e978e4"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dd2842fac03be4e5324ebbbd4d2d0c8c0fc6e0df75c09477dd45b288a0777389"}, + {file = "numba-0.59.1-cp311-cp311-win_amd64.whl", hash = "sha256:0594b3dfb369fada1f8bb2e3045cd6c61a564c62e50cf1f86b4666bc721b3450"}, + {file = "numba-0.59.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1cce206a3b92836cdf26ef39d3a3242fec25e07f020cc4feec4c4a865e340569"}, + {file = "numba-0.59.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8c8b4477763cb1fbd86a3be7050500229417bf60867c93e131fd2626edb02238"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d80bce4ef7e65bf895c29e3889ca75a29ee01da80266a01d34815918e365835"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f7ad1d217773e89a9845886401eaaab0a156a90aa2f179fdc125261fd1105096"}, + {file = "numba-0.59.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bf68f4d69dd3a9f26a9b23548fa23e3bcb9042e2935257b471d2a8d3c424b7f"}, + {file = "numba-0.59.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4e0318ae729de6e5dbe64c75ead1a95eb01fabfe0e2ebed81ebf0344d32db0ae"}, + {file = "numba-0.59.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0f68589740a8c38bb7dc1b938b55d1145244c8353078eea23895d4f82c8b9ec1"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:649913a3758891c77c32e2d2a3bcbedf4a69f5fea276d11f9119677c45a422e8"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9712808e4545270291d76b9a264839ac878c5eb7d8b6e02c970dc0ac29bc8187"}, + {file = "numba-0.59.1-cp39-cp39-win_amd64.whl", hash = "sha256:8d51ccd7008a83105ad6a0082b6a2b70f1142dc7cfd76deb8c5a862367eb8c86"}, + {file = "numba-0.59.1.tar.gz", hash = "sha256:76f69132b96028d2774ed20415e8c528a34e3299a40581bae178f0994a2f370b"}, +] + +[package.dependencies] +llvmlite = "==0.42.*" +numpy = ">=1.22,<1.27" + [[package]] name = "numpy" version = "1.26.4" @@ -1613,6 +1795,39 @@ transformers = ">=4.37.0,<4.38.0" quality = ["hf-doc-builder", "ruff"] tests = ["GitPython", "datasets", "optuna", "parameterized", "psutil", "pytest", "safetensors", "sentencepiece"] +[[package]] +name = "outlines" +version = "0.0.27" +description = "Probabilistic Generative Model Programming" +optional = false +python-versions = ">=3.8" +files = [ + {file = "outlines-0.0.27-py3-none-any.whl", hash = "sha256:dd614f49760ff8870a5d491fad4a372d7b7d4da5c1646f1b42f12a9d5e34db4b"}, + {file = "outlines-0.0.27.tar.gz", hash = "sha256:debc49f0db4d968eea05a4a6134516b3e49c6c8607106aa097410a4b53d5af6c"}, +] + +[package.dependencies] +cloudpickle = "*" +diskcache = "*" +interegular = "*" +jinja2 = "*" +joblib = "*" +jsonschema = "*" +lark = "*" +nest-asyncio = "*" +numba = "*" +numpy = "*" +pydantic = ">=2.0" +referencing = "*" +requests = "*" +scipy = "*" +torch = ">=2.1.0" +transformers = "*" + +[package.extras] +serve = ["fastapi", "pydantic (>=2.0)", "ray (==2.9.0)", "uvicorn", "vllm (>=0.3.0)"] +test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"] + [[package]] name = "packaging" version = "24.0" @@ -1945,6 +2160,116 @@ files = [ {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, ] +[[package]] +name = "pydantic" +version = "2.7.1" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic-2.7.1-py3-none-any.whl", hash = "sha256:e029badca45266732a9a79898a15ae2e8b14840b1eabbb25844be28f0b33f3d5"}, + {file = "pydantic-2.7.1.tar.gz", hash = "sha256:e9dbb5eada8abe4d9ae5f46b9939aead650cd2b68f249bb3a8139dbe125803cc"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.18.2" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.18.2" +description = "Core functionality for Pydantic validation and serialization" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.18.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:9e08e867b306f525802df7cd16c44ff5ebbe747ff0ca6cf3fde7f36c05a59a81"}, + {file = "pydantic_core-2.18.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f0a21cbaa69900cbe1a2e7cad2aa74ac3cf21b10c3efb0fa0b80305274c0e8a2"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0680b1f1f11fda801397de52c36ce38ef1c1dc841a0927a94f226dea29c3ae3d"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:95b9d5e72481d3780ba3442eac863eae92ae43a5f3adb5b4d0a1de89d42bb250"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fcf5cd9c4b655ad666ca332b9a081112cd7a58a8b5a6ca7a3104bc950f2038"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b5155ff768083cb1d62f3e143b49a8a3432e6789a3abee8acd005c3c7af1c74"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:553ef617b6836fc7e4df130bb851e32fe357ce36336d897fd6646d6058d980af"}, + {file = "pydantic_core-2.18.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b89ed9eb7d616ef5714e5590e6cf7f23b02d0d539767d33561e3675d6f9e3857"}, + {file = "pydantic_core-2.18.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:75f7e9488238e920ab6204399ded280dc4c307d034f3924cd7f90a38b1829563"}, + {file = "pydantic_core-2.18.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ef26c9e94a8c04a1b2924149a9cb081836913818e55681722d7f29af88fe7b38"}, + {file = "pydantic_core-2.18.2-cp310-none-win32.whl", hash = "sha256:182245ff6b0039e82b6bb585ed55a64d7c81c560715d1bad0cbad6dfa07b4027"}, + {file = "pydantic_core-2.18.2-cp310-none-win_amd64.whl", hash = "sha256:e23ec367a948b6d812301afc1b13f8094ab7b2c280af66ef450efc357d2ae543"}, + {file = "pydantic_core-2.18.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:219da3f096d50a157f33645a1cf31c0ad1fe829a92181dd1311022f986e5fbe3"}, + {file = "pydantic_core-2.18.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cc1cfd88a64e012b74e94cd00bbe0f9c6df57049c97f02bb07d39e9c852e19a4"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b7133a6e6aeb8df37d6f413f7705a37ab4031597f64ab56384c94d98fa0e90"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:224c421235f6102e8737032483f43c1a8cfb1d2f45740c44166219599358c2cd"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b14d82cdb934e99dda6d9d60dc84a24379820176cc4a0d123f88df319ae9c150"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2728b01246a3bba6de144f9e3115b532ee44bd6cf39795194fb75491824a1413"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:470b94480bb5ee929f5acba6995251ada5e059a5ef3e0dfc63cca287283ebfa6"}, + {file = "pydantic_core-2.18.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:997abc4df705d1295a42f95b4eec4950a37ad8ae46d913caeee117b6b198811c"}, + {file = "pydantic_core-2.18.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:75250dbc5290e3f1a0f4618db35e51a165186f9034eff158f3d490b3fed9f8a0"}, + {file = "pydantic_core-2.18.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4456f2dca97c425231d7315737d45239b2b51a50dc2b6f0c2bb181fce6207664"}, + {file = "pydantic_core-2.18.2-cp311-none-win32.whl", hash = "sha256:269322dcc3d8bdb69f054681edff86276b2ff972447863cf34c8b860f5188e2e"}, + {file = "pydantic_core-2.18.2-cp311-none-win_amd64.whl", hash = "sha256:800d60565aec896f25bc3cfa56d2277d52d5182af08162f7954f938c06dc4ee3"}, + {file = "pydantic_core-2.18.2-cp311-none-win_arm64.whl", hash = "sha256:1404c69d6a676245199767ba4f633cce5f4ad4181f9d0ccb0577e1f66cf4c46d"}, + {file = "pydantic_core-2.18.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:fb2bd7be70c0fe4dfd32c951bc813d9fe6ebcbfdd15a07527796c8204bd36242"}, + {file = "pydantic_core-2.18.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6132dd3bd52838acddca05a72aafb6eab6536aa145e923bb50f45e78b7251043"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d904828195733c183d20a54230c0df0eb46ec746ea1a666730787353e87182"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c9bd70772c720142be1020eac55f8143a34ec9f82d75a8e7a07852023e46617f"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2b8ed04b3582771764538f7ee7001b02e1170223cf9b75dff0bc698fadb00cf3"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e6dac87ddb34aaec85f873d737e9d06a3555a1cc1a8e0c44b7f8d5daeb89d86f"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca4ae5a27ad7a4ee5170aebce1574b375de390bc01284f87b18d43a3984df72"}, + {file = "pydantic_core-2.18.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:886eec03591b7cf058467a70a87733b35f44707bd86cf64a615584fd72488b7c"}, + {file = "pydantic_core-2.18.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ca7b0c1f1c983e064caa85f3792dd2fe3526b3505378874afa84baf662e12241"}, + {file = "pydantic_core-2.18.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b4356d3538c3649337df4074e81b85f0616b79731fe22dd11b99499b2ebbdf3"}, + {file = "pydantic_core-2.18.2-cp312-none-win32.whl", hash = "sha256:8b172601454f2d7701121bbec3425dd71efcb787a027edf49724c9cefc14c038"}, + {file = "pydantic_core-2.18.2-cp312-none-win_amd64.whl", hash = "sha256:b1bd7e47b1558ea872bd16c8502c414f9e90dcf12f1395129d7bb42a09a95438"}, + {file = "pydantic_core-2.18.2-cp312-none-win_arm64.whl", hash = "sha256:98758d627ff397e752bc339272c14c98199c613f922d4a384ddc07526c86a2ec"}, + {file = "pydantic_core-2.18.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:9fdad8e35f278b2c3eb77cbdc5c0a49dada440657bf738d6905ce106dc1de439"}, + {file = "pydantic_core-2.18.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1d90c3265ae107f91a4f279f4d6f6f1d4907ac76c6868b27dc7fb33688cfb347"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:390193c770399861d8df9670fb0d1874f330c79caaca4642332df7c682bf6b91"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:82d5d4d78e4448683cb467897fe24e2b74bb7b973a541ea1dcfec1d3cbce39fb"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4774f3184d2ef3e14e8693194f661dea5a4d6ca4e3dc8e39786d33a94865cefd"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d4d938ec0adf5167cb335acb25a4ee69a8107e4984f8fbd2e897021d9e4ca21b"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0e8b1be28239fc64a88a8189d1df7fad8be8c1ae47fcc33e43d4be15f99cc70"}, + {file = "pydantic_core-2.18.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:868649da93e5a3d5eacc2b5b3b9235c98ccdbfd443832f31e075f54419e1b96b"}, + {file = "pydantic_core-2.18.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:78363590ef93d5d226ba21a90a03ea89a20738ee5b7da83d771d283fd8a56761"}, + {file = "pydantic_core-2.18.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:852e966fbd035a6468fc0a3496589b45e2208ec7ca95c26470a54daed82a0788"}, + {file = "pydantic_core-2.18.2-cp38-none-win32.whl", hash = "sha256:6a46e22a707e7ad4484ac9ee9f290f9d501df45954184e23fc29408dfad61350"}, + {file = "pydantic_core-2.18.2-cp38-none-win_amd64.whl", hash = "sha256:d91cb5ea8b11607cc757675051f61b3d93f15eca3cefb3e6c704a5d6e8440f4e"}, + {file = "pydantic_core-2.18.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:ae0a8a797a5e56c053610fa7be147993fe50960fa43609ff2a9552b0e07013e8"}, + {file = "pydantic_core-2.18.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:042473b6280246b1dbf530559246f6842b56119c2926d1e52b631bdc46075f2a"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a388a77e629b9ec814c1b1e6b3b595fe521d2cdc625fcca26fbc2d44c816804"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e25add29b8f3b233ae90ccef2d902d0ae0432eb0d45370fe315d1a5cf231004b"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f459a5ce8434614dfd39bbebf1041952ae01da6bed9855008cb33b875cb024c0"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eff2de745698eb46eeb51193a9f41d67d834d50e424aef27df2fcdee1b153845"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8309f67285bdfe65c372ea3722b7a5642680f3dba538566340a9d36e920b5f0"}, + {file = "pydantic_core-2.18.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f93a8a2e3938ff656a7c1bc57193b1319960ac015b6e87d76c76bf14fe0244b4"}, + {file = "pydantic_core-2.18.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:22057013c8c1e272eb8d0eebc796701167d8377441ec894a8fed1af64a0bf399"}, + {file = "pydantic_core-2.18.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cfeecd1ac6cc1fb2692c3d5110781c965aabd4ec5d32799773ca7b1456ac636b"}, + {file = "pydantic_core-2.18.2-cp39-none-win32.whl", hash = "sha256:0d69b4c2f6bb3e130dba60d34c0845ba31b69babdd3f78f7c0c8fae5021a253e"}, + {file = "pydantic_core-2.18.2-cp39-none-win_amd64.whl", hash = "sha256:d9319e499827271b09b4e411905b24a426b8fb69464dfa1696258f53a3334641"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a1874c6dd4113308bd0eb568418e6114b252afe44319ead2b4081e9b9521fe75"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:ccdd111c03bfd3666bd2472b674c6899550e09e9f298954cfc896ab92b5b0e6d"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e18609ceaa6eed63753037fc06ebb16041d17d28199ae5aba0052c51449650a9"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e5c584d357c4e2baf0ff7baf44f4994be121e16a2c88918a5817331fc7599d7"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43f0f463cf89ace478de71a318b1b4f05ebc456a9b9300d027b4b57c1a2064fb"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e1b395e58b10b73b07b7cf740d728dd4ff9365ac46c18751bf8b3d8cca8f625a"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0098300eebb1c837271d3d1a2cd2911e7c11b396eac9661655ee524a7f10587b"}, + {file = "pydantic_core-2.18.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:36789b70d613fbac0a25bb07ab3d9dba4d2e38af609c020cf4d888d165ee0bf3"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3f9a801e7c8f1ef8718da265bba008fa121243dfe37c1cea17840b0944dfd72c"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:3a6515ebc6e69d85502b4951d89131ca4e036078ea35533bb76327f8424531ce"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20aca1e2298c56ececfd8ed159ae4dde2df0781988c97ef77d5c16ff4bd5b400"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:223ee893d77a310a0391dca6df00f70bbc2f36a71a895cecd9a0e762dc37b349"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2334ce8c673ee93a1d6a65bd90327588387ba073c17e61bf19b4fd97d688d63c"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cbca948f2d14b09d20268cda7b0367723d79063f26c4ffc523af9042cad95592"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b3ef08e20ec49e02d5c6717a91bb5af9b20f1805583cb0adfe9ba2c6b505b5ae"}, + {file = "pydantic_core-2.18.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c6fdc8627910eed0c01aed6a390a252fe3ea6d472ee70fdde56273f198938374"}, + {file = "pydantic_core-2.18.2.tar.gz", hash = "sha256:2e29d20810dfc3043ee13ac7d9e25105799817683348823f305ab3f349b9386e"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + [[package]] name = "pyreadline3" version = "3.4.1" @@ -2063,6 +2388,21 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +[[package]] +name = "referencing" +version = "0.35.0" +description = "JSON Referencing + Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "referencing-0.35.0-py3-none-any.whl", hash = "sha256:8080727b30e364e5783152903672df9b6b091c926a146a759080b62ca3126cd6"}, + {file = "referencing-0.35.0.tar.gz", hash = "sha256:191e936b0c696d0af17ad7430a3dc68e88bc11be6514f4757dc890f04ab05889"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +rpds-py = ">=0.7.0" + [[package]] name = "regex" version = "2024.4.16" @@ -2186,6 +2526,114 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rpds-py" +version = "0.18.0" +description = "Python bindings to Rust's persistent data structures (rpds)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "rpds_py-0.18.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:5b4e7d8d6c9b2e8ee2d55c90b59c707ca59bc30058269b3db7b1f8df5763557e"}, + {file = "rpds_py-0.18.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c463ed05f9dfb9baebef68048aed8dcdc94411e4bf3d33a39ba97e271624f8f7"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01e36a39af54a30f28b73096dd39b6802eddd04c90dbe161c1b8dbe22353189f"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d62dec4976954a23d7f91f2f4530852b0c7608116c257833922a896101336c51"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dd18772815d5f008fa03d2b9a681ae38d5ae9f0e599f7dda233c439fcaa00d40"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:923d39efa3cfb7279a0327e337a7958bff00cc447fd07a25cddb0a1cc9a6d2da"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39514da80f971362f9267c600b6d459bfbbc549cffc2cef8e47474fddc9b45b1"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a34d557a42aa28bd5c48a023c570219ba2593bcbbb8dc1b98d8cf5d529ab1434"}, + {file = "rpds_py-0.18.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:93df1de2f7f7239dc9cc5a4a12408ee1598725036bd2dedadc14d94525192fc3"}, + {file = "rpds_py-0.18.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:34b18ba135c687f4dac449aa5157d36e2cbb7c03cbea4ddbd88604e076aa836e"}, + {file = "rpds_py-0.18.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c0b5dcf9193625afd8ecc92312d6ed78781c46ecbf39af9ad4681fc9f464af88"}, + {file = "rpds_py-0.18.0-cp310-none-win32.whl", hash = "sha256:c4325ff0442a12113a6379af66978c3fe562f846763287ef66bdc1d57925d337"}, + {file = "rpds_py-0.18.0-cp310-none-win_amd64.whl", hash = "sha256:7223a2a5fe0d217e60a60cdae28d6949140dde9c3bcc714063c5b463065e3d66"}, + {file = "rpds_py-0.18.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:3a96e0c6a41dcdba3a0a581bbf6c44bb863f27c541547fb4b9711fd8cf0ffad4"}, + {file = "rpds_py-0.18.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30f43887bbae0d49113cbaab729a112251a940e9b274536613097ab8b4899cf6"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fcb25daa9219b4cf3a0ab24b0eb9a5cc8949ed4dc72acb8fa16b7e1681aa3c58"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d68c93e381010662ab873fea609bf6c0f428b6d0bb00f2c6939782e0818d37bf"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b34b7aa8b261c1dbf7720b5d6f01f38243e9b9daf7e6b8bc1fd4657000062f2c"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e6d75ab12b0bbab7215e5d40f1e5b738aa539598db27ef83b2ec46747df90e1"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b8612cd233543a3781bc659c731b9d607de65890085098986dfd573fc2befe5"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aec493917dd45e3c69d00a8874e7cbed844efd935595ef78a0f25f14312e33c6"}, + {file = "rpds_py-0.18.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:661d25cbffaf8cc42e971dd570d87cb29a665f49f4abe1f9e76be9a5182c4688"}, + {file = "rpds_py-0.18.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1df3659d26f539ac74fb3b0c481cdf9d725386e3552c6fa2974f4d33d78e544b"}, + {file = "rpds_py-0.18.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a1ce3ba137ed54f83e56fb983a5859a27d43a40188ba798993812fed73c70836"}, + {file = "rpds_py-0.18.0-cp311-none-win32.whl", hash = "sha256:69e64831e22a6b377772e7fb337533c365085b31619005802a79242fee620bc1"}, + {file = "rpds_py-0.18.0-cp311-none-win_amd64.whl", hash = "sha256:998e33ad22dc7ec7e030b3df701c43630b5bc0d8fbc2267653577e3fec279afa"}, + {file = "rpds_py-0.18.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7f2facbd386dd60cbbf1a794181e6aa0bd429bd78bfdf775436020172e2a23f0"}, + {file = "rpds_py-0.18.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1d9a5be316c15ffb2b3c405c4ff14448c36b4435be062a7f578ccd8b01f0c4d8"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd5bf1af8efe569654bbef5a3e0a56eca45f87cfcffab31dd8dde70da5982475"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5417558f6887e9b6b65b4527232553c139b57ec42c64570569b155262ac0754f"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:56a737287efecafc16f6d067c2ea0117abadcd078d58721f967952db329a3e5c"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8f03bccbd8586e9dd37219bce4d4e0d3ab492e6b3b533e973fa08a112cb2ffc9"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4457a94da0d5c53dc4b3e4de1158bdab077db23c53232f37a3cb7afdb053a4e3"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0ab39c1ba9023914297dd88ec3b3b3c3f33671baeb6acf82ad7ce883f6e8e157"}, + {file = "rpds_py-0.18.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9d54553c1136b50fd12cc17e5b11ad07374c316df307e4cfd6441bea5fb68496"}, + {file = "rpds_py-0.18.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0af039631b6de0397ab2ba16eaf2872e9f8fca391b44d3d8cac317860a700a3f"}, + {file = "rpds_py-0.18.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:84ffab12db93b5f6bad84c712c92060a2d321b35c3c9960b43d08d0f639d60d7"}, + {file = "rpds_py-0.18.0-cp312-none-win32.whl", hash = "sha256:685537e07897f173abcf67258bee3c05c374fa6fff89d4c7e42fb391b0605e98"}, + {file = "rpds_py-0.18.0-cp312-none-win_amd64.whl", hash = "sha256:e003b002ec72c8d5a3e3da2989c7d6065b47d9eaa70cd8808b5384fbb970f4ec"}, + {file = "rpds_py-0.18.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:08f9ad53c3f31dfb4baa00da22f1e862900f45908383c062c27628754af2e88e"}, + {file = "rpds_py-0.18.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c0013fe6b46aa496a6749c77e00a3eb07952832ad6166bd481c74bda0dcb6d58"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e32a92116d4f2a80b629778280103d2a510a5b3f6314ceccd6e38006b5e92dcb"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e541ec6f2ec456934fd279a3120f856cd0aedd209fc3852eca563f81738f6861"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bed88b9a458e354014d662d47e7a5baafd7ff81c780fd91584a10d6ec842cb73"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2644e47de560eb7bd55c20fc59f6daa04682655c58d08185a9b95c1970fa1e07"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e8916ae4c720529e18afa0b879473049e95949bf97042e938530e072fde061d"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:465a3eb5659338cf2a9243e50ad9b2296fa15061736d6e26240e713522b6235c"}, + {file = "rpds_py-0.18.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ea7d4a99f3b38c37eac212dbd6ec42b7a5ec51e2c74b5d3223e43c811609e65f"}, + {file = "rpds_py-0.18.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:67071a6171e92b6da534b8ae326505f7c18022c6f19072a81dcf40db2638767c"}, + {file = "rpds_py-0.18.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:41ef53e7c58aa4ef281da975f62c258950f54b76ec8e45941e93a3d1d8580594"}, + {file = "rpds_py-0.18.0-cp38-none-win32.whl", hash = "sha256:fdea4952db2793c4ad0bdccd27c1d8fdd1423a92f04598bc39425bcc2b8ee46e"}, + {file = "rpds_py-0.18.0-cp38-none-win_amd64.whl", hash = "sha256:7cd863afe7336c62ec78d7d1349a2f34c007a3cc6c2369d667c65aeec412a5b1"}, + {file = "rpds_py-0.18.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5307def11a35f5ae4581a0b658b0af8178c65c530e94893345bebf41cc139d33"}, + {file = "rpds_py-0.18.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f195baa60a54ef9d2de16fbbfd3ff8b04edc0c0140a761b56c267ac11aa467"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39f5441553f1c2aed4de4377178ad8ff8f9d733723d6c66d983d75341de265ab"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9a00312dea9310d4cb7dbd7787e722d2e86a95c2db92fbd7d0155f97127bcb40"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f2fc11e8fe034ee3c34d316d0ad8808f45bc3b9ce5857ff29d513f3ff2923a1"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:586f8204935b9ec884500498ccc91aa869fc652c40c093bd9e1471fbcc25c022"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddc2f4dfd396c7bfa18e6ce371cba60e4cf9d2e5cdb71376aa2da264605b60b9"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5ddcba87675b6d509139d1b521e0c8250e967e63b5909a7e8f8944d0f90ff36f"}, + {file = "rpds_py-0.18.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:7bd339195d84439cbe5771546fe8a4e8a7a045417d8f9de9a368c434e42a721e"}, + {file = "rpds_py-0.18.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:d7c36232a90d4755b720fbd76739d8891732b18cf240a9c645d75f00639a9024"}, + {file = "rpds_py-0.18.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:6b0817e34942b2ca527b0e9298373e7cc75f429e8da2055607f4931fded23e20"}, + {file = "rpds_py-0.18.0-cp39-none-win32.whl", hash = "sha256:99f70b740dc04d09e6b2699b675874367885217a2e9f782bdf5395632ac663b7"}, + {file = "rpds_py-0.18.0-cp39-none-win_amd64.whl", hash = "sha256:6ef687afab047554a2d366e112dd187b62d261d49eb79b77e386f94644363294"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ad36cfb355e24f1bd37cac88c112cd7730873f20fb0bdaf8ba59eedf8216079f"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:36b3ee798c58ace201289024b52788161e1ea133e4ac93fba7d49da5fec0ef9e"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8a2f084546cc59ea99fda8e070be2fd140c3092dc11524a71aa8f0f3d5a55ca"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e4461d0f003a0aa9be2bdd1b798a041f177189c1a0f7619fe8c95ad08d9a45d7"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8db715ebe3bb7d86d77ac1826f7d67ec11a70dbd2376b7cc214199360517b641"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793968759cd0d96cac1e367afd70c235867831983f876a53389ad869b043c948"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66e6a3af5a75363d2c9a48b07cb27c4ea542938b1a2e93b15a503cdfa8490795"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ef0befbb5d79cf32d0266f5cff01545602344eda89480e1dd88aca964260b18"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:1d4acf42190d449d5e89654d5c1ed3a4f17925eec71f05e2a41414689cda02d1"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:a5f446dd5055667aabaee78487f2b5ab72e244f9bc0b2ffebfeec79051679984"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:9dbbeb27f4e70bfd9eec1be5477517365afe05a9b2c441a0b21929ee61048124"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:22806714311a69fd0af9b35b7be97c18a0fc2826e6827dbb3a8c94eac6cf7eeb"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:b34ae4636dfc4e76a438ab826a0d1eed2589ca7d9a1b2d5bb546978ac6485461"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c8370641f1a7f0e0669ddccca22f1da893cef7628396431eb445d46d893e5cd"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c8362467a0fdeccd47935f22c256bec5e6abe543bf0d66e3d3d57a8fb5731863"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11a8c85ef4a07a7638180bf04fe189d12757c696eb41f310d2426895356dcf05"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b316144e85316da2723f9d8dc75bada12fa58489a527091fa1d5a612643d1a0e"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf1ea2e34868f6fbf070e1af291c8180480310173de0b0c43fc38a02929fc0e3"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e546e768d08ad55b20b11dbb78a745151acbd938f8f00d0cfbabe8b0199b9880"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4901165d170a5fde6f589acb90a6b33629ad1ec976d4529e769c6f3d885e3e80"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:618a3d6cae6ef8ec88bb76dd80b83cfe415ad4f1d942ca2a903bf6b6ff97a2da"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ed4eb745efbff0a8e9587d22a84be94a5eb7d2d99c02dacf7bd0911713ed14dd"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6c81e5f372cd0dc5dc4809553d34f832f60a46034a5f187756d9b90586c2c307"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:43fbac5f22e25bee1d482c97474f930a353542855f05c1161fd804c9dc74a09d"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d7faa6f14017c0b1e69f5e2c357b998731ea75a442ab3841c0dbbbfe902d2c4"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:08231ac30a842bd04daabc4d71fddd7e6d26189406d5a69535638e4dcb88fe76"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:044a3e61a7c2dafacae99d1e722cc2d4c05280790ec5a05031b3876809d89a5c"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f26b5bd1079acdb0c7a5645e350fe54d16b17bfc5e71f371c449383d3342e17"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:482103aed1dfe2f3b71a58eff35ba105289b8d862551ea576bd15479aba01f66"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1374f4129f9bcca53a1bba0bb86bf78325a0374577cf7e9e4cd046b1e6f20e24"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:635dc434ff724b178cb192c70016cc0ad25a275228f749ee0daf0eddbc8183b1"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:bc362ee4e314870a70f4ae88772d72d877246537d9f8cb8f7eacf10884862432"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:4832d7d380477521a8c1644bbab6588dfedea5e30a7d967b5fb75977c45fd77f"}, + {file = "rpds_py-0.18.0.tar.gz", hash = "sha256:42821446ee7a76f5d9f71f9e33a4fb2ffd724bb3e7f93386150b61a43115788d"}, +] + [[package]] name = "safetensors" version = "0.4.3" @@ -2308,6 +2756,48 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "scipy" +version = "1.13.0" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"}, + {file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"}, + {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"}, + {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"}, + {file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"}, + {file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"}, + {file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"}, + {file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"}, + {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"}, + {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"}, + {file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"}, + {file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"}, + {file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"}, + {file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"}, + {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"}, + {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"}, + {file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"}, + {file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"}, + {file = "scipy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602"}, + {file = "scipy-1.13.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78"}, + {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5"}, + {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d"}, + {file = "scipy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86"}, + {file = "scipy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e"}, + {file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"}, +] + +[package.dependencies] +numpy = ">=1.22.4,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + [[package]] name = "sentencepiece" version = "0.1.99" @@ -3102,4 +3592,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "5f39b7146fc36cc846070a3e12db863f2748b49a24ab55878129be3480c89685" +content-hash = "8a74ea25e989cdb332816e4bc0ad661c6230d60d5ae8e712fa56297aae4083b3" diff --git a/server/pyproject.toml b/server/pyproject.toml index 517ef42f..ee1a9d8a 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -25,6 +25,7 @@ peft = "^0.8.2" optimum-habana = "1.10.4" transformers = "4.37.2" accelerate = "0.27.2" +outlines="^0.0.27" [tool.poetry.group.dev.dependencies] grpcio-tools = "*" diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 771c7694..2dcdaa34 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -407,6 +407,7 @@ class CausalLMBatch(Batch): parameters, batches[dst_batch_idx].next_token_chooser.dtype, batches[dst_batch_idx].next_token_chooser.device, + batches[dst_batch_idx].next_token_chooser.tokenizer, hq_env.is_quantization_enabled ) @@ -462,7 +463,7 @@ class CausalLMBatch(Batch): parameters.append(parameters[0]) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - parameters, dtype, device, hq_env.is_quantization_enabled + parameters, dtype, device, tokenizer, hq_env.is_quantization_enabled ) tokenized_inputs = tokenizer( [r.data.inputs for r in requests] + dummy_inputs, @@ -1040,7 +1041,7 @@ class CausalLM(Model): if top_n_tokens > 0: all_top_tokens = [] - for (top_token_ids, top_token_logprobs) in zip( + for top_token_ids, top_token_logprobs in zip( top_token_ids, top_token_logprobs ): toptoken_texts = self.tokenizer.batch_decode( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 886fe486..7ec8c2fc 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -237,7 +237,7 @@ class FlashCausalLMBatch(Batch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, dtype, device, tokenizer ) start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -593,6 +593,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, + tokenizer=batches[0].next_token_chooser.tokenizer, ) speculative_ids = ( @@ -869,7 +870,11 @@ class FlashCausalLM(Model): # Try to find an associated cuda graph cuda_graph = self.cuda_graphs.get(padded_bs, None) - if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None: + if ( + cu_seqlen_prefill is not None + or cuda_graph is None + or batch.speculative_ids is not None + ): return self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1013,9 +1018,9 @@ class FlashCausalLM(Model): # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: if len(batch) > 1: - prefill_tokens_indices[ - out_start_index : out_end_index - 1 - ] = batch.input_ids[start_index + 1 : start_index + out_length] + prefill_tokens_indices[out_start_index : out_end_index - 1] = ( + batch.input_ids[start_index + 1 : start_index + out_length] + ) else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = batch.input_ids[ @@ -1028,6 +1033,7 @@ class FlashCausalLM(Model): cumulative_length += input_length + # Update values batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids @@ -1166,7 +1172,7 @@ class FlashCausalLM(Model): if top_n_tokens > 0: all_top_tokens = [] - for (top_token_ids, top_token_logprobs) in zip( + for top_token_ids, top_token_logprobs in zip( top_token_ids, top_token_logprobs ): toptoken_texts = self.tokenizer.batch_decode( @@ -1204,6 +1210,12 @@ class FlashCausalLM(Model): generations.append(generation) + # accept each new token for this specific request since we may + # have more than one new token per request with speculative decoding + for next_token_id in _next_token_ids: + batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id) + + # Update values batch.input_lengths[i] = input_length + n_accepted_ids if batch.input_lengths[i] > batch.max_seqlen: diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 34a50194..70669c8d 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -192,7 +192,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, dtype, device, tokenizer ) start_slots = torch.tensor(start_slots, dtype=torch.int64) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 42ff1c80..a2c30255 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 2f28688d..5ea2db87 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -815,6 +815,9 @@ class IdeficsCausalLM(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 868db6aa..4585f4b9 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -124,7 +124,7 @@ class MambaBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -694,6 +694,9 @@ class Mamba(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 25042a32..459f4256 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch): inputs.append(r.inputs) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -789,6 +789,9 @@ class Seq2SeqLM(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.decoder_input_ids[i] = next_token_id batch.all_decoder_input_ids[i] = all_decoder_input_ids batch.input_lengths[i] = input_length diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 9013d200..64f74433 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -2,8 +2,17 @@ import math import torch import habana_frameworks.torch.core as htcore +import json +from loguru import logger from functools import lru_cache from typing import Optional, List, Dict, Union +from text_generation_server.pb.generate_pb2 import GrammarType + +from outlines.fsm.fsm import RegexFSM +from outlines.fsm.json_schema import build_regex_from_object +from functools import lru_cache +from typing import List, Optional, DefaultDict +import time from transformers import ( LogitsWarper, @@ -124,9 +133,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor): ) -> torch.FloatTensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then penalty has to be multiplied to reduce the previous token probability - score = -torch.where( - score < 0, score * self.penalty, score / self.penalty - ) + score = -torch.where(score < 0, score * self.penalty, score / self.penalty) return scores.scatter_add_(1, input_ids, score) @@ -435,3 +442,132 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): self.processors = new_processors return self return None + + +class GrammarLogitProcessor(LogitsProcessor): + fsm_state: DefaultDict[int, int] + fsm: RegexFSM + + def __init__(self, tokenizer, device, grammar, grammar_type): + self.device = device + self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) + self.fsm = GrammarLogitProcessor._cached_compile_fsm( + grammar_type, grammar, self.tokenizer + ) + + def __call__( + self, + logits: torch.Tensor, + fsm_grammar_state: int, + ): + if fsm_grammar_state == -1 or self.fsm is None: + return logits + allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) + mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) + mask[allowed_tokens] = 0 + biased_scores = logits + mask + return biased_scores + + def advance(self, next_token_id, fsm_grammar_state): + return GrammarLogitProcessor._advance( + next_token_id, fsm_grammar_state, self.fsm + ) + + @staticmethod + def _advance(next_token_id, fsm_grammar_state, fsm): + if fsm_grammar_state == -1: + return fsm_grammar_state + return fsm.next_state(fsm_grammar_state, next_token_id) + + # TODO: move grammar compilation into the router + @staticmethod + @lru_cache(maxsize=32, typed=True) + def _cached_compile_fsm(grammar_type, schema, tokenizer): + start_time = time.time() + if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: + schema = build_regex_from_object(schema) + elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: + pass # schema is already a regex just here for clarity + fsm = RegexFSM(schema, tokenizer) + logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") + return fsm + + @staticmethod + @lru_cache(maxsize=32, typed=True) + def _cached_adapt_tokenizer(tokenizer): + """Adapt tokenizer to work with the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + start_time = time.time() + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s") + return tokenizer + + def filter(self, indices): + new_fsms = [] + for i in indices: + new_fsms.append(self.fsms[i]) + self.fsms = new_fsms + return self + + +class HeterogeneousGrammarLogitProcessor(LogitsProcessor): + def __init__(self, tokenizer, device, grammars, grammar_type): + self.device = device + self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) + self.fsms = [] + for i in range(len(grammars)): + fsm = GrammarLogitProcessor._cached_compile_fsm( + grammar_type[i], grammars[i], self.tokenizer + ) + self.fsms.append(fsm) + + def __call__( + self, + logits: torch.Tensor, + fsm_grammar_states: List[int], + mask: torch.Tensor, + ): + mask = torch.full_like(logits, -math.inf) + for i in range(logits.shape[0]): + fsm = self.fsms[i] + if fsm_grammar_states[i] == -1 or fsm is None: + continue + allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) + mask[i, allowed_tokens] = 0 + logits += mask + return logits + + def advance_batch(self, next_token_ids, fsm_grammar_states, grammars): + return [ + GrammarLogitProcessor._advance( + next_token_ids[i], fsm_grammar_states[i], self.fsms[i] + ) + for i in range(len(next_token_ids)) + ] + + def advance_at_index(self, next_token_id, fsm_grammar_state, index): + return GrammarLogitProcessor._advance( + next_token_id, fsm_grammar_state, self.fsms[index] + ) + + def filter(self, indices): + return GrammarLogitProcessor.filter(self, indices) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 9d918ac5..ff0d22e2 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -3,11 +3,13 @@ import re from typing import List, Optional, Tuple +import math import torch from text_generation_server.pb import generate_pb2 -from text_generation_server.pb.generate_pb2 import FinishReason +from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType from text_generation_server.utils.logits_process import ( FrequencyPenaltyLogitsProcessor, + GrammarLogitProcessor, HeterogeneousProcessorWrapper, HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousFrequencyPenaltyLogitsProcessor, @@ -15,6 +17,7 @@ from text_generation_server.utils.logits_process import ( HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, + HeterogeneousGrammarLogitProcessor, static_warper, ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor @@ -35,8 +38,14 @@ class NextTokenChooser: do_sample=False, seed=0, device="cpu", + tokenizer: Optional[PreTrainedTokenizerBase] = None, + grammar: str = "", + grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, + fsm_grammar_state: int = 0, ): - self.watermark_processor = WatermarkLogitsProcessor(device=device) if watermark else None + self.watermark_processor = ( + WatermarkLogitsProcessor(device=device) if watermark else None + ) self.repetition_processor = ( RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) if repetition_penalty and repetition_penalty != 1.0 @@ -47,6 +56,12 @@ class NextTokenChooser: if frequency_penalty and frequency_penalty != 0.0 else None ) + self.grammar_processor = ( + GrammarLogitProcessor(tokenizer, device, grammar, grammar_type) + if grammar != "" + else None + ) + self.tokenizer = tokenizer has_warpers = ( (temperature is not None and temperature != 1.0) @@ -60,7 +75,10 @@ class NextTokenChooser: self.static_warper = None sampling = do_sample or has_warpers + self.choice = Sampling(seed, device) if sampling else Greedy() + self.fsm_grammar_state = fsm_grammar_state + self.grammar = grammar def __call__(self, input_ids, scores): if self.watermark_processor is not None: @@ -69,6 +87,8 @@ class NextTokenChooser: scores = self.repetition_processor(input_ids, scores) if self.frequency_processor is not None: scores = self.frequency_processor(input_ids, scores) + if self.grammar_processor is not None: + scores = self.grammar_processor(scores, self.fsm_grammar_state) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -79,11 +99,19 @@ class NextTokenChooser: return next_id, next_logprob + def advance_grammar(self, next_id: int): + if self.grammar_processor is not None: + self.fsm_grammar_state = self.grammar_processor.advance( + next_id, self.fsm_grammar_state + ) + return self + @classmethod def from_pb( cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -96,6 +124,9 @@ class NextTokenChooser: do_sample=pb.do_sample, seed=pb.seed, device=device, + tokenizer=tokenizer, + grammar=pb.grammar, + grammar_type=pb.grammar_type, ) @@ -198,6 +229,10 @@ class HeterogeneousNextTokenChooser: typical_p: List[float], do_sample: List[bool], seeds: List[int], + tokenizer: PreTrainedTokenizerBase, + grammars: List[str], + grammar_types: List[int], + fsm_grammar_states: List[int], quantization_enabled: bool, ): warpers = [] @@ -231,6 +266,14 @@ class HeterogeneousNextTokenChooser: else None ) + self.grammar_processor = ( + HeterogeneousGrammarLogitProcessor( + tokenizer, device, grammars, grammar_types + ) + if any([grammar != "" for grammar in grammars]) + else None + ) + if any([x != 1.0 for x in temperature]): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -262,6 +305,10 @@ class HeterogeneousNextTokenChooser: self.do_sample = do_sample self.dtype = dtype self.device = device + self.tokenizer = tokenizer + self.fsm_grammar_states = fsm_grammar_states + self.grammars = grammars + self.grammar_types = grammar_types def __call__( self, @@ -282,6 +329,8 @@ class HeterogeneousNextTokenChooser: scores = scores.view(B, S, -1) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) + mask = torch.full((scores.shape[-1],), -math.inf, device=self.device) + for j in range(S): _scores = scores[:, j] if self.watermark_processor is not None: @@ -290,10 +339,10 @@ class HeterogeneousNextTokenChooser: _scores = self.repetition_processor(input_ids, _scores) if self.frequency_processor is not None: _scores = self.frequency_processor(input_ids, _scores) - for warper in self.warpers: _scores = warper(input_ids, _scores) - + if self.grammar_processor is not None: + _scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask) _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids @@ -351,6 +400,21 @@ class HeterogeneousNextTokenChooser: return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids + def advance_grammar(self, next_ids: List[int]): + if self.grammar_processor is not None: + other_new_states = self.grammar_processor.advance_batch( + next_ids, self.fsm_grammar_states, self.grammars + ) + self.fsm_grammar_states = other_new_states + return self + + def advance_grammar_single(self, grammar_state_index: int, next_id: int): + if self.grammar_processor is not None: + self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index( + next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index + ) + return self + def filter(self, indices): if self.watermark_processor is not None: self.watermark_processor = self.watermark_processor.filter(indices) @@ -361,6 +425,9 @@ class HeterogeneousNextTokenChooser: if self.frequency_processor is not None: self.frequency_processor = self.frequency_processor.filter(indices) + if self.grammar_processor is not None: + self.grammar_processor = self.grammar_processor.filter(indices) + filtered_warpers = [] for warper in self.warpers: filtered_warper = warper.filter(indices) @@ -371,6 +438,18 @@ class HeterogeneousNextTokenChooser: self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] + new_grammars = [] + new_fsm_grammar_states = [] + new_grammar_types = [] + for i in indices: + new_grammars.append(self.grammars[i]) + new_fsm_grammar_states.append(self.fsm_grammar_states[i]) + new_grammar_types.append(self.grammar_types[i]) + + self.grammars = new_grammars + self.fsm_grammar_states = new_fsm_grammar_states + self.grammar_types = new_grammar_types + if any(self.do_sample): self.choice.filter(indices) else: @@ -384,6 +463,7 @@ class HeterogeneousNextTokenChooser: pb: List[generate_pb2.NextTokenChooserParameters], dtype: torch.dtype, device: torch.device, + tokenizer: PreTrainedTokenizerBase, quantization_enabled: bool, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( @@ -398,6 +478,10 @@ class HeterogeneousNextTokenChooser: seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, + tokenizer=tokenizer, + grammars=[pb_.grammar for pb_ in pb], + grammar_types=[pb_.grammar_type for pb_ in pb], + fsm_grammar_states=[0] * len(pb), quantization_enabled=quantization_enabled, )