From d849641b28ca3b828902ebebd88d83db975736f5 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 14 Feb 2024 00:03:56 +0000 Subject: [PATCH] feat: include grammar type in request, avoid alloc and improve proto types --- proto/generate.proto | 8 +++ router/client/src/client.rs | 1 + router/client/src/lib.rs | 4 +- router/src/health.rs | 2 + router/src/lib.rs | 63 +++++++++++-------- router/src/queue.rs | 5 +- router/src/server.rs | 2 +- router/src/validation.rs | 32 ++++++++-- .../utils/logits_process.py | 40 ++++++------ server/text_generation_server/utils/tokens.py | 22 +++++-- 10 files changed, 118 insertions(+), 61 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index aae0e7a4..0490029f 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -51,6 +51,12 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} +enum GrammarType { + GRAMMAR_TYPE_NONE = 0; + GRAMMAR_TYPE_JSON = 1; + GRAMMAR_TYPE_REGEX = 2; +} + message NextTokenChooserParameters { /// exponential scaling output probability distribution float temperature = 1; @@ -72,6 +78,8 @@ message NextTokenChooserParameters { bool watermark = 8; /// grammar (applied if not empty) string grammar = 10; + /// grammar type + GrammarType grammar_type = 11; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 9822ea77..f8658318 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -129,6 +129,7 @@ impl Client { frequency_penalty: 0.1, watermark: true, grammar: String::new(), + grammar_type: GrammarType::None as i32, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index c38b931b..6782d9ff 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,8 +9,8 @@ pub use client::Client; pub use pb::generate::v2::HealthResponse; pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, - Request, StoppingCriteriaParameters, Tokens, + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/health.rs b/router/src/health.rs index 6f3d2023..b05b3094 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -1,5 +1,6 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use text_generation_client::GrammarType as ProtoGrammarType; use text_generation_client::{ Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; @@ -46,6 +47,7 @@ impl Health { frequency_penalty: 0.0, watermark: false, grammar: String::new(), + grammar_type: ProtoGrammarType::None as i32, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/lib.rs b/router/src/lib.rs index b13d84ee..d3602e24 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -44,6 +44,40 @@ impl HubTokenizerConfig { serde_json::from_str(&content).unwrap_or_default() } } +mod json_object_or_string_to_string { + // This custom deserializer is used to handle the fact that the grammar field can be either a + // string or an object. In both cases we handle it as a string, but also provide this convience + // to the user to be flexible with the input. + use super::*; + use serde::de; + use serde::Deserializer; + use serde_json::Value; + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::String(s) => Ok(s), + Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()), + _ => Err(de::Error::custom("expected string or object for grammar")), + } + } +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(tag = "type", content = "value")] +pub(crate) enum GrammarType { + #[serde( + rename = "json", + deserialize_with = "json_object_or_string_to_string::deserialize" + )] + Json(String), + #[serde(rename = "regex")] + Regex(String), +} mod token_serde { use super::*; @@ -201,31 +235,8 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] pub top_n_tokens: Option, - #[serde( - default, - deserialize_with = "json_object_or_string_to_string::deserialize" - )] - pub grammar: String, -} - -mod json_object_or_string_to_string { - use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; - - match value { - Value::String(s) => Ok(s), - Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()), - _ => Err(de::Error::custom("expected string or object for grammar")), - } - } + #[serde(default)] + pub grammar: Option, } fn default_max_new_tokens() -> Option { @@ -251,7 +262,7 @@ fn default_parameters() -> GenerateParameters { decoder_input_details: false, seed: None, top_n_tokens: None, - grammar: String::new(), + grammar: None, } } diff --git a/router/src/queue.rs b/router/src/queue.rs index 3e4aefa1..52ea16ca 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -343,7 +343,9 @@ enum QueueCommand { #[cfg(test)] mod tests { use super::*; - use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; + use text_generation_client::{ + GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, + }; use tracing::info_span; fn default_entry() -> ( @@ -369,6 +371,7 @@ mod tests { frequency_penalty: 0.0, watermark: false, grammar: String::new(), + grammar_type: ProtoGrammarType::None as i32, }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, diff --git a/router/src/server.rs b/router/src/server.rs index bcf17f46..0fc76916 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -614,7 +614,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: None, - grammar: String::new(), + grammar: None, }, }; diff --git a/router/src/validation.rs b/router/src/validation.rs index f2cb6efd..dc79c61a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,8 +1,10 @@ /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; -use crate::{GenerateParameters, GenerateRequest}; +use crate::{GenerateParameters, GenerateRequest, GrammarType}; use rand::{thread_rng, Rng}; -use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; +use text_generation_client::{ + GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, +}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokenizers::TruncationDirection; @@ -296,10 +298,27 @@ impl Validation { .validate_input(request.inputs, truncate, max_new_tokens) .await?; - // Ensure that grammar is not set if it's not supported - if !grammar.is_empty() && !self.grammar_support { - return Err(ValidationError::Grammar); - } + // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar + // NOTE: this is currently difficult because we need the tokenizer in Python to build + // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which + // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM + // compiler and use that to build the FSM here. + + // Validate grammar and unpack the grammar and type for the proto message + let (grammar, grammar_type) = match grammar { + Some(grammar) => { + // Ensure that grammar is not set if it's not supported + if !self.grammar_support { + return Err(ValidationError::Grammar); + } + match grammar { + // currently both are handled the same way since compilation is done in Python + GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()), + GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), + } + } + None => (String::new(), ProtoGrammarType::None.into()), + }; let parameters = NextTokenChooserParameters { temperature, @@ -312,6 +331,7 @@ impl Validation { seed, watermark, grammar, + grammar_type, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index e88e8a74..718fc12e 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -5,6 +5,7 @@ import json from loguru import logger from functools import lru_cache from typing import Optional, List, Dict, Union +from text_generation_server.pb.generate_pb2 import GrammarType from outlines.fsm.fsm import RegexFSM from outlines.fsm.json_schema import build_regex_from_object @@ -476,10 +477,12 @@ class GrammarLogitProcessor(LogitsProcessor): fsm_state: DefaultDict[int, int] fsm: RegexFSM - def __init__(self, tokenizer, device, grammar): + def __init__(self, tokenizer, device, grammar, grammar_type): self.device = device self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) - self.fsm = GrammarLogitProcessor._cached_compile_fsm(grammar, self.tokenizer) + self.fsm = GrammarLogitProcessor._cached_compile_fsm( + grammar_type, grammar, self.tokenizer + ) def __call__( self, @@ -508,17 +511,12 @@ class GrammarLogitProcessor(LogitsProcessor): # TODO: move grammar compilation into the router @staticmethod @lru_cache(maxsize=32, typed=True) - def _cached_compile_fsm(schema, tokenizer): + def _cached_compile_fsm(grammar_type, schema, tokenizer): start_time = time.time() - # Detect if schema is a json object before converting it to regex. - # We need to check if it's a valid json object before converting it to regex - # and cannot simply test if it starts with '{' and ends with '}' because there - # are valid regexes that start and end with curly braces. - try: - json.loads(schema) # check if schema is a valid json - schema = build_regex_from_object(schema) # convert schema to regex - except json.JSONDecodeError: - pass + if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: + schema = build_regex_from_object(schema) + elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: + pass # schema is already a regex just here for clarity fsm = RegexFSM(schema, tokenizer) logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") return fsm @@ -561,32 +559,32 @@ class GrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor): - def __init__(self, tokenizer, device, grammars): + def __init__(self, tokenizer, device, grammars, grammar_type): self.device = device self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) - self.fsms = [ - ( - GrammarLogitProcessor._cached_compile_fsm(g, self.tokenizer) - if g - else None + self.fsms = [] + for i in range(len(grammars)): + fsm = GrammarLogitProcessor._cached_compile_fsm( + grammar_type[i], grammars[i], self.tokenizer ) - for g in grammars - ] + self.fsms.append(fsm) def __call__( self, logits: torch.Tensor, fsm_grammar_states: List[int], + mask: torch.Tensor, ): for i in range(logits.shape[0]): fsm = self.fsms[i] if fsm_grammar_states[i] == -1 or fsm is None: continue allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) - mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask[allowed_tokens] = 0 biased_scores = logits[i] + mask + mask.fill_(-math.inf) logits[i] = biased_scores + return logits def advance_batch(self, next_token_ids, fsm_grammar_states, grammars): diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 8d268c65..8f2ae4f6 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,9 +1,10 @@ import re from typing import List, Optional, Tuple +import math import torch from text_generation_server.pb import generate_pb2 -from text_generation_server.pb.generate_pb2 import FinishReason +from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType from text_generation_server.utils.logits_process import ( FrequencyPenaltyLogitsProcessor, GrammarLogitProcessor, @@ -36,6 +37,7 @@ class NextTokenChooser: device: str = "cpu", tokenizer: Optional[PreTrainedTokenizerBase] = None, grammar: str = "", + grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, fsm_grammar_state: int = 0, ): self.watermark_processor = ( @@ -52,7 +54,9 @@ class NextTokenChooser: else None ) self.grammar_processor = ( - GrammarLogitProcessor(tokenizer, device, grammar) if grammar != "" else None + GrammarLogitProcessor(tokenizer, device, grammar, grammar_type) + if grammar != "" + else None ) self.tokenizer = tokenizer @@ -121,6 +125,7 @@ class NextTokenChooser: device=device, tokenizer=tokenizer, grammar=pb.grammar, + grammar_type=pb.grammar_type, ) @@ -227,6 +232,7 @@ class HeterogeneousNextTokenChooser: seeds: List[int], tokenizer: PreTrainedTokenizerBase, grammars: List[str], + grammar_types: List[GrammarType], fsm_grammar_states=List[int], ): warpers = [] @@ -260,7 +266,9 @@ class HeterogeneousNextTokenChooser: ) self.grammar_processor = ( - HeterogeneousGrammarLogitProcessor(tokenizer, device, grammars) + HeterogeneousGrammarLogitProcessor( + tokenizer, device, grammars, grammar_types + ) if any([grammar != "" for grammar in grammars]) else None ) @@ -319,6 +327,8 @@ class HeterogeneousNextTokenChooser: scores = scores.view(B, S, -1) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) + mask = torch.full((scores.shape[-1],), -math.inf, device=self.device) + for j in range(S): _scores = scores[:, j] if self.watermark_processor is not None: @@ -330,7 +340,7 @@ class HeterogeneousNextTokenChooser: for warper in self.warpers: _scores = warper(input_ids, _scores) if self.grammar_processor is not None: - _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + _scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask) _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids @@ -421,12 +431,15 @@ class HeterogeneousNextTokenChooser: new_grammars = [] new_fsm_grammar_states = [] + new_grammar_types = [] for i in indices: new_grammars.append(self.grammars[i]) new_fsm_grammar_states.append(self.fsm_grammar_states[i]) + new_grammar_types.append(self.grammar_types[i]) self.grammars = new_grammars self.fsm_grammar_states = new_fsm_grammar_states + self.grammar_types = new_grammar_types if any(self.do_sample): self.choice.filter(indices) @@ -457,6 +470,7 @@ class HeterogeneousNextTokenChooser: dtype=dtype, tokenizer=tokenizer, grammars=[pb_.grammar for pb_ in pb], + grammar_types=[pb_.grammar_type for pb_ in pb], fsm_grammar_states=[0] * len(pb), )