diff --git a/router/client/src/client.rs b/router/client/src/client.rs index ce04bcca0..2fc06630b 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -19,8 +19,11 @@ impl Client { pub async fn connect(uri: Uri) -> Result { let channel = Channel::builder(uri).connect().await?; + let limit = 100 * 1024 * 1024; // 100MB Ok(Self { - stub: TextGenerationServiceClient::new(channel), + stub: TextGenerationServiceClient::new(channel) + .max_decoding_message_size(limit) + .max_encoding_message_size(limit), }) } @@ -33,8 +36,12 @@ impl Client { })) .await?; + let limit = 100 * 1024 * 1024; // 100MB + println!("limit: {}", limit); Ok(Self { - stub: TextGenerationServiceClient::new(channel), + stub: TextGenerationServiceClient::new(channel) + .max_decoding_message_size(limit) + .max_encoding_message_size(limit), }) } diff --git a/router/src/validation.rs b/router/src/validation.rs index 8ac86136d..b16b5048d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -397,10 +397,23 @@ impl Validation { .await .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + // flatten the BTreeMap> to 3 Vec into 4 vectors (start_states, tokens, end_states, offsets) + let mut start_states = vec![]; + let mut tokens = vec![]; + let mut end_states = vec![]; + + for (start_state, token_map) in _states_to_token_maps.iter() { + for (token, end_state) in token_map.iter() { + start_states.push(*start_state); + tokens.push(*token); + end_states.push(*end_state); + } + } + let stm = StatesToTokenMaps { - start_states: vec![], - tokens: vec![], - end_states: vec![], + start_states, + tokens, + end_states, }; ( diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index d5adbd32a..f0c17f678 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -206,11 +206,20 @@ def serve( logger.exception("Error when initializing model") raise + max_send_message_size = 100 * 1024 * 1024 # 100 MB + max_receive_message_size = 100 * 1024 * 1024 # 100 MB + + server_options = [ + ("grpc.max_send_message_length", max_send_message_size), + ("grpc.max_receive_message_length", max_receive_message_size), + ] + server = aio.server( + options=server_options, interceptors=[ ExceptionInterceptor(), UDSOpenTelemetryAioServerInterceptor(), - ] + ], ) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( TextGenerationService(model, Cache(), quantize, server_urls), server diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index b4ffb863f..b3bfb3f30 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -475,9 +475,19 @@ class GrammarLogitProcessor(LogitsProcessor): fsm_state: DefaultDict[int, int] fsm: RegexFSM - def __init__(self, tokenizer, device, grammar, grammar_type): + def __init__(self, tokenizer, device, grammar, grammar_type, states_to_token_maps): self.device = device self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) + + # TODO: use the precompiled grammar here + self.states_to_token_maps = states_to_token_maps + precompiled_grammar = RegexFSM.precompiled( + states_to_token_maps=states_to_token_maps, + empty_token_ids=None, + vocabulary=None, + eos_token_id=None, + ) + self.fsm = GrammarLogitProcessor._cached_compile_fsm( grammar_type, grammar, self.tokenizer ) @@ -550,14 +560,37 @@ class GrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor): - def __init__(self, tokenizer, device, grammars, grammar_types): + def __init__( + self, tokenizer, device, grammars, grammar_types, states_to_token_maps + ): self.device = device self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.fsms = [] + for grammar, grammar_type in zip(grammars, grammar_types): - fsm = GrammarLogitProcessor._cached_compile_fsm( - grammar_type, grammar, self.tokenizer + start_states = states_to_token_maps[0].start_states + tokens = states_to_token_maps[0].tokens + end_states = states_to_token_maps[0].end_states + + _states_to_token_maps = {} + for i in range(len(start_states)): + if start_states[i] in _states_to_token_maps: + _states_to_token_maps[start_states[i]][tokens[i]] = end_states[i] + else: + _states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]} + + # TODO: cleanup how precompiled grammars are handled + precompiled_grammar = RegexFSM.precompiled( + states_to_token_maps=_states_to_token_maps, + empty_token_ids=None, + vocabulary=list(tokenizer.get_vocab().values()), + eos_token_id=self.tokenizer.eos_token_id, ) + # fsm = GrammarLogitProcessor._cached_compile_fsm( + # grammar_type, grammar, self.tokenizer + # ) + + fsm = precompiled_grammar self.fsms.append(fsm) def __call__( diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 7c8a18f02..ea8fed864 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -39,6 +39,7 @@ class NextTokenChooser: grammar: str = "", grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, fsm_grammar_state: int = 0, + states_to_token_maps: List[List[int]] = None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -54,7 +55,9 @@ class NextTokenChooser: else None ) self.grammar_processor = ( - GrammarLogitProcessor(tokenizer, device, grammar, grammar_type) + GrammarLogitProcessor( + tokenizer, device, grammar, grammar_type, states_to_token_maps + ) if grammar != "" else None ) @@ -78,6 +81,7 @@ class NextTokenChooser: self.choice = Sampling(seed, device) if sampling else Greedy() self.fsm_grammar_state = fsm_grammar_state self.grammar = grammar + self.states_to_token_maps = states_to_token_maps def __call__(self, input_ids, scores): if self.watermark_processor is not None: @@ -126,6 +130,7 @@ class NextTokenChooser: tokenizer=tokenizer, grammar=pb.grammar, grammar_type=pb.grammar_type, + states_to_token_maps=pb.states_to_token_maps, ) @@ -233,7 +238,8 @@ class HeterogeneousNextTokenChooser: tokenizer: PreTrainedTokenizerBase, grammars: List[str], grammar_types: List[int], - fsm_grammar_states=List[int], + fsm_grammar_states: List[int], + states_to_token_maps: List[List[List[int]]], ): warpers = [] @@ -267,7 +273,7 @@ class HeterogeneousNextTokenChooser: self.grammar_processor = ( HeterogeneousGrammarLogitProcessor( - tokenizer, device, grammars, grammar_types + tokenizer, device, grammars, grammar_types, states_to_token_maps ) if any([grammar != "" for grammar in grammars]) else None @@ -308,6 +314,7 @@ class HeterogeneousNextTokenChooser: self.fsm_grammar_states = fsm_grammar_states self.grammars = grammars self.grammar_types = grammar_types + self.states_to_token_maps = states_to_token_maps def __call__( self, @@ -486,6 +493,7 @@ class HeterogeneousNextTokenChooser: fsm_grammar_states=( fsm_grammar_states if fsm_grammar_states else [0] * len(pb) ), + states_to_token_maps=[pb_.states_to_token_maps for pb_ in pb], )