feat: include grammar type in request, avoid alloc and improve proto types

This commit is contained in:
drbh 2024-02-14 00:03:56 +00:00
parent 8b9430fb68
commit d849641b28
10 changed files with 118 additions and 61 deletions

View File

@ -51,6 +51,12 @@ message ClearCacheRequest {
/// Empty response /// Empty response
message ClearCacheResponse {} message ClearCacheResponse {}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters { message NextTokenChooserParameters {
/// exponential scaling output probability distribution /// exponential scaling output probability distribution
float temperature = 1; float temperature = 1;
@ -72,6 +78,8 @@ message NextTokenChooserParameters {
bool watermark = 8; bool watermark = 8;
/// grammar (applied if not empty) /// grammar (applied if not empty)
string grammar = 10; string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -129,6 +129,7 @@ impl Client {
frequency_penalty: 0.1, frequency_penalty: 0.1,
watermark: true, watermark: true,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens: max_total_tokens - truncate,

View File

@ -9,8 +9,8 @@ pub use client::Client;
pub use pb::generate::v2::HealthResponse; pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{ pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
Request, StoppingCriteriaParameters, Tokens, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -1,5 +1,6 @@
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{ use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
}; };
@ -46,6 +47,7 @@ impl Health {
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -44,6 +44,40 @@ impl HubTokenizerConfig {
serde_json::from_str(&content).unwrap_or_default() serde_json::from_str(&content).unwrap_or_default()
} }
} }
mod json_object_or_string_to_string {
// This custom deserializer is used to handle the fact that the grammar field can be either a
// string or an object. In both cases we handle it as a string, but also provide this convience
// to the user to be flexible with the input.
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
_ => Err(de::Error::custom("expected string or object for grammar")),
}
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
#[serde(
rename = "json",
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
Json(String),
#[serde(rename = "regex")]
Regex(String),
}
mod token_serde { mod token_serde {
use super::*; use super::*;
@ -201,31 +235,8 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>, pub top_n_tokens: Option<u32>,
#[serde( #[serde(default)]
default, pub grammar: Option<GrammarType>,
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
pub grammar: String,
}
mod json_object_or_string_to_string {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
_ => Err(de::Error::custom("expected string or object for grammar")),
}
}
} }
fn default_max_new_tokens() -> Option<u32> { fn default_max_new_tokens() -> Option<u32> {
@ -251,7 +262,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false, decoder_input_details: false,
seed: None, seed: None,
top_n_tokens: None, top_n_tokens: None,
grammar: String::new(), grammar: None,
} }
} }

View File

@ -343,7 +343,9 @@ enum QueueCommand {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use tracing::info_span; use tracing::info_span;
fn default_entry() -> ( fn default_entry() -> (
@ -369,6 +371,7 @@ mod tests {
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: false,

View File

@ -614,7 +614,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: String::new(), grammar: None,
}, },
}; };

View File

@ -1,8 +1,10 @@
/// Payload validation logic /// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest, GrammarType};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection; use tokenizers::TruncationDirection;
@ -296,10 +298,27 @@ impl Validation {
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
// Ensure that grammar is not set if it's not supported // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
if !grammar.is_empty() && !self.grammar_support { // NOTE: this is currently difficult because we need the tokenizer in Python to build
return Err(ValidationError::Grammar); // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
} // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
// compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type) = match grammar {
Some(grammar) => {
// Ensure that grammar is not set if it's not supported
if !self.grammar_support {
return Err(ValidationError::Grammar);
}
match grammar {
// currently both are handled the same way since compilation is done in Python
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
}
}
None => (String::new(), ProtoGrammarType::None.into()),
};
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
@ -312,6 +331,7 @@ impl Validation {
seed, seed,
watermark, watermark,
grammar, grammar,
grammar_type,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -5,6 +5,7 @@ import json
from loguru import logger from loguru import logger
from functools import lru_cache from functools import lru_cache
from typing import Optional, List, Dict, Union from typing import Optional, List, Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object from outlines.fsm.json_schema import build_regex_from_object
@ -476,10 +477,12 @@ class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int] fsm_state: DefaultDict[int, int]
fsm: RegexFSM fsm: RegexFSM
def __init__(self, tokenizer, device, grammar): def __init__(self, tokenizer, device, grammar, grammar_type):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm(grammar, self.tokenizer) self.fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer
)
def __call__( def __call__(
self, self,
@ -508,17 +511,12 @@ class GrammarLogitProcessor(LogitsProcessor):
# TODO: move grammar compilation into the router # TODO: move grammar compilation into the router
@staticmethod @staticmethod
@lru_cache(maxsize=32, typed=True) @lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(schema, tokenizer): def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time() start_time = time.time()
# Detect if schema is a json object before converting it to regex. if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
# We need to check if it's a valid json object before converting it to regex schema = build_regex_from_object(schema)
# and cannot simply test if it starts with '{' and ends with '}' because there elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
# are valid regexes that start and end with curly braces. pass # schema is already a regex just here for clarity
try:
json.loads(schema) # check if schema is a valid json
schema = build_regex_from_object(schema) # convert schema to regex
except json.JSONDecodeError:
pass
fsm = RegexFSM(schema, tokenizer) fsm = RegexFSM(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm return fsm
@ -561,32 +559,32 @@ class GrammarLogitProcessor(LogitsProcessor):
class HeterogeneousGrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammars): def __init__(self, tokenizer, device, grammars, grammar_type):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = [ self.fsms = []
( for i in range(len(grammars)):
GrammarLogitProcessor._cached_compile_fsm(g, self.tokenizer) fsm = GrammarLogitProcessor._cached_compile_fsm(
if g grammar_type[i], grammars[i], self.tokenizer
else None
) )
for g in grammars self.fsms.append(fsm)
]
def __call__( def __call__(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
fsm_grammar_states: List[int], fsm_grammar_states: List[int],
mask: torch.Tensor,
): ):
for i in range(logits.shape[0]): for i in range(logits.shape[0]):
fsm = self.fsms[i] fsm = self.fsms[i]
if fsm_grammar_states[i] == -1 or fsm is None: if fsm_grammar_states[i] == -1 or fsm is None:
continue continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0 mask[allowed_tokens] = 0
biased_scores = logits[i] + mask biased_scores = logits[i] + mask
mask.fill_(-math.inf)
logits[i] = biased_scores logits[i] = biased_scores
return logits return logits
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars): def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):

