diff --git a/proto/generate.proto b/proto/generate.proto index 5140fdaa..aae0e7a4 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -70,6 +70,8 @@ message NextTokenChooserParameters { float frequency_penalty = 9; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; + /// grammar (applied if not empty) + string grammar = 10; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 7b9f90fb..9822ea77 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -128,6 +128,7 @@ impl Client { repetition_penalty: 1.2, frequency_penalty: 0.1, watermark: true, + grammar: String::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/src/health.rs b/router/src/health.rs index e830a3c3..6f3d2023 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -45,6 +45,7 @@ impl Health { repetition_penalty: 1.0, frequency_penalty: 0.0, watermark: false, + grammar: String::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/lib.rs b/router/src/lib.rs index a9d783bb..ecad671f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -201,6 +201,8 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] pub top_n_tokens: Option, + #[serde(default)] + pub grammar: String, } fn default_max_new_tokens() -> Option { @@ -226,6 +228,7 @@ fn default_parameters() -> GenerateParameters { decoder_input_details: false, seed: None, top_n_tokens: None, + grammar: String::new(), } } diff --git a/router/src/queue.rs b/router/src/queue.rs index 3675e0f5..3e4aefa1 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -354,7 +354,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { - inputs: "".to_string(), + inputs: String::new(), input_length: 0, truncate: 0, decoder_input_details: false, @@ -368,6 +368,7 @@ mod tests { repetition_penalty: 0.0, frequency_penalty: 0.0, watermark: false, + grammar: String::new(), }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, diff --git a/router/src/server.rs b/router/src/server.rs index 00b793e3..6e042c4e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -614,6 +614,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: None, + grammar: String::new(), }, }; diff --git a/router/src/validation.rs b/router/src/validation.rs index e6874b11..a77995df 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -182,6 +182,7 @@ impl Validation { watermark, decoder_input_details, top_n_tokens, + grammar, .. } = request.parameters; @@ -302,6 +303,7 @@ impl Validation { do_sample, seed, watermark, + grammar, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 3d160de7..82b2e35b 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -21,6 +21,7 @@ from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from outlines.fsm.fsm import RegexFSM +import time class NextTokenChooser: def __init__( @@ -36,6 +37,7 @@ class NextTokenChooser: seed=0, device="cpu", tokenizer=None, + grammar=None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -66,9 +68,12 @@ class NextTokenChooser: self.static_warper = None sampling = do_sample or has_warpers - # TODO toggle grammar - # self.choice = Sampling(seed, device) if sampling else Greedy() - self.choice = Grammar(tokenizer, device) + + # TODO: is grammar a subset of sampling? If so, we should merge them + if grammar: + self.choice = Grammar(tokenizer, device, grammar) + else: + self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): if self.watermark_processor is not None: @@ -106,6 +111,7 @@ class NextTokenChooser: seed=pb.seed, device=device, tokenizer=tokenizer, + grammar=pb.grammar, ) @@ -433,16 +439,24 @@ class Grammar: fsm_state: DefaultDict[int, int] fsm: RegexFSM - def __init__(self, tokenizer, device): - # TODO: get regex on init not hardcoded - regex_str = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" - + def __init__(self, tokenizer, device, regex_str): # TODO: adapt tokenizer is expensive, we should do it only once # this is a temporary solution + + # TODO: remove debug logs + # time this + start_time = time.time() tokenizer = self.adapt_tokenizer(tokenizer) + + print(f"Adapt tokenizer: {time.time() - start_time}") + start_time = time.time() + + # TODO: avoid recompiling the FSM every time? fsm = RegexFSM(regex_str, tokenizer) + print(f"Compile FSM: {time.time() - start_time}") self.fsm = fsm self.fsm_state = defaultdict(int) + self.device = device def __call__(self, logits): # TODO: handle seq_id properly @@ -452,7 +466,7 @@ class Grammar: return self.fsm_state[seq_id].eos_token_id allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) - mask = torch.full((logits.shape[-1],), -math.inf, device=logits.device) + mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask[allowed_tokens] = 0 biased_scores = logits + mask