mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: support simple grammars via outlines
This commit is contained in:
parent
b013cb4f4a
commit
0245506718
@ -70,6 +70,8 @@ message NextTokenChooserParameters {
|
||||
float frequency_penalty = 9;
|
||||
/// token watermarking using "A Watermark for Large Language Models"
|
||||
bool watermark = 8;
|
||||
/// grammar (applied if not empty)
|
||||
string grammar = 10;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
@ -128,6 +128,7 @@ impl Client {
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
|
@ -45,6 +45,7 @@ impl Health {
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -201,6 +201,8 @@ pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||
pub top_n_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub grammar: String,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> Option<u32> {
|
||||
@ -226,6 +228,7 @@ fn default_parameters() -> GenerateParameters {
|
||||
decoder_input_details: false,
|
||||
seed: None,
|
||||
top_n_tokens: None,
|
||||
grammar: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -354,7 +354,7 @@ mod tests {
|
||||
|
||||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: "".to_string(),
|
||||
inputs: String::new(),
|
||||
input_length: 0,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
@ -368,6 +368,7 @@ mod tests {
|
||||
repetition_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
ignore_eos_token: false,
|
||||
|
@ -614,6 +614,7 @@ async fn chat_completions(
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: String::new(),
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -182,6 +182,7 @@ impl Validation {
|
||||
watermark,
|
||||
decoder_input_details,
|
||||
top_n_tokens,
|
||||
grammar,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
@ -302,6 +303,7 @@ impl Validation {
|
||||
do_sample,
|
||||
seed,
|
||||
watermark,
|
||||
grammar,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
@ -21,6 +21,7 @@ from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
import time
|
||||
|
||||
class NextTokenChooser:
|
||||
def __init__(
|
||||
@ -36,6 +37,7 @@ class NextTokenChooser:
|
||||
seed=0,
|
||||
device="cpu",
|
||||
tokenizer=None,
|
||||
grammar=None,
|
||||
):
|
||||
self.watermark_processor = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
@ -66,9 +68,12 @@ class NextTokenChooser:
|
||||
self.static_warper = None
|
||||
|
||||
sampling = do_sample or has_warpers
|
||||
# TODO toggle grammar
|
||||
# self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
self.choice = Grammar(tokenizer, device)
|
||||
|
||||
# TODO: is grammar a subset of sampling? If so, we should merge them
|
||||
if grammar:
|
||||
self.choice = Grammar(tokenizer, device, grammar)
|
||||
else:
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
if self.watermark_processor is not None:
|
||||
@ -106,6 +111,7 @@ class NextTokenChooser:
|
||||
seed=pb.seed,
|
||||
device=device,
|
||||
tokenizer=tokenizer,
|
||||
grammar=pb.grammar,
|
||||
)
|
||||
|
||||
|
||||
@ -433,16 +439,24 @@ class Grammar:
|
||||
fsm_state: DefaultDict[int, int]
|
||||
fsm: RegexFSM
|
||||
|
||||
def __init__(self, tokenizer, device):
|
||||
# TODO: get regex on init not hardcoded
|
||||
regex_str = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
def __init__(self, tokenizer, device, regex_str):
|
||||
# TODO: adapt tokenizer is expensive, we should do it only once
|
||||
# this is a temporary solution
|
||||
|
||||
# TODO: remove debug logs
|
||||
# time this
|
||||
start_time = time.time()
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
|
||||
print(f"Adapt tokenizer: {time.time() - start_time}")
|
||||
start_time = time.time()
|
||||
|
||||
# TODO: avoid recompiling the FSM every time?
|
||||
fsm = RegexFSM(regex_str, tokenizer)
|
||||
print(f"Compile FSM: {time.time() - start_time}")
|
||||
self.fsm = fsm
|
||||
self.fsm_state = defaultdict(int)
|
||||
self.device = device
|
||||
|
||||
def __call__(self, logits):
|
||||
# TODO: handle seq_id properly
|
||||
@ -452,7 +466,7 @@ class Grammar:
|
||||
return self.fsm_state[seq_id].eos_token_id
|
||||
|
||||
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
|
||||
mask = torch.full((logits.shape[-1],), -math.inf, device=logits.device)
|
||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||
mask[allowed_tokens] = 0
|
||||
biased_scores = logits + mask
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user