feat: include grammar type in request, avoid alloc and improve proto types

This commit is contained in:
drbh 2024-02-14 00:03:56 +00:00
parent 8b9430fb68
commit d849641b28
10 changed files with 118 additions and 61 deletions

View File

@ -51,6 +51,12 @@ message ClearCacheRequest {
/// Empty response
message ClearCacheResponse {}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
@ -72,6 +78,8 @@ message NextTokenChooserParameters {
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
}
message StoppingCriteriaParameters {

View File

@ -129,6 +129,7 @@ impl Client {
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,

View File

@ -9,8 +9,8 @@ pub use client::Client;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Request, StoppingCriteriaParameters, Tokens,
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -1,5 +1,6 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
@ -46,6 +47,7 @@ impl Health {
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,

View File

@ -44,6 +44,40 @@ impl HubTokenizerConfig {
serde_json::from_str(&content).unwrap_or_default()
}
}
mod json_object_or_string_to_string {
// This custom deserializer is used to handle the fact that the grammar field can be either a
// string or an object. In both cases we handle it as a string, but also provide this convience
// to the user to be flexible with the input.
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
_ => Err(de::Error::custom("expected string or object for grammar")),
}
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
#[serde(
rename = "json",
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
Json(String),
#[serde(rename = "regex")]
Regex(String),
}
mod token_serde {
use super::*;
@ -201,31 +235,8 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
#[serde(
default,
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
pub grammar: String,
}
mod json_object_or_string_to_string {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
_ => Err(de::Error::custom("expected string or object for grammar")),
}
}
#[serde(default)]
pub grammar: Option<GrammarType>,
}
fn default_max_new_tokens() -> Option<u32> {
@ -251,7 +262,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false,
seed: None,
top_n_tokens: None,
grammar: String::new(),
grammar: None,
}
}

View File

@ -343,7 +343,9 @@ enum QueueCommand {
#[cfg(test)]
mod tests {
use super::*;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use tracing::info_span;
fn default_entry() -> (
@ -369,6 +371,7 @@ mod tests {
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,

View File

@ -614,7 +614,7 @@ async fn chat_completions(
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar: String::new(),
grammar: None,
},
};

View File

@ -1,8 +1,10 @@
/// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest};
use crate::{GenerateParameters, GenerateRequest, GrammarType};
use rand::{thread_rng, Rng};
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection;
@ -296,10 +298,27 @@ impl Validation {
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
// NOTE: this is currently difficult because we need the tokenizer in Python to build
// the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
// may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
// compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type) = match grammar {
Some(grammar) => {
// Ensure that grammar is not set if it's not supported
if !grammar.is_empty() && !self.grammar_support {
if !self.grammar_support {
return Err(ValidationError::Grammar);
}
match grammar {
// currently both are handled the same way since compilation is done in Python
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
}
}
None => (String::new(), ProtoGrammarType::None.into()),
};
let parameters = NextTokenChooserParameters {
temperature,
@ -312,6 +331,7 @@ impl Validation {
seed,
watermark,
grammar,
grammar_type,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,

View File

@ -5,6 +5,7 @@ import json
from loguru import logger
from functools import lru_cache
from typing import Optional, List, Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
@ -476,10 +477,12 @@ class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
def __init__(self, tokenizer, device, grammar):
def __init__(self, tokenizer, device, grammar, grammar_type):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm(grammar, self.tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer
)
def __call__(
self,
@ -508,17 +511,12 @@ class GrammarLogitProcessor(LogitsProcessor):
# TODO: move grammar compilation into the router
@staticmethod
@lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(schema, tokenizer):
def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time()
# Detect if schema is a json object before converting it to regex.
# We need to check if it's a valid json object before converting it to regex
# and cannot simply test if it starts with '{' and ends with '}' because there
# are valid regexes that start and end with curly braces.
try:
json.loads(schema) # check if schema is a valid json
schema = build_regex_from_object(schema) # convert schema to regex
except json.JSONDecodeError:
pass
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
@ -561,32 +559,32 @@ class GrammarLogitProcessor(LogitsProcessor):
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammars):
def __init__(self, tokenizer, device, grammars, grammar_type):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = [
(
GrammarLogitProcessor._cached_compile_fsm(g, self.tokenizer)
if g
else None
self.fsms = []
for i in range(len(grammars)):
fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type[i], grammars[i], self.tokenizer
)
for g in grammars
]
self.fsms.append(fsm)
def __call__(
self,
logits: torch.Tensor,
fsm_grammar_states: List[int],
mask: torch.Tensor,
):
for i in range(logits.shape[0]):
fsm = self.fsms[i]
if fsm_grammar_states[i] == -1 or fsm is None:
continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits[i] + mask
mask.fill_(-math.inf)
logits[i] = biased_scores
return logits
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):

View File

@ -1,9 +1,10 @@
import re
from typing import List, Optional, Tuple
import math
import torch
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor,
GrammarLogitProcessor,
@ -36,6 +37,7 @@ class NextTokenChooser:
device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0,
):
self.watermark_processor = (
@ -52,7 +54,9 @@ class NextTokenChooser:
else None
)
self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device, grammar) if grammar != "" else None
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
if grammar != ""
else None
)
self.tokenizer = tokenizer
@ -121,6 +125,7 @@ class NextTokenChooser:
device=device,
tokenizer=tokenizer,
grammar=pb.grammar,
grammar_type=pb.grammar_type,
)
@ -227,6 +232,7 @@ class HeterogeneousNextTokenChooser:
seeds: List[int],
tokenizer: PreTrainedTokenizerBase,
grammars: List[str],
grammar_types: List[GrammarType],
fsm_grammar_states=List[int],
):
warpers = []
@ -260,7 +266,9 @@ class HeterogeneousNextTokenChooser:
)
self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(tokenizer, device, grammars)
HeterogeneousGrammarLogitProcessor(
tokenizer, device, grammars, grammar_types
)
if any([grammar != "" for grammar in grammars])
else None
)
@ -319,6 +327,8 @@ class HeterogeneousNextTokenChooser:
scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
for j in range(S):
_scores = scores[:, j]
if self.watermark_processor is not None:
@ -330,7 +340,7 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers:
_scores = warper(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
_scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
@ -421,12 +431,15 @@ class HeterogeneousNextTokenChooser:
new_grammars = []
new_fsm_grammar_states = []
new_grammar_types = []
for i in indices:
new_grammars.append(self.grammars[i])
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
new_grammar_types.append(self.grammar_types[i])
self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states
self.grammar_types = new_grammar_types
if any(self.do_sample):
self.choice.filter(indices)
@ -457,6 +470,7 @@ class HeterogeneousNextTokenChooser:
dtype=dtype,
tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb],
grammar_types=[pb_.grammar_type for pb_ in pb],
fsm_grammar_states=[0] * len(pb),
)