mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
feat: support grammar compilation worker via py03
This commit is contained in:
parent
7dbaf9e901
commit
ad5f562aa5
85
Cargo.lock
generated
85
Cargo.lock
generated
@ -1401,6 +1401,15 @@ version = "2.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics"
|
||||
version = "0.21.1"
|
||||
@ -2197,6 +2206,69 @@ dependencies = [
|
||||
"prost 0.12.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3"
|
||||
version = "0.20.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"indoc",
|
||||
"libc",
|
||||
"memoffset",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
"unindent",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-build-config"
|
||||
version = "0.20.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"target-lexicon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-ffi"
|
||||
version = "0.20.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"pyo3-build-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros"
|
||||
version = "0.20.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn 2.0.51",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros-backend"
|
||||
version = "0.20.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn 2.0.51",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.11.1"
|
||||
@ -2945,6 +3017,12 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "target-lexicon"
|
||||
version = "0.12.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f"
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.10.1"
|
||||
@ -3030,6 +3108,7 @@ dependencies = [
|
||||
"nohash-hasher",
|
||||
"opentelemetry",
|
||||
"opentelemetry-otlp",
|
||||
"pyo3",
|
||||
"rand",
|
||||
"reqwest",
|
||||
"serde",
|
||||
@ -3583,6 +3662,12 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||
|
||||
[[package]]
|
||||
name = "unindent"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.7.1"
|
||||
|
@ -46,6 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = "1.0.10"
|
||||
futures-util = "0.3.30"
|
||||
pyo3 = "0.20.3"
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||
|
@ -1,3 +1,5 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// Payload validation logic
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
||||
@ -14,6 +16,9 @@ use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{instrument, Span};
|
||||
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::IntoPyDict;
|
||||
|
||||
/// Validation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Validation {
|
||||
@ -26,6 +31,7 @@ pub struct Validation {
|
||||
disable_grammar_support: bool,
|
||||
/// Channel to communicate with the background tokenization task
|
||||
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||
grammar_compilation_sender: Option<mpsc::UnboundedSender<GrammarCompilationRequest>>,
|
||||
}
|
||||
|
||||
impl Validation {
|
||||
@ -41,7 +47,7 @@ impl Validation {
|
||||
disable_grammar_support: bool,
|
||||
) -> Self {
|
||||
// If we have a fast tokenizer
|
||||
let sender = if let Some(tokenizer) = tokenizer {
|
||||
let (sender, grammar_compilation_sender) = if let Some(tokenizer) = tokenizer {
|
||||
// Create round robin channel
|
||||
let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
|
||||
let mut senders = Vec::with_capacity(workers);
|
||||
@ -58,17 +64,29 @@ impl Validation {
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: start more than one grammar compilation worker
|
||||
let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// create single grammar compilation workers
|
||||
tokio::task::spawn_blocking(move || {
|
||||
grammar_compilation_worker(tokenizer, grammar_receiver).map_err(|e| {
|
||||
tracing::error!("Error in grammar compilation worker: {:?}", e);
|
||||
e
|
||||
})
|
||||
});
|
||||
|
||||
// Create tokenization round robin task
|
||||
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
|
||||
|
||||
Some(validation_sender)
|
||||
(Some(validation_sender), Some(grammar_sender))
|
||||
} else {
|
||||
None
|
||||
(None, None)
|
||||
};
|
||||
|
||||
Self {
|
||||
max_best_of,
|
||||
sender,
|
||||
grammar_compilation_sender,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
@ -102,6 +120,27 @@ impl Validation {
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self, inputs))]
|
||||
pub async fn compile_grammar(&self, inputs: String) -> Result<String, ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.grammar_compilation_sender {
|
||||
// Create response channel
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
// Send request to the background validation task
|
||||
// Unwrap is safe here
|
||||
sender
|
||||
.send((inputs.clone(), response_sender, Span::current()))
|
||||
.unwrap();
|
||||
|
||||
// Await on response channel
|
||||
// Unwrap is safe here
|
||||
let encoding = response_receiver.await.unwrap()?;
|
||||
return Ok(encoding);
|
||||
}
|
||||
|
||||
Ok(inputs)
|
||||
}
|
||||
|
||||
#[instrument(skip(self, inputs))]
|
||||
async fn validate_input(
|
||||
&self,
|
||||
@ -331,12 +370,12 @@ impl Validation {
|
||||
.compile(&json)
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
|
||||
(
|
||||
// Serialize json to string
|
||||
serde_json::to_string(&json)
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
||||
ProtoGrammarType::Json.into(),
|
||||
)
|
||||
// NOTE: this is the first step to compile the grammar
|
||||
let regex_compiled_grammar = self
|
||||
.compile_grammar(serde_json::to_string(&json).unwrap())
|
||||
.await
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
(regex_compiled_grammar, ProtoGrammarType::Regex.into())
|
||||
}
|
||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
||||
}
|
||||
@ -418,6 +457,32 @@ fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<
|
||||
}
|
||||
}
|
||||
|
||||
/// Start grammar compilation workers
|
||||
fn grammar_compilation_worker(
|
||||
tokenizer: Tokenizer,
|
||||
mut receiver: mpsc::UnboundedReceiver<GrammarCompilationRequest>,
|
||||
) -> Result<(), PyErr> {
|
||||
// initialize python runtime
|
||||
pyo3::prepare_freethreaded_python();
|
||||
|
||||
// load in outlines for all workers
|
||||
Python::with_gil(|py| {
|
||||
PyModule::import(py, "outlines")?;
|
||||
Ok::<_, PyErr>(())
|
||||
})?;
|
||||
|
||||
// Loop over requests
|
||||
while let Some((inputs, response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(compile_grammar(inputs, &tokenizer))
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get input length and optionally truncate it
|
||||
fn prepare_input(
|
||||
mut inputs: String,
|
||||
@ -442,6 +507,120 @@ fn prepare_input(
|
||||
Ok((encoding, inputs))
|
||||
}
|
||||
|
||||
/// Compile a grammar
|
||||
fn compile_grammar(inputs: String, tokenizer: &Tokenizer) -> Result<String, ValidationError> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let schema = Python::with_gil(|py| -> PyResult<String> {
|
||||
let fun: Py<PyAny> = PyModule::from_code(
|
||||
py,
|
||||
r#"
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
import time
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, vocab, special_tokens):
|
||||
self.vocabulary = vocab
|
||||
self.special_tokens = special_tokens
|
||||
self.eos_token_id = 0
|
||||
|
||||
def get_vocab(self, with_added_tokens):
|
||||
return self.vocabulary
|
||||
|
||||
def encode(self, text, add_special_tokens):
|
||||
return text
|
||||
|
||||
def decode(self, text, skip_special_tokens):
|
||||
return text
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
return " ".join(tokens)
|
||||
|
||||
def adapt_tokenizer(vocab, special_tokens):
|
||||
start_time = time.time()
|
||||
tokenizer = Tokenizer(vocab, special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
||||
return tokenizer
|
||||
|
||||
def compile_regex_grammar(inputs, vocab):
|
||||
start_time = time.time()
|
||||
print("🔥 starting compile_regex_grammar", inputs)
|
||||
schema = build_regex_from_schema(inputs)
|
||||
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
|
||||
tokenizer = adapt_tokenizer(vocab, special_tokens)
|
||||
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
||||
fsm = RegexFSM(schema, tokenizer)
|
||||
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
|
||||
return fsm
|
||||
|
||||
def convert_grammar_to_regex(inputs):
|
||||
start_time = time.time()
|
||||
print("🔥 starting convert_grammar_to_regex", inputs)
|
||||
schema = build_regex_from_schema(inputs)
|
||||
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
|
||||
return schema
|
||||
"#,
|
||||
"",
|
||||
"",
|
||||
)?
|
||||
// .getattr("compile_regex_grammar")?
|
||||
.getattr("convert_grammar_to_regex")?
|
||||
.into();
|
||||
|
||||
let args: &pyo3::types::PyDict = tokenizer.get_vocab(true).into_py_dict(py);
|
||||
let special_tokens: Vec<String> = vec![];
|
||||
|
||||
let regex_fsm = fun.call(py, (inputs.clone(),), None)?;
|
||||
|
||||
if false {
|
||||
let regex_fsm = fun.call(py, (inputs.clone(), args, special_tokens), None)?;
|
||||
let regex_fsm_ref = regex_fsm.into_ref(py);
|
||||
|
||||
let states_to_token_maps = regex_fsm_ref
|
||||
.getattr("states_to_token_maps")?
|
||||
.extract::<BTreeMap<u32, BTreeMap<u32, u32>>>()?;
|
||||
|
||||
println!("🔥 elapsed: {:?}", start_time.elapsed());
|
||||
|
||||
// size of serialized states_to_token_maps
|
||||
let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap();
|
||||
println!(
|
||||
"🔥 states_to_token_maps size: {:.2}MB",
|
||||
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
|
||||
);
|
||||
}
|
||||
|
||||
let result = regex_fsm.into_ref(py).extract().unwrap();
|
||||
|
||||
println!("result: {:?}", result);
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
let elapsed = start_time.elapsed();
|
||||
println!("🔥 elapsed: {:?}", elapsed);
|
||||
Ok(schema)
|
||||
}
|
||||
|
||||
type GrammarCompilationRequest = (
|
||||
String,
|
||||
oneshot::Sender<Result<String, ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
|
||||
type TokenizerRequest = (
|
||||
(String, Option<usize>),
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
|
||||
|
@ -34,7 +34,7 @@ peft = { version = "^0.8.2", optional = true }
|
||||
torch = { version = "^2.1.1", optional = true }
|
||||
scipy = "^1.11.1"
|
||||
pillow = "^10.0.0"
|
||||
outlines= { version = "^0.0.27", optional = true }
|
||||
outlines= { version = "^0.0.34", optional = true }
|
||||
|
||||
[tool.poetry.extras]
|
||||
torch = ["torch"]
|
||||
|
@ -6,7 +6,7 @@ from typing import Dict, Union
|
||||
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_object
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, DefaultDict
|
||||
import time
|
||||
@ -512,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
def _cached_compile_fsm(grammar_type, schema, tokenizer):
|
||||
start_time = time.time()
|
||||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||
schema = build_regex_from_object(schema)
|
||||
schema = build_regex_from_schema(schema)
|
||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||
pass # schema is already a regex just here for clarity
|
||||
fsm = RegexFSM(schema, tokenizer)
|
||||
|
Loading…
Reference in New Issue
Block a user