mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support JSON schemas
This commit is contained in:
parent
0245506718
commit
066d3d4872
@ -201,10 +201,30 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||||
pub top_n_tokens: Option<u32>,
|
pub top_n_tokens: Option<u32>,
|
||||||
#[serde(default)]
|
#[serde(default, deserialize_with = "json_object_or_string_to_string::deserialize")]
|
||||||
pub grammar: String,
|
pub grammar: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod json_object_or_string_to_string {
|
||||||
|
use super::*;
|
||||||
|
use serde::de;
|
||||||
|
use serde::Deserializer;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
|
||||||
|
match value {
|
||||||
|
Value::String(s) => Ok(s),
|
||||||
|
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
|
||||||
|
_ => Err(de::Error::custom("expected string or object for grammar")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> Option<u32> {
|
fn default_max_new_tokens() -> Option<u32> {
|
||||||
Some(100)
|
Some(100)
|
||||||
}
|
}
|
||||||
|
@ -34,6 +34,7 @@ peft = { version = "^0.8.2", optional = true }
|
|||||||
torch = { version = "^2.1.1", optional = true }
|
torch = { version = "^2.1.1", optional = true }
|
||||||
scipy = "^1.11.1"
|
scipy = "^1.11.1"
|
||||||
pillow = "^10.0.0"
|
pillow = "^10.0.0"
|
||||||
|
outlines="^0.0.27"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
|
@ -21,6 +21,9 @@ from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
|||||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
from outlines.fsm.fsm import RegexFSM
|
from outlines.fsm.fsm import RegexFSM
|
||||||
|
from outlines.fsm.json_schema import build_regex_from_object
|
||||||
|
|
||||||
|
# TODO: remove when done debugging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
@ -70,7 +73,7 @@ class NextTokenChooser:
|
|||||||
sampling = do_sample or has_warpers
|
sampling = do_sample or has_warpers
|
||||||
|
|
||||||
# TODO: is grammar a subset of sampling? If so, we should merge them
|
# TODO: is grammar a subset of sampling? If so, we should merge them
|
||||||
if grammar:
|
if grammar:
|
||||||
self.choice = Grammar(tokenizer, device, grammar)
|
self.choice = Grammar(tokenizer, device, grammar)
|
||||||
else:
|
else:
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
@ -434,26 +437,22 @@ class Greedy:
|
|||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
return logits.argmax(dim=-1)
|
return logits.argmax(dim=-1)
|
||||||
|
|
||||||
# TODO: move this whole thing into the logit_process util and make it a Sampler
|
|
||||||
class Grammar:
|
class Grammar:
|
||||||
fsm_state: DefaultDict[int, int]
|
fsm_state: DefaultDict[int, int]
|
||||||
fsm: RegexFSM
|
fsm: RegexFSM
|
||||||
|
|
||||||
def __init__(self, tokenizer, device, regex_str):
|
def __init__(self, tokenizer, device, grammar):
|
||||||
# TODO: adapt tokenizer is expensive, we should do it only once
|
|
||||||
# this is a temporary solution
|
|
||||||
|
|
||||||
# TODO: remove debug logs
|
# TODO: remove debug logs
|
||||||
# time this
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||||
|
|
||||||
print(f"Adapt tokenizer: {time.time() - start_time}")
|
print(f"Adapt tokenizer: {time.time() - start_time}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
regex_string = build_regex_from_object(grammar)
|
||||||
# TODO: avoid recompiling the FSM every time?
|
print(f"Build regex: {time.time() - start_time}")
|
||||||
fsm = RegexFSM(regex_str, tokenizer)
|
fsm = RegexFSM(regex_string, tokenizer)
|
||||||
print(f"Compile FSM: {time.time() - start_time}")
|
print(f"Compile FSM: {time.time() - start_time}")
|
||||||
|
|
||||||
self.fsm = fsm
|
self.fsm = fsm
|
||||||
self.fsm_state = defaultdict(int)
|
self.fsm_state = defaultdict(int)
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -504,7 +503,8 @@ class Grammar:
|
|||||||
tokenizer.convert_token_to_string = convert_token_to_string
|
tokenizer.convert_token_to_string = convert_token_to_string
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousSampling:
|
class HeterogeneousSampling:
|
||||||
r"""
|
r"""
|
||||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||||
|
Loading…
Reference in New Issue
Block a user