feat: support grammar compilation worker via py03

This commit is contained in:
drbh 2024-03-05 17:44:16 +00:00
parent 7dbaf9e901
commit ad5f562aa5
5 changed files with 277 additions and 12 deletions

85
Cargo.lock generated
View File

@ -1401,6 +1401,15 @@ version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
[[package]]
name = "memoffset"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "metrics" name = "metrics"
version = "0.21.1" version = "0.21.1"
@ -2197,6 +2206,69 @@ dependencies = [
"prost 0.12.3", "prost 0.12.3",
] ]
[[package]]
name = "pyo3"
version = "0.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233"
dependencies = [
"cfg-if",
"indoc",
"libc",
"memoffset",
"parking_lot",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.51",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185"
dependencies = [
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.51",
]
[[package]] [[package]]
name = "quanta" name = "quanta"
version = "0.11.1" version = "0.11.1"
@ -2945,6 +3017,12 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "target-lexicon"
version = "0.12.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f"
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.10.1" version = "3.10.1"
@ -3030,6 +3108,7 @@ dependencies = [
"nohash-hasher", "nohash-hasher",
"opentelemetry", "opentelemetry",
"opentelemetry-otlp", "opentelemetry-otlp",
"pyo3",
"rand", "rand",
"reqwest", "reqwest",
"serde", "serde",
@ -3583,6 +3662,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.7.1" version = "0.7.1"

View File

@ -46,6 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = "1.0.10" minijinja = "1.0.10"
futures-util = "0.3.30" futures-util = "0.3.30"
pyo3 = "0.20.3"
[build-dependencies] [build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }

View File

@ -1,3 +1,5 @@
use std::collections::BTreeMap;
/// Payload validation logic /// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest, GrammarType}; use crate::{GenerateParameters, GenerateRequest, GrammarType};
@ -14,6 +16,9 @@ use tokio::sync::mpsc;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tracing::{instrument, Span}; use tracing::{instrument, Span};
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
/// Validation /// Validation
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Validation { pub struct Validation {
@ -26,6 +31,7 @@ pub struct Validation {
disable_grammar_support: bool, disable_grammar_support: bool,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>, sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
grammar_compilation_sender: Option<mpsc::UnboundedSender<GrammarCompilationRequest>>,
} }
impl Validation { impl Validation {
@ -41,7 +47,7 @@ impl Validation {
disable_grammar_support: bool, disable_grammar_support: bool,
) -> Self { ) -> Self {
// If we have a fast tokenizer // If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer { let (sender, grammar_compilation_sender) = if let Some(tokenizer) = tokenizer {
// Create round robin channel // Create round robin channel
let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel(); let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
let mut senders = Vec::with_capacity(workers); let mut senders = Vec::with_capacity(workers);
@ -58,17 +64,29 @@ impl Validation {
}); });
} }
// TODO: start more than one grammar compilation worker
let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel();
// create single grammar compilation workers
tokio::task::spawn_blocking(move || {
grammar_compilation_worker(tokenizer, grammar_receiver).map_err(|e| {
tracing::error!("Error in grammar compilation worker: {:?}", e);
e
})
});
// Create tokenization round robin task // Create tokenization round robin task
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
Some(validation_sender) (Some(validation_sender), Some(grammar_sender))
} else { } else {
None (None, None)
}; };
Self { Self {
max_best_of, max_best_of,
sender, sender,
grammar_compilation_sender,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
@ -102,6 +120,27 @@ impl Validation {
} }
} }
#[instrument(skip(self, inputs))]
pub async fn compile_grammar(&self, inputs: String) -> Result<String, ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.grammar_compilation_sender {
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
sender
.send((inputs.clone(), response_sender, Span::current()))
.unwrap();
// Await on response channel
// Unwrap is safe here
let encoding = response_receiver.await.unwrap()?;
return Ok(encoding);
}
Ok(inputs)
}
#[instrument(skip(self, inputs))] #[instrument(skip(self, inputs))]
async fn validate_input( async fn validate_input(
&self, &self,
@ -331,12 +370,12 @@ impl Validation {
.compile(&json) .compile(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
( // NOTE: this is the first step to compile the grammar
// Serialize json to string let regex_compiled_grammar = self
serde_json::to_string(&json) .compile_grammar(serde_json::to_string(&json).unwrap())
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, .await
ProtoGrammarType::Json.into(), .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
) (regex_compiled_grammar, ProtoGrammarType::Regex.into())
} }
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
} }
@ -418,6 +457,32 @@ fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<
} }
} }
/// Start grammar compilation workers
fn grammar_compilation_worker(
tokenizer: Tokenizer,
mut receiver: mpsc::UnboundedReceiver<GrammarCompilationRequest>,
) -> Result<(), PyErr> {
// initialize python runtime
pyo3::prepare_freethreaded_python();
// load in outlines for all workers
Python::with_gil(|py| {
PyModule::import(py, "outlines")?;
Ok::<_, PyErr>(())
})?;
// Loop over requests
while let Some((inputs, response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| {
response_tx
.send(compile_grammar(inputs, &tokenizer))
.unwrap_or(())
})
}
Ok(())
}
/// Get input length and optionally truncate it /// Get input length and optionally truncate it
fn prepare_input( fn prepare_input(
mut inputs: String, mut inputs: String,
@ -442,6 +507,120 @@ fn prepare_input(
Ok((encoding, inputs)) Ok((encoding, inputs))
} }
/// Compile a grammar
fn compile_grammar(inputs: String, tokenizer: &Tokenizer) -> Result<String, ValidationError> {
let start_time = std::time::Instant::now();
let schema = Python::with_gil(|py| -> PyResult<String> {
let fun: Py<PyAny> = PyModule::from_code(
py,
r#"
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
import time
from transformers.file_utils import SPIECE_UNDERLINE
class Tokenizer:
def __init__(self, vocab, special_tokens):
self.vocabulary = vocab
self.special_tokens = special_tokens
self.eos_token_id = 0
def get_vocab(self, with_added_tokens):
return self.vocabulary
def encode(self, text, add_special_tokens):
return text
def decode(self, text, skip_special_tokens):
return text
def convert_tokens_to_string(self, tokens):
return " ".join(tokens)
def adapt_tokenizer(vocab, special_tokens):
start_time = time.time()
tokenizer = Tokenizer(vocab, special_tokens)
def convert_token_to_string(token: str) -> str:
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer
def compile_regex_grammar(inputs, vocab):
start_time = time.time()
print("🔥 starting compile_regex_grammar", inputs)
schema = build_regex_from_schema(inputs)
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
tokenizer = adapt_tokenizer(vocab, special_tokens)
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
fsm = RegexFSM(schema, tokenizer)
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
return fsm
def convert_grammar_to_regex(inputs):
start_time = time.time()
print("🔥 starting convert_grammar_to_regex", inputs)
schema = build_regex_from_schema(inputs)
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
return schema
"#,
"",
"",
)?
// .getattr("compile_regex_grammar")?
.getattr("convert_grammar_to_regex")?
.into();
let args: &pyo3::types::PyDict = tokenizer.get_vocab(true).into_py_dict(py);
let special_tokens: Vec<String> = vec![];
let regex_fsm = fun.call(py, (inputs.clone(),), None)?;
if false {
let regex_fsm = fun.call(py, (inputs.clone(), args, special_tokens), None)?;
let regex_fsm_ref = regex_fsm.into_ref(py);
let states_to_token_maps = regex_fsm_ref
.getattr("states_to_token_maps")?
.extract::<BTreeMap<u32, BTreeMap<u32, u32>>>()?;
println!("🔥 elapsed: {:?}", start_time.elapsed());
// size of serialized states_to_token_maps
let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap();
println!(
"🔥 states_to_token_maps size: {:.2}MB",
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
);
}
let result = regex_fsm.into_ref(py).extract().unwrap();
println!("result: {:?}", result);
Ok(result)
})
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
let elapsed = start_time.elapsed();
println!("🔥 elapsed: {:?}", elapsed);
Ok(schema)
}
type GrammarCompilationRequest = (
String,
oneshot::Sender<Result<String, ValidationError>>,
Span,
);
type TokenizerRequest = ( type TokenizerRequest = (
(String, Option<usize>), (String, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>, oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,

View File

@ -34,7 +34,7 @@ peft = { version = "^0.8.2", optional = true }
torch = { version = "^2.1.1", optional = true } torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1" scipy = "^1.11.1"
pillow = "^10.0.0" pillow = "^10.0.0"
outlines= { version = "^0.0.27", optional = true } outlines= { version = "^0.0.34", optional = true }
[tool.poetry.extras] [tool.poetry.extras]
torch = ["torch"] torch = ["torch"]

View File

@ -6,7 +6,7 @@ from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object from outlines.fsm.json_schema import build_regex_from_schema
from functools import lru_cache from functools import lru_cache
from typing import List, Optional, DefaultDict from typing import List, Optional, DefaultDict
import time import time
@ -512,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
def _cached_compile_fsm(grammar_type, schema, tokenizer): def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time() start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema) schema = build_regex_from_schema(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer) fsm = RegexFSM(schema, tokenizer)