Move JSON grammar -> regex grammar conversion to the router

This change moves the JSON grammar -> regex grammar conversion to the
router by adding a dependency on the `outlines-core` Rust crate. In
contrast to the Python implementation, the conversions are not LRU-cached
since they seem to be fast enough:

simple schema           time:   [5.8293 µs 5.8307 µs 5.8320 µs]
                        change: [-13.166% -12.884% -12.641%] (p = 0.00 < 0.05)
                        Performance has improved.

complex schema          time:   [14.875 µs 14.881 µs 14.887 µs]
                        change: [-2.1637% -1.9914% -1.7852%] (p = 0.00 < 0.05)
                        Performance has improved.

Using the schemas from:
https://github.com/dottxt-ai/outlines-core/blob/main/benchmarks/bench_json_schema.py
This commit is contained in:
Daniël de Kok 2024-11-22 12:21:32 +00:00
parent 780531ec77
commit 7e87e868e6
5 changed files with 62 additions and 20 deletions

24
Cargo.lock generated
View File

@ -3005,6 +3005,17 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "outlines-core"
version = "0.1.0"
source = "git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#ba10c619fc9bf3c487e43f49bdecb95a24bb465c"
dependencies = [
"anyhow",
"regex",
"serde-pyobject",
"serde_json",
]
[[package]] [[package]]
name = "overload" name = "overload"
version = "0.1.1" version = "0.1.1"
@ -3952,6 +3963,16 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-pyobject"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca4b0aad8b225845739a0030a0d5cc2ae949c56a86a7daf9226c7df7c2016d16"
dependencies = [
"pyo3",
"serde",
]
[[package]] [[package]]
name = "serde_cbor" name = "serde_cbor"
version = "0.11.2" version = "0.11.2"
@ -3979,6 +4000,7 @@ version = "1.0.133"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
dependencies = [ dependencies = [
"indexmap 2.6.0",
"itoa", "itoa",
"memchr", "memchr",
"ryu", "ryu",
@ -4430,6 +4452,7 @@ dependencies = [
name = "text-generation-router" name = "text-generation-router"
version = "2.4.2-dev0" version = "2.4.2-dev0"
dependencies = [ dependencies = [
"anyhow",
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.7.9", "axum 0.7.9",
@ -4453,6 +4476,7 @@ dependencies = [
"once_cell", "once_cell",
"opentelemetry 0.20.0", "opentelemetry 0.20.0",
"opentelemetry-otlp", "opentelemetry-otlp",
"outlines-core",
"pyo3", "pyo3",
"rand", "rand",
"regex", "regex",

3
crate-hashes.json Normal file
View File

@ -0,0 +1,3 @@
{
"git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#outlines-core@0.1.0": "1j9dcd831b0bmmjk2n4aag3x47qnqmkpg4gqpvwwyic7744llbfm"
}

View File

@ -8,6 +8,7 @@ authors.workspace = true
homepage.workspace = true homepage.workspace = true
[dependencies] [dependencies]
anyhow = "1"
async-trait = "0.1.74" async-trait = "0.1.74"
async-stream = "0.3.5" async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] } axum = { version = "0.7", features = ["json"] }
@ -22,6 +23,7 @@ metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0" opentelemetry-otlp = "0.13.0"
outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" }
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] } reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"

View File

@ -9,6 +9,7 @@ use crate::{PyTokenizer, Tokenizer};
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{ImageFormat, ImageReader}; use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use outlines_core::json_schema::to_regex as json_schema_to_regex;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
@ -351,11 +352,13 @@ impl Validation {
"Grammar must have a 'properties' field".to_string(), "Grammar must have a 'properties' field".to_string(),
))?; ))?;
// Serialize json to string // Do compilation in the router for performance. In the future, we
ValidGrammar::Json( // should also move regex -> automaton compilation in the router,
serde_json::to_string(&json) // but this is not yet supported in pure Rust by outlines-core.
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, let grammar_regex = json_schema_to_regex(&json, None, &json)
) .map_err(ValidationError::RegexFromSchema)?;
ValidGrammar::Regex(grammar_regex.to_string())
} }
GrammarType::Regex(regex) => ValidGrammar::Regex(regex), GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
}; };
@ -810,6 +813,8 @@ pub enum ValidationError {
Grammar, Grammar,
#[error("grammar is not valid: {0}")] #[error("grammar is not valid: {0}")]
InvalidGrammar(String), InvalidGrammar(String),
#[error("cannot compile regex from schema: {0}")]
RegexFromSchema(anyhow::Error),
#[error("base64 encoding is invalid: {0}")] #[error("base64 encoding is invalid: {0}")]
InvalidBase64(#[from] base64::DecodeError), InvalidBase64(#[from] base64::DecodeError),
#[error("invalid image: {0}")] #[error("invalid image: {0}")]

View File

@ -1,19 +1,19 @@
from functools import lru_cache
import math import math
import time
import torch import torch
from typing import List, Optional, DefaultDict
from loguru import logger from loguru import logger
from typing import Dict, Union from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from functools import lru_cache
from typing import List, Optional, DefaultDict
import time
from transformers import ( from transformers import (
LogitsWarper, LogitsWarper,
LogitsProcessor, LogitsProcessor,
PreTrainedTokenizerBase,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
@ -484,7 +484,13 @@ class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int] fsm_state: DefaultDict[int, int]
fsm: RegexGuide fsm: RegexGuide
def __init__(self, tokenizer, device, grammar, grammar_type): def __init__(
self,
tokenizer: Optional[PreTrainedTokenizerBase],
device: str,
grammar: str,
grammar_type: GrammarType,
):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm( self.fsm = GrammarLogitProcessor._cached_compile_fsm(
@ -519,18 +525,20 @@ class GrammarLogitProcessor(LogitsProcessor):
# TODO: move grammar compilation into the router # TODO: move grammar compilation into the router
@staticmethod @staticmethod
@lru_cache(maxsize=32, typed=True) @lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(grammar_type, schema, tokenizer): def _cached_compile_fsm(
grammar_type: GrammarType,
schema: str,
tokenizer: Optional[PreTrainedTokenizerBase],
):
start_time = time.time() start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
try: # JSON schema is compiled by the v3 router.
schema = build_regex_from_schema(schema) logger.error(
# TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer "Non-regex grammars must be compiled by the router, grammar won't be enforced"
except Exception as e: )
logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}")
# allows everything # allows everything
schema = "(.*?)" schema = "(.*?)"
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexGuide.from_regex(schema, tokenizer) fsm = RegexGuide.from_regex(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm return fsm