mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
feat: prefer precompiled grammar
This commit is contained in:
parent
4f7074ca71
commit
c52a0f679e
@ -19,8 +19,11 @@ impl Client {
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let channel = Channel::builder(uri).connect().await?;
|
||||
|
||||
let limit = 100 * 1024 * 1024; // 100MB
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
stub: TextGenerationServiceClient::new(channel)
|
||||
.max_decoding_message_size(limit)
|
||||
.max_encoding_message_size(limit),
|
||||
})
|
||||
}
|
||||
|
||||
@ -33,8 +36,12 @@ impl Client {
|
||||
}))
|
||||
.await?;
|
||||
|
||||
let limit = 100 * 1024 * 1024; // 100MB
|
||||
println!("limit: {}", limit);
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
stub: TextGenerationServiceClient::new(channel)
|
||||
.max_decoding_message_size(limit)
|
||||
.max_encoding_message_size(limit),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -397,10 +397,23 @@ impl Validation {
|
||||
.await
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
|
||||
// flatten the BTreeMap<u32, BTreeMap<u32, u32>> to 3 Vec<u32> into 4 vectors (start_states, tokens, end_states, offsets)
|
||||
let mut start_states = vec![];
|
||||
let mut tokens = vec![];
|
||||
let mut end_states = vec![];
|
||||
|
||||
for (start_state, token_map) in _states_to_token_maps.iter() {
|
||||
for (token, end_state) in token_map.iter() {
|
||||
start_states.push(*start_state);
|
||||
tokens.push(*token);
|
||||
end_states.push(*end_state);
|
||||
}
|
||||
}
|
||||
|
||||
let stm = StatesToTokenMaps {
|
||||
start_states: vec![],
|
||||
tokens: vec![],
|
||||
end_states: vec![],
|
||||
start_states,
|
||||
tokens,
|
||||
end_states,
|
||||
};
|
||||
|
||||
(
|
||||
|
@ -206,11 +206,20 @@ def serve(
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
max_send_message_size = 100 * 1024 * 1024 # 100 MB
|
||||
max_receive_message_size = 100 * 1024 * 1024 # 100 MB
|
||||
|
||||
server_options = [
|
||||
("grpc.max_send_message_length", max_send_message_size),
|
||||
("grpc.max_receive_message_length", max_receive_message_size),
|
||||
]
|
||||
|
||||
server = aio.server(
|
||||
options=server_options,
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
UDSOpenTelemetryAioServerInterceptor(),
|
||||
]
|
||||
],
|
||||
)
|
||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||
TextGenerationService(model, Cache(), quantize, server_urls), server
|
||||
|
@ -475,9 +475,19 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
fsm_state: DefaultDict[int, int]
|
||||
fsm: RegexFSM
|
||||
|
||||
def __init__(self, tokenizer, device, grammar, grammar_type):
|
||||
def __init__(self, tokenizer, device, grammar, grammar_type, states_to_token_maps):
|
||||
self.device = device
|
||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||
|
||||
# TODO: use the precompiled grammar here
|
||||
self.states_to_token_maps = states_to_token_maps
|
||||
precompiled_grammar = RegexFSM.precompiled(
|
||||
states_to_token_maps=states_to_token_maps,
|
||||
empty_token_ids=None,
|
||||
vocabulary=None,
|
||||
eos_token_id=None,
|
||||
)
|
||||
|
||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||
grammar_type, grammar, self.tokenizer
|
||||
)
|
||||
@ -550,14 +560,37 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
|
||||
|
||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
def __init__(self, tokenizer, device, grammars, grammar_types):
|
||||
def __init__(
|
||||
self, tokenizer, device, grammars, grammar_types, states_to_token_maps
|
||||
):
|
||||
self.device = device
|
||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||
self.fsms = []
|
||||
|
||||
for grammar, grammar_type in zip(grammars, grammar_types):
|
||||
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||
grammar_type, grammar, self.tokenizer
|
||||
start_states = states_to_token_maps[0].start_states
|
||||
tokens = states_to_token_maps[0].tokens
|
||||
end_states = states_to_token_maps[0].end_states
|
||||
|
||||
_states_to_token_maps = {}
|
||||
for i in range(len(start_states)):
|
||||
if start_states[i] in _states_to_token_maps:
|
||||
_states_to_token_maps[start_states[i]][tokens[i]] = end_states[i]
|
||||
else:
|
||||
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
|
||||
|
||||
# TODO: cleanup how precompiled grammars are handled
|
||||
precompiled_grammar = RegexFSM.precompiled(
|
||||
states_to_token_maps=_states_to_token_maps,
|
||||
empty_token_ids=None,
|
||||
vocabulary=list(tokenizer.get_vocab().values()),
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
# fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||
# grammar_type, grammar, self.tokenizer
|
||||
# )
|
||||
|
||||
fsm = precompiled_grammar
|
||||
self.fsms.append(fsm)
|
||||
|
||||
def __call__(
|
||||
|
@ -39,6 +39,7 @@ class NextTokenChooser:
|
||||
grammar: str = "",
|
||||
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
||||
fsm_grammar_state: int = 0,
|
||||
states_to_token_maps: List[List[int]] = None,
|
||||
):
|
||||
self.watermark_processor = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
@ -54,7 +55,9 @@ class NextTokenChooser:
|
||||
else None
|
||||
)
|
||||
self.grammar_processor = (
|
||||
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
|
||||
GrammarLogitProcessor(
|
||||
tokenizer, device, grammar, grammar_type, states_to_token_maps
|
||||
)
|
||||
if grammar != ""
|
||||
else None
|
||||
)
|
||||
@ -78,6 +81,7 @@ class NextTokenChooser:
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
self.fsm_grammar_state = fsm_grammar_state
|
||||
self.grammar = grammar
|
||||
self.states_to_token_maps = states_to_token_maps
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
if self.watermark_processor is not None:
|
||||
@ -126,6 +130,7 @@ class NextTokenChooser:
|
||||
tokenizer=tokenizer,
|
||||
grammar=pb.grammar,
|
||||
grammar_type=pb.grammar_type,
|
||||
states_to_token_maps=pb.states_to_token_maps,
|
||||
)
|
||||
|
||||
|
||||
@ -233,7 +238,8 @@ class HeterogeneousNextTokenChooser:
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
grammars: List[str],
|
||||
grammar_types: List[int],
|
||||
fsm_grammar_states=List[int],
|
||||
fsm_grammar_states: List[int],
|
||||
states_to_token_maps: List[List[List[int]]],
|
||||
):
|
||||
warpers = []
|
||||
|
||||
@ -267,7 +273,7 @@ class HeterogeneousNextTokenChooser:
|
||||
|
||||
self.grammar_processor = (
|
||||
HeterogeneousGrammarLogitProcessor(
|
||||
tokenizer, device, grammars, grammar_types
|
||||
tokenizer, device, grammars, grammar_types, states_to_token_maps
|
||||
)
|
||||
if any([grammar != "" for grammar in grammars])
|
||||
else None
|
||||
@ -308,6 +314,7 @@ class HeterogeneousNextTokenChooser:
|
||||
self.fsm_grammar_states = fsm_grammar_states
|
||||
self.grammars = grammars
|
||||
self.grammar_types = grammar_types
|
||||
self.states_to_token_maps = states_to_token_maps
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -486,6 +493,7 @@ class HeterogeneousNextTokenChooser:
|
||||
fsm_grammar_states=(
|
||||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||
),
|
||||
states_to_token_maps=[pb_.states_to_token_maps for pb_ in pb],
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user