mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: support JSON schemas
This commit is contained in:
parent
0245506718
commit
066d3d4872
@ -201,10 +201,30 @@ pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||
pub top_n_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
#[serde(default, deserialize_with = "json_object_or_string_to_string::deserialize")]
|
||||
pub grammar: String,
|
||||
}
|
||||
|
||||
mod json_object_or_string_to_string {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
|
||||
match value {
|
||||
Value::String(s) => Ok(s),
|
||||
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
|
||||
_ => Err(de::Error::custom("expected string or object for grammar")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> Option<u32> {
|
||||
Some(100)
|
||||
}
|
||||
|
@ -34,6 +34,7 @@ peft = { version = "^0.8.2", optional = true }
|
||||
torch = { version = "^2.1.1", optional = true }
|
||||
scipy = "^1.11.1"
|
||||
pillow = "^10.0.0"
|
||||
outlines="^0.0.27"
|
||||
|
||||
[tool.poetry.extras]
|
||||
torch = ["torch"]
|
||||
|
@ -21,6 +21,9 @@ from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_object
|
||||
|
||||
# TODO: remove when done debugging
|
||||
import time
|
||||
|
||||
class NextTokenChooser:
|
||||
@ -70,7 +73,7 @@ class NextTokenChooser:
|
||||
sampling = do_sample or has_warpers
|
||||
|
||||
# TODO: is grammar a subset of sampling? If so, we should merge them
|
||||
if grammar:
|
||||
if grammar:
|
||||
self.choice = Grammar(tokenizer, device, grammar)
|
||||
else:
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
@ -434,26 +437,22 @@ class Greedy:
|
||||
def __call__(self, logits):
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
# TODO: move this whole thing into the logit_process util and make it a Sampler
|
||||
|
||||
class Grammar:
|
||||
fsm_state: DefaultDict[int, int]
|
||||
fsm: RegexFSM
|
||||
|
||||
def __init__(self, tokenizer, device, regex_str):
|
||||
# TODO: adapt tokenizer is expensive, we should do it only once
|
||||
# this is a temporary solution
|
||||
|
||||
def __init__(self, tokenizer, device, grammar):
|
||||
# TODO: remove debug logs
|
||||
# time this
|
||||
start_time = time.time()
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
|
||||
print(f"Adapt tokenizer: {time.time() - start_time}")
|
||||
start_time = time.time()
|
||||
|
||||
# TODO: avoid recompiling the FSM every time?
|
||||
fsm = RegexFSM(regex_str, tokenizer)
|
||||
regex_string = build_regex_from_object(grammar)
|
||||
print(f"Build regex: {time.time() - start_time}")
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
print(f"Compile FSM: {time.time() - start_time}")
|
||||
|
||||
self.fsm = fsm
|
||||
self.fsm_state = defaultdict(int)
|
||||
self.device = device
|
||||
@ -504,7 +503,8 @@ class Grammar:
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
|
||||
class HeterogeneousSampling:
|
||||
r"""
|
||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||
|
Loading…
Reference in New Issue
Block a user