View File

@ -1,9 +1,10 @@
import re import re
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import math
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor, FrequencyPenaltyLogitsProcessor,
GrammarLogitProcessor, GrammarLogitProcessor,
@ -36,6 +37,7 @@ class NextTokenChooser:
device: str = "cpu", device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "", grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0, fsm_grammar_state: int = 0,
): ):
self.watermark_processor = ( self.watermark_processor = (
@ -52,7 +54,9 @@ class NextTokenChooser:
else None else None
) )
self.grammar_processor = ( self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device, grammar) if grammar != "" else None GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
if grammar != ""
else None
) )
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -121,6 +125,7 @@ class NextTokenChooser:
device=device, device=device,
tokenizer=tokenizer, tokenizer=tokenizer,
grammar=pb.grammar, grammar=pb.grammar,
grammar_type=pb.grammar_type,
) )
@ -227,6 +232,7 @@ class HeterogeneousNextTokenChooser:
seeds: List[int], seeds: List[int],
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
grammars: List[str], grammars: List[str],
grammar_types: List[GrammarType],
fsm_grammar_states=List[int], fsm_grammar_states=List[int],
): ):
warpers = [] warpers = []
@ -260,7 +266,9 @@ class HeterogeneousNextTokenChooser:
) )
self.grammar_processor = ( self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(tokenizer, device, grammars) HeterogeneousGrammarLogitProcessor(
tokenizer, device, grammars, grammar_types
)
if any([grammar != "" for grammar in grammars]) if any([grammar != "" for grammar in grammars])
else None else None
) )
@ -319,6 +327,8 @@ class HeterogeneousNextTokenChooser:
scores = scores.view(B, S, -1) scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
for j in range(S): for j in range(S):
_scores = scores[:, j] _scores = scores[:, j]
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -330,7 +340,7 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
if self.grammar_processor is not None: if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states) _scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores)
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
@ -421,12 +431,15 @@ class HeterogeneousNextTokenChooser:
new_grammars = [] new_grammars = []
new_fsm_grammar_states = [] new_fsm_grammar_states = []
new_grammar_types = []
for i in indices: for i in indices:
new_grammars.append(self.grammars[i]) new_grammars.append(self.grammars[i])
new_fsm_grammar_states.append(self.fsm_grammar_states[i]) new_fsm_grammar_states.append(self.fsm_grammar_states[i])
new_grammar_types.append(self.grammar_types[i])
self.grammars = new_grammars self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states self.fsm_grammar_states = new_fsm_grammar_states
self.grammar_types = new_grammar_types
if any(self.do_sample): if any(self.do_sample):
self.choice.filter(indices) self.choice.filter(indices)
@ -457,6 +470,7 @@ class HeterogeneousNextTokenChooser:
dtype=dtype, dtype=dtype,
tokenizer=tokenizer, tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb], grammars=[pb_.grammar for pb_ in pb],
grammar_types=[pb_.grammar_type for pb_ in pb],
fsm_grammar_states=[0] * len(pb), fsm_grammar_states=[0] * len(pb),
) )