mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support simple grammars via outlines
This commit is contained in:
parent
b013cb4f4a
commit
0245506718
@ -70,6 +70,8 @@ message NextTokenChooserParameters {
|
|||||||
float frequency_penalty = 9;
|
float frequency_penalty = 9;
|
||||||
/// token watermarking using "A Watermark for Large Language Models"
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
|
/// grammar (applied if not empty)
|
||||||
|
string grammar = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
@ -128,6 +128,7 @@ impl Client {
|
|||||||
repetition_penalty: 1.2,
|
repetition_penalty: 1.2,
|
||||||
frequency_penalty: 0.1,
|
frequency_penalty: 0.1,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
|
grammar: String::new(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
@ -45,6 +45,7 @@ impl Health {
|
|||||||
repetition_penalty: 1.0,
|
repetition_penalty: 1.0,
|
||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -201,6 +201,8 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||||
pub top_n_tokens: Option<u32>,
|
pub top_n_tokens: Option<u32>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub grammar: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> Option<u32> {
|
fn default_max_new_tokens() -> Option<u32> {
|
||||||
@ -226,6 +228,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
|
grammar: String::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,7 +354,7 @@ mod tests {
|
|||||||
|
|
||||||
let entry = Entry {
|
let entry = Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: "".to_string(),
|
inputs: String::new(),
|
||||||
input_length: 0,
|
input_length: 0,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
@ -368,6 +368,7 @@ mod tests {
|
|||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
|
@ -614,6 +614,7 @@ async fn chat_completions(
|
|||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
|
grammar: String::new(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -182,6 +182,7 @@ impl Validation {
|
|||||||
watermark,
|
watermark,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
|
grammar,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
@ -302,6 +303,7 @@ impl Validation {
|
|||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
|
grammar,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
@ -21,6 +21,7 @@ from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
|||||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
from outlines.fsm.fsm import RegexFSM
|
from outlines.fsm.fsm import RegexFSM
|
||||||
|
import time
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -36,6 +37,7 @@ class NextTokenChooser:
|
|||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
|
grammar=None,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -66,9 +68,12 @@ class NextTokenChooser:
|
|||||||
self.static_warper = None
|
self.static_warper = None
|
||||||
|
|
||||||
sampling = do_sample or has_warpers
|
sampling = do_sample or has_warpers
|
||||||
# TODO toggle grammar
|
|
||||||
# self.choice = Sampling(seed, device) if sampling else Greedy()
|
# TODO: is grammar a subset of sampling? If so, we should merge them
|
||||||
self.choice = Grammar(tokenizer, device)
|
if grammar:
|
||||||
|
self.choice = Grammar(tokenizer, device, grammar)
|
||||||
|
else:
|
||||||
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -106,6 +111,7 @@ class NextTokenChooser:
|
|||||||
seed=pb.seed,
|
seed=pb.seed,
|
||||||
device=device,
|
device=device,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
grammar=pb.grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -433,16 +439,24 @@ class Grammar:
|
|||||||
fsm_state: DefaultDict[int, int]
|
fsm_state: DefaultDict[int, int]
|
||||||
fsm: RegexFSM
|
fsm: RegexFSM
|
||||||
|
|
||||||
def __init__(self, tokenizer, device):
|
def __init__(self, tokenizer, device, regex_str):
|
||||||
# TODO: get regex on init not hardcoded
|
|
||||||
regex_str = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
|
||||||
|
|
||||||
# TODO: adapt tokenizer is expensive, we should do it only once
|
# TODO: adapt tokenizer is expensive, we should do it only once
|
||||||
# this is a temporary solution
|
# this is a temporary solution
|
||||||
|
|
||||||
|
# TODO: remove debug logs
|
||||||
|
# time this
|
||||||
|
start_time = time.time()
|
||||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
print(f"Adapt tokenizer: {time.time() - start_time}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# TODO: avoid recompiling the FSM every time?
|
||||||
fsm = RegexFSM(regex_str, tokenizer)
|
fsm = RegexFSM(regex_str, tokenizer)
|
||||||
|
print(f"Compile FSM: {time.time() - start_time}")
|
||||||
self.fsm = fsm
|
self.fsm = fsm
|
||||||
self.fsm_state = defaultdict(int)
|
self.fsm_state = defaultdict(int)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
# TODO: handle seq_id properly
|
# TODO: handle seq_id properly
|
||||||
@ -452,7 +466,7 @@ class Grammar:
|
|||||||
return self.fsm_state[seq_id].eos_token_id
|
return self.fsm_state[seq_id].eos_token_id
|
||||||
|
|
||||||
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
|
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
|
||||||
mask = torch.full((logits.shape[-1],), -math.inf, device=logits.device)
|
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||||
mask[allowed_tokens] = 0
|
mask[allowed_tokens] = 0
|
||||||
biased_scores = logits + mask
|
biased_scores = logits + mask
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user