mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
feat: prefer precompiled grammar
This commit is contained in:
parent
4f7074ca71
commit
c52a0f679e
@ -19,8 +19,11 @@ impl Client {
|
|||||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
let channel = Channel::builder(uri).connect().await?;
|
let channel = Channel::builder(uri).connect().await?;
|
||||||
|
|
||||||
|
let limit = 100 * 1024 * 1024; // 100MB
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
stub: TextGenerationServiceClient::new(channel),
|
stub: TextGenerationServiceClient::new(channel)
|
||||||
|
.max_decoding_message_size(limit)
|
||||||
|
.max_encoding_message_size(limit),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,8 +36,12 @@ impl Client {
|
|||||||
}))
|
}))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let limit = 100 * 1024 * 1024; // 100MB
|
||||||
|
println!("limit: {}", limit);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
stub: TextGenerationServiceClient::new(channel),
|
stub: TextGenerationServiceClient::new(channel)
|
||||||
|
.max_decoding_message_size(limit)
|
||||||
|
.max_encoding_message_size(limit),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -397,10 +397,23 @@ impl Validation {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
|
|
||||||
|
// flatten the BTreeMap<u32, BTreeMap<u32, u32>> to 3 Vec<u32> into 4 vectors (start_states, tokens, end_states, offsets)
|
||||||
|
let mut start_states = vec![];
|
||||||
|
let mut tokens = vec![];
|
||||||
|
let mut end_states = vec![];
|
||||||
|
|
||||||
|
for (start_state, token_map) in _states_to_token_maps.iter() {
|
||||||
|
for (token, end_state) in token_map.iter() {
|
||||||
|
start_states.push(*start_state);
|
||||||
|
tokens.push(*token);
|
||||||
|
end_states.push(*end_state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let stm = StatesToTokenMaps {
|
let stm = StatesToTokenMaps {
|
||||||
start_states: vec![],
|
start_states,
|
||||||
tokens: vec![],
|
tokens,
|
||||||
end_states: vec![],
|
end_states,
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
(
|
||||||
|
@ -206,11 +206,20 @@ def serve(
|
|||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
max_send_message_size = 100 * 1024 * 1024 # 100 MB
|
||||||
|
max_receive_message_size = 100 * 1024 * 1024 # 100 MB
|
||||||
|
|
||||||
|
server_options = [
|
||||||
|
("grpc.max_send_message_length", max_send_message_size),
|
||||||
|
("grpc.max_receive_message_length", max_receive_message_size),
|
||||||
|
]
|
||||||
|
|
||||||
server = aio.server(
|
server = aio.server(
|
||||||
|
options=server_options,
|
||||||
interceptors=[
|
interceptors=[
|
||||||
ExceptionInterceptor(),
|
ExceptionInterceptor(),
|
||||||
UDSOpenTelemetryAioServerInterceptor(),
|
UDSOpenTelemetryAioServerInterceptor(),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||||
TextGenerationService(model, Cache(), quantize, server_urls), server
|
TextGenerationService(model, Cache(), quantize, server_urls), server
|
||||||
|
@ -475,9 +475,19 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
fsm_state: DefaultDict[int, int]
|
fsm_state: DefaultDict[int, int]
|
||||||
fsm: RegexFSM
|
fsm: RegexFSM
|
||||||
|
|
||||||
def __init__(self, tokenizer, device, grammar, grammar_type):
|
def __init__(self, tokenizer, device, grammar, grammar_type, states_to_token_maps):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
# TODO: use the precompiled grammar here
|
||||||
|
self.states_to_token_maps = states_to_token_maps
|
||||||
|
precompiled_grammar = RegexFSM.precompiled(
|
||||||
|
states_to_token_maps=states_to_token_maps,
|
||||||
|
empty_token_ids=None,
|
||||||
|
vocabulary=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||||
grammar_type, grammar, self.tokenizer
|
grammar_type, grammar, self.tokenizer
|
||||||
)
|
)
|
||||||
@ -550,14 +560,37 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
|
|
||||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
def __init__(self, tokenizer, device, grammars, grammar_types):
|
def __init__(
|
||||||
|
self, tokenizer, device, grammars, grammar_types, states_to_token_maps
|
||||||
|
):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||||
self.fsms = []
|
self.fsms = []
|
||||||
|
|
||||||
for grammar, grammar_type in zip(grammars, grammar_types):
|
for grammar, grammar_type in zip(grammars, grammar_types):
|
||||||
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
start_states = states_to_token_maps[0].start_states
|
||||||
grammar_type, grammar, self.tokenizer
|
tokens = states_to_token_maps[0].tokens
|
||||||
|
end_states = states_to_token_maps[0].end_states
|
||||||
|
|
||||||
|
_states_to_token_maps = {}
|
||||||
|
for i in range(len(start_states)):
|
||||||
|
if start_states[i] in _states_to_token_maps:
|
||||||
|
_states_to_token_maps[start_states[i]][tokens[i]] = end_states[i]
|
||||||
|
else:
|
||||||
|
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
|
||||||
|
|
||||||
|
# TODO: cleanup how precompiled grammars are handled
|
||||||
|
precompiled_grammar = RegexFSM.precompiled(
|
||||||
|
states_to_token_maps=_states_to_token_maps,
|
||||||
|
empty_token_ids=None,
|
||||||
|
vocabulary=list(tokenizer.get_vocab().values()),
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
)
|
)
|
||||||
|
# fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||||
|
# grammar_type, grammar, self.tokenizer
|
||||||
|
# )
|
||||||
|
|
||||||
|
fsm = precompiled_grammar
|
||||||
self.fsms.append(fsm)
|
self.fsms.append(fsm)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -39,6 +39,7 @@ class NextTokenChooser:
|
|||||||
grammar: str = "",
|
grammar: str = "",
|
||||||
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
||||||
fsm_grammar_state: int = 0,
|
fsm_grammar_state: int = 0,
|
||||||
|
states_to_token_maps: List[List[int]] = None,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -54,7 +55,9 @@ class NextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
|
GrammarLogitProcessor(
|
||||||
|
tokenizer, device, grammar, grammar_type, states_to_token_maps
|
||||||
|
)
|
||||||
if grammar != ""
|
if grammar != ""
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
@ -78,6 +81,7 @@ class NextTokenChooser:
|
|||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
self.fsm_grammar_state = fsm_grammar_state
|
self.fsm_grammar_state = fsm_grammar_state
|
||||||
self.grammar = grammar
|
self.grammar = grammar
|
||||||
|
self.states_to_token_maps = states_to_token_maps
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -126,6 +130,7 @@ class NextTokenChooser:
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammar=pb.grammar,
|
grammar=pb.grammar,
|
||||||
grammar_type=pb.grammar_type,
|
grammar_type=pb.grammar_type,
|
||||||
|
states_to_token_maps=pb.states_to_token_maps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -233,7 +238,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
grammars: List[str],
|
grammars: List[str],
|
||||||
grammar_types: List[int],
|
grammar_types: List[int],
|
||||||
fsm_grammar_states=List[int],
|
fsm_grammar_states: List[int],
|
||||||
|
states_to_token_maps: List[List[List[int]]],
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
|
|
||||||
@ -267,7 +273,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
HeterogeneousGrammarLogitProcessor(
|
HeterogeneousGrammarLogitProcessor(
|
||||||
tokenizer, device, grammars, grammar_types
|
tokenizer, device, grammars, grammar_types, states_to_token_maps
|
||||||
)
|
)
|
||||||
if any([grammar != "" for grammar in grammars])
|
if any([grammar != "" for grammar in grammars])
|
||||||
else None
|
else None
|
||||||
@ -308,6 +314,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.fsm_grammar_states = fsm_grammar_states
|
self.fsm_grammar_states = fsm_grammar_states
|
||||||
self.grammars = grammars
|
self.grammars = grammars
|
||||||
self.grammar_types = grammar_types
|
self.grammar_types = grammar_types
|
||||||
|
self.states_to_token_maps = states_to_token_maps
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -486,6 +493,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
fsm_grammar_states=(
|
fsm_grammar_states=(
|
||||||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||||
),
|
),
|
||||||
|
states_to_token_maps=[pb_.states_to_token_maps for pb_ in pb],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user