feat: prefer precompiled grammar

This commit is contained in:
drbh 2024-03-07 17:12:46 +00:00
parent 4f7074ca71
commit c52a0f679e
5 changed files with 83 additions and 13 deletions

View File

@ -19,8 +19,11 @@ impl Client {
pub async fn connect(uri: Uri) -> Result<Self> { pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?; let channel = Channel::builder(uri).connect().await?;
let limit = 100 * 1024 * 1024; // 100MB
Ok(Self { Ok(Self {
stub: TextGenerationServiceClient::new(channel), stub: TextGenerationServiceClient::new(channel)
.max_decoding_message_size(limit)
.max_encoding_message_size(limit),
}) })
} }
@ -33,8 +36,12 @@ impl Client {
})) }))
.await?; .await?;
let limit = 100 * 1024 * 1024; // 100MB
println!("limit: {}", limit);
Ok(Self { Ok(Self {
stub: TextGenerationServiceClient::new(channel), stub: TextGenerationServiceClient::new(channel)
.max_decoding_message_size(limit)
.max_encoding_message_size(limit),
}) })
} }

View File

@ -397,10 +397,23 @@ impl Validation {
.await .await
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
// flatten the BTreeMap<u32, BTreeMap<u32, u32>> to 3 Vec<u32> into 4 vectors (start_states, tokens, end_states, offsets)
let mut start_states = vec![];
let mut tokens = vec![];
let mut end_states = vec![];
for (start_state, token_map) in _states_to_token_maps.iter() {
for (token, end_state) in token_map.iter() {
start_states.push(*start_state);
tokens.push(*token);
end_states.push(*end_state);
}
}
let stm = StatesToTokenMaps { let stm = StatesToTokenMaps {
start_states: vec![], start_states,
tokens: vec![], tokens,
end_states: vec![], end_states,
}; };
( (

View File

@ -206,11 +206,20 @@ def serve(
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
max_send_message_size = 100 * 1024 * 1024 # 100 MB
max_receive_message_size = 100 * 1024 * 1024 # 100 MB
server_options = [
("grpc.max_send_message_length", max_send_message_size),
("grpc.max_receive_message_length", max_receive_message_size),
]
server = aio.server( server = aio.server(
options=server_options,
interceptors=[ interceptors=[
ExceptionInterceptor(), ExceptionInterceptor(),
UDSOpenTelemetryAioServerInterceptor(), UDSOpenTelemetryAioServerInterceptor(),
] ],
) )
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), quantize, server_urls), server TextGenerationService(model, Cache(), quantize, server_urls), server

View File

@ -475,9 +475,19 @@ class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int] fsm_state: DefaultDict[int, int]
fsm: RegexFSM fsm: RegexFSM
def __init__(self, tokenizer, device, grammar, grammar_type): def __init__(self, tokenizer, device, grammar, grammar_type, states_to_token_maps):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
# TODO: use the precompiled grammar here
self.states_to_token_maps = states_to_token_maps
precompiled_grammar = RegexFSM.precompiled(
states_to_token_maps=states_to_token_maps,
empty_token_ids=None,
vocabulary=None,
eos_token_id=None,
)
self.fsm = GrammarLogitProcessor._cached_compile_fsm( self.fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer grammar_type, grammar, self.tokenizer
) )
@ -550,14 +560,37 @@ class GrammarLogitProcessor(LogitsProcessor):
class HeterogeneousGrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammars, grammar_types): def __init__(
self, tokenizer, device, grammars, grammar_types, states_to_token_maps
):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = [] self.fsms = []
for grammar, grammar_type in zip(grammars, grammar_types): for grammar, grammar_type in zip(grammars, grammar_types):
fsm = GrammarLogitProcessor._cached_compile_fsm( start_states = states_to_token_maps[0].start_states
grammar_type, grammar, self.tokenizer tokens = states_to_token_maps[0].tokens
end_states = states_to_token_maps[0].end_states
_states_to_token_maps = {}
for i in range(len(start_states)):
if start_states[i] in _states_to_token_maps:
_states_to_token_maps[start_states[i]][tokens[i]] = end_states[i]
else:
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
# TODO: cleanup how precompiled grammars are handled
precompiled_grammar = RegexFSM.precompiled(
states_to_token_maps=_states_to_token_maps,
empty_token_ids=None,
vocabulary=list(tokenizer.get_vocab().values()),
eos_token_id=self.tokenizer.eos_token_id,
) )
# fsm = GrammarLogitProcessor._cached_compile_fsm(
# grammar_type, grammar, self.tokenizer
# )
fsm = precompiled_grammar
self.fsms.append(fsm) self.fsms.append(fsm)
def __call__( def __call__(

View File

@ -39,6 +39,7 @@ class NextTokenChooser:
grammar: str = "", grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0, fsm_grammar_state: int = 0,
states_to_token_maps: List[List[int]] = None,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -54,7 +55,9 @@ class NextTokenChooser:
else None else None
) )
self.grammar_processor = ( self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type) GrammarLogitProcessor(
tokenizer, device, grammar, grammar_type, states_to_token_maps
)
if grammar != "" if grammar != ""
else None else None
) )
@ -78,6 +81,7 @@ class NextTokenChooser:
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
self.fsm_grammar_state = fsm_grammar_state self.fsm_grammar_state = fsm_grammar_state
self.grammar = grammar self.grammar = grammar
self.states_to_token_maps = states_to_token_maps
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -126,6 +130,7 @@ class NextTokenChooser:
tokenizer=tokenizer, tokenizer=tokenizer,
grammar=pb.grammar, grammar=pb.grammar,
grammar_type=pb.grammar_type, grammar_type=pb.grammar_type,
states_to_token_maps=pb.states_to_token_maps,
) )
@ -233,7 +238,8 @@ class HeterogeneousNextTokenChooser:
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
grammars: List[str], grammars: List[str],
grammar_types: List[int], grammar_types: List[int],
fsm_grammar_states=List[int], fsm_grammar_states: List[int],
states_to_token_maps: List[List[List[int]]],
): ):
warpers = [] warpers = []
@ -267,7 +273,7 @@ class HeterogeneousNextTokenChooser:
self.grammar_processor = ( self.grammar_processor = (
HeterogeneousGrammarLogitProcessor( HeterogeneousGrammarLogitProcessor(
tokenizer, device, grammars, grammar_types tokenizer, device, grammars, grammar_types, states_to_token_maps
) )
if any([grammar != "" for grammar in grammars]) if any([grammar != "" for grammar in grammars])
else None else None
@ -308,6 +314,7 @@ class HeterogeneousNextTokenChooser:
self.fsm_grammar_states = fsm_grammar_states self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammars self.grammars = grammars
self.grammar_types = grammar_types self.grammar_types = grammar_types
self.states_to_token_maps = states_to_token_maps
def __call__( def __call__(
self, self,
@ -486,6 +493,7 @@ class HeterogeneousNextTokenChooser:
fsm_grammar_states=( fsm_grammar_states=(
fsm_grammar_states if fsm_grammar_states else [0] * len(pb) fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
), ),
states_to_token_maps=[pb_.states_to_token_maps for pb_ in pb],
) )