diff --git a/router/src/lib.rs b/router/src/lib.rs index ecad671f..03b7ce88 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -201,10 +201,30 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] pub top_n_tokens: Option, - #[serde(default)] + #[serde(default, deserialize_with = "json_object_or_string_to_string::deserialize")] pub grammar: String, } +mod json_object_or_string_to_string { + use super::*; + use serde::de; + use serde::Deserializer; + use serde_json::Value; + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::String(s) => Ok(s), + Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()), + _ => Err(de::Error::custom("expected string or object for grammar")), + } + } +} + fn default_max_new_tokens() -> Option { Some(100) } diff --git a/server/pyproject.toml b/server/pyproject.toml index b8ebf2e3..566eda7a 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -34,6 +34,7 @@ peft = { version = "^0.8.2", optional = true } torch = { version = "^2.1.1", optional = true } scipy = "^1.11.1" pillow = "^10.0.0" +outlines="^0.0.27" [tool.poetry.extras] torch = ["torch"] diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 82b2e35b..802b5f88 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -21,6 +21,9 @@ from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from outlines.fsm.fsm import RegexFSM +from outlines.fsm.json_schema import build_regex_from_object + +# TODO: remove when done debugging import time class NextTokenChooser: @@ -70,7 +73,7 @@ class NextTokenChooser: sampling = do_sample or has_warpers # TODO: is grammar a subset of sampling? If so, we should merge them - if grammar: + if grammar: self.choice = Grammar(tokenizer, device, grammar) else: self.choice = Sampling(seed, device) if sampling else Greedy() @@ -434,26 +437,22 @@ class Greedy: def __call__(self, logits): return logits.argmax(dim=-1) -# TODO: move this whole thing into the logit_process util and make it a Sampler + class Grammar: fsm_state: DefaultDict[int, int] fsm: RegexFSM - def __init__(self, tokenizer, device, regex_str): - # TODO: adapt tokenizer is expensive, we should do it only once - # this is a temporary solution - + def __init__(self, tokenizer, device, grammar): # TODO: remove debug logs - # time this start_time = time.time() tokenizer = self.adapt_tokenizer(tokenizer) - print(f"Adapt tokenizer: {time.time() - start_time}") start_time = time.time() - - # TODO: avoid recompiling the FSM every time? - fsm = RegexFSM(regex_str, tokenizer) + regex_string = build_regex_from_object(grammar) + print(f"Build regex: {time.time() - start_time}") + fsm = RegexFSM(regex_string, tokenizer) print(f"Compile FSM: {time.time() - start_time}") + self.fsm = fsm self.fsm_state = defaultdict(int) self.device = device @@ -504,7 +503,8 @@ class Grammar: tokenizer.convert_token_to_string = convert_token_to_string return tokenizer - + + class HeterogeneousSampling: r""" Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.