feat: support simple grammars via outlines

This commit is contained in:
drbh 2024-02-07 19:35:39 -05:00
parent b013cb4f4a
commit 0245506718
8 changed files with 34 additions and 9 deletions

View File

@ -70,6 +70,8 @@ message NextTokenChooserParameters {
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
}
message StoppingCriteriaParameters {

View File

@ -128,6 +128,7 @@ impl Client {
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,

View File

@ -45,6 +45,7 @@ impl Health {
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,

View File

@ -201,6 +201,8 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
#[serde(default)]
pub grammar: String,
}
fn default_max_new_tokens() -> Option<u32> {
@ -226,6 +228,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false,
seed: None,
top_n_tokens: None,
grammar: String::new(),
}
}

View File

@ -354,7 +354,7 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: "".to_string(),
inputs: String::new(),
input_length: 0,
truncate: 0,
decoder_input_details: false,
@ -368,6 +368,7 @@ mod tests {
repetition_penalty: 0.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,

View File

@ -614,6 +614,7 @@ async fn chat_completions(
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar: String::new(),
},
};

View File

@ -182,6 +182,7 @@ impl Validation {
watermark,
decoder_input_details,
top_n_tokens,
grammar,
..
} = request.parameters;
@ -302,6 +303,7 @@ impl Validation {
do_sample,
seed,
watermark,
grammar,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,

View File

@ -21,6 +21,7 @@ from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from outlines.fsm.fsm import RegexFSM
import time
class NextTokenChooser:
def __init__(
@ -36,6 +37,7 @@ class NextTokenChooser:
seed=0,
device="cpu",
tokenizer=None,
grammar=None,
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@ -66,9 +68,12 @@ class NextTokenChooser:
self.static_warper = None
sampling = do_sample or has_warpers
# TODO toggle grammar
# self.choice = Sampling(seed, device) if sampling else Greedy()
self.choice = Grammar(tokenizer, device)
# TODO: is grammar a subset of sampling? If so, we should merge them
if grammar:
self.choice = Grammar(tokenizer, device, grammar)
else:
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
if self.watermark_processor is not None:
@ -106,6 +111,7 @@ class NextTokenChooser:
seed=pb.seed,
device=device,
tokenizer=tokenizer,
grammar=pb.grammar,
)
@ -433,16 +439,24 @@ class Grammar:
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
def __init__(self, tokenizer, device):
# TODO: get regex on init not hardcoded
regex_str = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
def __init__(self, tokenizer, device, regex_str):
# TODO: adapt tokenizer is expensive, we should do it only once
# this is a temporary solution
# TODO: remove debug logs
# time this
start_time = time.time()
tokenizer = self.adapt_tokenizer(tokenizer)
print(f"Adapt tokenizer: {time.time() - start_time}")
start_time = time.time()
# TODO: avoid recompiling the FSM every time?
fsm = RegexFSM(regex_str, tokenizer)
print(f"Compile FSM: {time.time() - start_time}")
self.fsm = fsm
self.fsm_state = defaultdict(int)
self.device = device
def __call__(self, logits):
# TODO: handle seq_id properly
@ -452,7 +466,7 @@ class Grammar:
return self.fsm_state[seq_id].eos_token_id
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
mask = torch.full((logits.shape[-1],), -math.inf, device=logits.device)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits + mask