mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Move JSON grammar -> regex grammar conversion to the router
This change moves the JSON grammar -> regex grammar conversion to the router by adding a dependency on the `outlines-core` Rust crate. In contrast to the Python implementation, the conversions are not LRU-cached since they seem to be fast enough: simple schema time: [5.8293 µs 5.8307 µs 5.8320 µs] change: [-13.166% -12.884% -12.641%] (p = 0.00 < 0.05) Performance has improved. complex schema time: [14.875 µs 14.881 µs 14.887 µs] change: [-2.1637% -1.9914% -1.7852%] (p = 0.00 < 0.05) Performance has improved. Using the schemas from: https://github.com/dottxt-ai/outlines-core/blob/main/benchmarks/bench_json_schema.py
This commit is contained in:
parent
780531ec77
commit
7e87e868e6
24
Cargo.lock
generated
24
Cargo.lock
generated
@ -3005,6 +3005,17 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "outlines-core"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#ba10c619fc9bf3c487e43f49bdecb95a24bb465c"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"regex",
|
||||
"serde-pyobject",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
@ -3952,6 +3963,16 @@ dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde-pyobject"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca4b0aad8b225845739a0030a0d5cc2ae949c56a86a7daf9226c7df7c2016d16"
|
||||
dependencies = [
|
||||
"pyo3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_cbor"
|
||||
version = "0.11.2"
|
||||
@ -3979,6 +4000,7 @@ version = "1.0.133"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
|
||||
dependencies = [
|
||||
"indexmap 2.6.0",
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
@ -4430,6 +4452,7 @@ dependencies = [
|
||||
name = "text-generation-router"
|
||||
version = "2.4.2-dev0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum 0.7.9",
|
||||
@ -4453,6 +4476,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"opentelemetry 0.20.0",
|
||||
"opentelemetry-otlp",
|
||||
"outlines-core",
|
||||
"pyo3",
|
||||
"rand",
|
||||
"regex",
|
||||
|
3
crate-hashes.json
Normal file
3
crate-hashes.json
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#outlines-core@0.1.0": "1j9dcd831b0bmmjk2n4aag3x47qnqmkpg4gqpvwwyic7744llbfm"
|
||||
}
|
@ -8,6 +8,7 @@ authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
@ -22,6 +23,7 @@ metrics-exporter-prometheus = { workspace = true }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" }
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
|
@ -9,6 +9,7 @@ use crate::{PyTokenizer, Tokenizer};
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use image::{ImageFormat, ImageReader};
|
||||
use jsonschema::{Draft, JSONSchema};
|
||||
use outlines_core::json_schema::to_regex as json_schema_to_regex;
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
@ -351,11 +352,13 @@ impl Validation {
|
||||
"Grammar must have a 'properties' field".to_string(),
|
||||
))?;
|
||||
|
||||
// Serialize json to string
|
||||
ValidGrammar::Json(
|
||||
serde_json::to_string(&json)
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
||||
)
|
||||
// Do compilation in the router for performance. In the future, we
|
||||
// should also move regex -> automaton compilation in the router,
|
||||
// but this is not yet supported in pure Rust by outlines-core.
|
||||
let grammar_regex = json_schema_to_regex(&json, None, &json)
|
||||
.map_err(ValidationError::RegexFromSchema)?;
|
||||
|
||||
ValidGrammar::Regex(grammar_regex.to_string())
|
||||
}
|
||||
GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
|
||||
};
|
||||
@ -810,6 +813,8 @@ pub enum ValidationError {
|
||||
Grammar,
|
||||
#[error("grammar is not valid: {0}")]
|
||||
InvalidGrammar(String),
|
||||
#[error("cannot compile regex from schema: {0}")]
|
||||
RegexFromSchema(anyhow::Error),
|
||||
#[error("base64 encoding is invalid: {0}")]
|
||||
InvalidBase64(#[from] base64::DecodeError),
|
||||
#[error("invalid image: {0}")]
|
||||
|
@ -1,19 +1,19 @@
|
||||
from functools import lru_cache
|
||||
import math
|
||||
import time
|
||||
import torch
|
||||
from typing import List, Optional, DefaultDict
|
||||
|
||||
from loguru import logger
|
||||
from typing import Dict, Union
|
||||
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, DefaultDict
|
||||
import time
|
||||
|
||||
from transformers import (
|
||||
LogitsWarper,
|
||||
LogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
@ -484,7 +484,13 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
fsm_state: DefaultDict[int, int]
|
||||
fsm: RegexGuide
|
||||
|
||||
def __init__(self, tokenizer, device, grammar, grammar_type):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase],
|
||||
device: str,
|
||||
grammar: str,
|
||||
grammar_type: GrammarType,
|
||||
):
|
||||
self.device = device
|
||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||
@ -519,18 +525,20 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
# TODO: move grammar compilation into the router
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=32, typed=True)
|
||||
def _cached_compile_fsm(grammar_type, schema, tokenizer):
|
||||
def _cached_compile_fsm(
|
||||
grammar_type: GrammarType,
|
||||
schema: str,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase],
|
||||
):
|
||||
start_time = time.time()
|
||||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||
try:
|
||||
schema = build_regex_from_schema(schema)
|
||||
# TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer
|
||||
except Exception as e:
|
||||
logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}")
|
||||
# JSON schema is compiled by the v3 router.
|
||||
logger.error(
|
||||
"Non-regex grammars must be compiled by the router, grammar won't be enforced"
|
||||
)
|
||||
# allows everything
|
||||
schema = "(.*?)"
|
||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||
pass # schema is already a regex just here for clarity
|
||||
|
||||
fsm = RegexGuide.from_regex(schema, tokenizer)
|
||||
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||
return fsm
|
||||
|
Loading…
Reference in New Issue
Block a user