diff --git a/Cargo.lock b/Cargo.lock index 012a8c02..19acba99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1401,6 +1401,15 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + [[package]] name = "metrics" version = "0.21.1" @@ -2197,6 +2206,69 @@ dependencies = [ "prost 0.12.3", ] +[[package]] +name = "pyo3" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.51", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.51", +] + [[package]] name = "quanta" version = "0.11.1" @@ -2945,6 +3017,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "target-lexicon" +version = "0.12.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" + [[package]] name = "tempfile" version = "3.10.1" @@ -3030,6 +3108,7 @@ dependencies = [ "nohash-hasher", "opentelemetry", "opentelemetry-otlp", + "pyo3", "rand", "reqwest", "serde", @@ -3583,6 +3662,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + [[package]] name = "untrusted" version = "0.7.1" diff --git a/router/Cargo.toml b/router/Cargo.toml index 170debda..10778a37 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,6 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } minijinja = "1.0.10" futures-util = "0.3.30" +pyo3 = "0.20.3" [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/src/validation.rs b/router/src/validation.rs index 204dbf92..5aa87c49 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest, GrammarType}; @@ -14,6 +16,9 @@ use tokio::sync::mpsc; use tokio::sync::oneshot; use tracing::{instrument, Span}; +use pyo3::prelude::*; +use pyo3::types::IntoPyDict; + /// Validation #[derive(Debug, Clone)] pub struct Validation { @@ -26,6 +31,7 @@ pub struct Validation { disable_grammar_support: bool, /// Channel to communicate with the background tokenization task sender: Option>, + grammar_compilation_sender: Option>, } impl Validation { @@ -41,7 +47,7 @@ impl Validation { disable_grammar_support: bool, ) -> Self { // If we have a fast tokenizer - let sender = if let Some(tokenizer) = tokenizer { + let (sender, grammar_compilation_sender) = if let Some(tokenizer) = tokenizer { // Create round robin channel let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel(); let mut senders = Vec::with_capacity(workers); @@ -58,17 +64,29 @@ impl Validation { }); } + // TODO: start more than one grammar compilation worker + let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel(); + + // create single grammar compilation workers + tokio::task::spawn_blocking(move || { + grammar_compilation_worker(tokenizer, grammar_receiver).map_err(|e| { + tracing::error!("Error in grammar compilation worker: {:?}", e); + e + }) + }); + // Create tokenization round robin task tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); - Some(validation_sender) + (Some(validation_sender), Some(grammar_sender)) } else { - None + (None, None) }; Self { max_best_of, sender, + grammar_compilation_sender, max_stop_sequences, max_top_n_tokens, max_input_length, @@ -102,6 +120,27 @@ impl Validation { } } + #[instrument(skip(self, inputs))] + pub async fn compile_grammar(&self, inputs: String) -> Result { + // If we have a fast tokenizer + if let Some(sender) = &self.grammar_compilation_sender { + // Create response channel + let (response_sender, response_receiver) = oneshot::channel(); + // Send request to the background validation task + // Unwrap is safe here + sender + .send((inputs.clone(), response_sender, Span::current())) + .unwrap(); + + // Await on response channel + // Unwrap is safe here + let encoding = response_receiver.await.unwrap()?; + return Ok(encoding); + } + + Ok(inputs) + } + #[instrument(skip(self, inputs))] async fn validate_input( &self, @@ -331,12 +370,12 @@ impl Validation { .compile(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; - ( - // Serialize json to string - serde_json::to_string(&json) - .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, - ProtoGrammarType::Json.into(), - ) + // NOTE: this is the first step to compile the grammar + let regex_compiled_grammar = self + .compile_grammar(serde_json::to_string(&json).unwrap()) + .await + .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + (regex_compiled_grammar, ProtoGrammarType::Regex.into()) } GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), } @@ -418,6 +457,32 @@ fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver< } } +/// Start grammar compilation workers +fn grammar_compilation_worker( + tokenizer: Tokenizer, + mut receiver: mpsc::UnboundedReceiver, +) -> Result<(), PyErr> { + // initialize python runtime + pyo3::prepare_freethreaded_python(); + + // load in outlines for all workers + Python::with_gil(|py| { + PyModule::import(py, "outlines")?; + Ok::<_, PyErr>(()) + })?; + + // Loop over requests + while let Some((inputs, response_tx, parent_span)) = receiver.blocking_recv() { + parent_span.in_scope(|| { + response_tx + .send(compile_grammar(inputs, &tokenizer)) + .unwrap_or(()) + }) + } + + Ok(()) +} + /// Get input length and optionally truncate it fn prepare_input( mut inputs: String, @@ -442,6 +507,120 @@ fn prepare_input( Ok((encoding, inputs)) } +/// Compile a grammar +fn compile_grammar(inputs: String, tokenizer: &Tokenizer) -> Result { + let start_time = std::time::Instant::now(); + let schema = Python::with_gil(|py| -> PyResult { + let fun: Py = PyModule::from_code( + py, + r#" +from outlines.fsm.fsm import RegexFSM +from outlines.fsm.json_schema import build_regex_from_schema +import time +from transformers.file_utils import SPIECE_UNDERLINE + +class Tokenizer: + def __init__(self, vocab, special_tokens): + self.vocabulary = vocab + self.special_tokens = special_tokens + self.eos_token_id = 0 + + def get_vocab(self, with_added_tokens): + return self.vocabulary + + def encode(self, text, add_special_tokens): + return text + + def decode(self, text, skip_special_tokens): + return text + + def convert_tokens_to_string(self, tokens): + return " ".join(tokens) + +def adapt_tokenizer(vocab, special_tokens): + start_time = time.time() + tokenizer = Tokenizer(vocab, special_tokens) + + def convert_token_to_string(token: str) -> str: + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + print(f"Adapted tokenizer in {time.time() - start_time:.2f}s") + return tokenizer + +def compile_regex_grammar(inputs, vocab): + start_time = time.time() + print("🔥 starting compile_regex_grammar", inputs) + schema = build_regex_from_schema(inputs) + print(f"Compiled grammar in {time.time() - start_time:.2f}s") + tokenizer = adapt_tokenizer(vocab, special_tokens) + print(f"Adapted tokenizer in {time.time() - start_time:.2f}s") + fsm = RegexFSM(schema, tokenizer) + print(f"Compiled grammar in {time.time() - start_time:.2f}s") + return fsm + +def convert_grammar_to_regex(inputs): + start_time = time.time() + print("🔥 starting convert_grammar_to_regex", inputs) + schema = build_regex_from_schema(inputs) + print(f"Compiled grammar in {time.time() - start_time:.2f}s") + return schema +"#, + "", + "", + )? + // .getattr("compile_regex_grammar")? + .getattr("convert_grammar_to_regex")? + .into(); + + let args: &pyo3::types::PyDict = tokenizer.get_vocab(true).into_py_dict(py); + let special_tokens: Vec = vec![]; + + let regex_fsm = fun.call(py, (inputs.clone(),), None)?; + + if false { + let regex_fsm = fun.call(py, (inputs.clone(), args, special_tokens), None)?; + let regex_fsm_ref = regex_fsm.into_ref(py); + + let states_to_token_maps = regex_fsm_ref + .getattr("states_to_token_maps")? + .extract::>>()?; + + println!("🔥 elapsed: {:?}", start_time.elapsed()); + + // size of serialized states_to_token_maps + let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap(); + println!( + "🔥 states_to_token_maps size: {:.2}MB", + states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0 + ); + } + + let result = regex_fsm.into_ref(py).extract().unwrap(); + + println!("result: {:?}", result); + + Ok(result) + }) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + let elapsed = start_time.elapsed(); + println!("🔥 elapsed: {:?}", elapsed); + Ok(schema) +} + +type GrammarCompilationRequest = ( + String, + oneshot::Sender>, + Span, +); + type TokenizerRequest = ( (String, Option), oneshot::Sender>, diff --git a/server/pyproject.toml b/server/pyproject.toml index 1e539ce3..7e88889f 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -34,7 +34,7 @@ peft = { version = "^0.8.2", optional = true } torch = { version = "^2.1.1", optional = true } scipy = "^1.11.1" pillow = "^10.0.0" -outlines= { version = "^0.0.27", optional = true } +outlines= { version = "^0.0.34", optional = true } [tool.poetry.extras] torch = ["torch"] diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index cd7efec8..b4ffb863 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -6,7 +6,7 @@ from typing import Dict, Union from text_generation_server.pb.generate_pb2 import GrammarType from outlines.fsm.fsm import RegexFSM -from outlines.fsm.json_schema import build_regex_from_object +from outlines.fsm.json_schema import build_regex_from_schema from functools import lru_cache from typing import List, Optional, DefaultDict import time @@ -512,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor): def _cached_compile_fsm(grammar_type, schema, tokenizer): start_time = time.time() if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: - schema = build_regex_from_object(schema) + schema = build_regex_from_schema(schema) elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: pass # schema is already a regex just here for clarity fsm = RegexFSM(schema, tokenizer)