feat: support JSON schemas

This commit is contained in:
drbh 2024-02-08 11:24:34 -05:00
parent 0245506718
commit 066d3d4872
3 changed files with 34 additions and 13 deletions

View File

@ -201,10 +201,30 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>, pub top_n_tokens: Option<u32>,
#[serde(default)] #[serde(default, deserialize_with = "json_object_or_string_to_string::deserialize")]
pub grammar: String, pub grammar: String,
} }
mod json_object_or_string_to_string {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
_ => Err(de::Error::custom("expected string or object for grammar")),
}
}
}
fn default_max_new_tokens() -> Option<u32> { fn default_max_new_tokens() -> Option<u32> {
Some(100) Some(100)
} }

View File

@ -34,6 +34,7 @@ peft = { version = "^0.8.2", optional = true }
torch = { version = "^2.1.1", optional = true } torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1" scipy = "^1.11.1"
pillow = "^10.0.0" pillow = "^10.0.0"
outlines="^0.0.27"
[tool.poetry.extras] [tool.poetry.extras]
torch = ["torch"] torch = ["torch"]

View File

@ -21,6 +21,9 @@ from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from outlines.fsm.fsm import RegexFSM from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
# TODO: remove when done debugging
import time import time
class NextTokenChooser: class NextTokenChooser:
@ -70,7 +73,7 @@ class NextTokenChooser:
sampling = do_sample or has_warpers sampling = do_sample or has_warpers
# TODO: is grammar a subset of sampling? If so, we should merge them # TODO: is grammar a subset of sampling? If so, we should merge them
if grammar: if grammar:
self.choice = Grammar(tokenizer, device, grammar) self.choice = Grammar(tokenizer, device, grammar)
else: else:
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
@ -434,26 +437,22 @@ class Greedy:
def __call__(self, logits): def __call__(self, logits):
return logits.argmax(dim=-1) return logits.argmax(dim=-1)
# TODO: move this whole thing into the logit_process util and make it a Sampler
class Grammar: class Grammar:
fsm_state: DefaultDict[int, int] fsm_state: DefaultDict[int, int]
fsm: RegexFSM fsm: RegexFSM
def __init__(self, tokenizer, device, regex_str): def __init__(self, tokenizer, device, grammar):
# TODO: adapt tokenizer is expensive, we should do it only once
# this is a temporary solution
# TODO: remove debug logs # TODO: remove debug logs
# time this
start_time = time.time() start_time = time.time()
tokenizer = self.adapt_tokenizer(tokenizer) tokenizer = self.adapt_tokenizer(tokenizer)
print(f"Adapt tokenizer: {time.time() - start_time}") print(f"Adapt tokenizer: {time.time() - start_time}")
start_time = time.time() start_time = time.time()
regex_string = build_regex_from_object(grammar)
# TODO: avoid recompiling the FSM every time? print(f"Build regex: {time.time() - start_time}")
fsm = RegexFSM(regex_str, tokenizer) fsm = RegexFSM(regex_string, tokenizer)
print(f"Compile FSM: {time.time() - start_time}") print(f"Compile FSM: {time.time() - start_time}")
self.fsm = fsm self.fsm = fsm
self.fsm_state = defaultdict(int) self.fsm_state = defaultdict(int)
self.device = device self.device = device
@ -504,7 +503,8 @@ class Grammar:
tokenizer.convert_token_to_string = convert_token_to_string tokenizer.convert_token_to_string = convert_token_to_string
return tokenizer return tokenizer
class HeterogeneousSampling: class HeterogeneousSampling:
r""" r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.