diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 6deae48d..a0ef0fe6 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -45,6 +45,7 @@ pub async fn run( repetition_penalty: repetition_penalty.unwrap_or(1.0), frequency_penalty: frequency_penalty.unwrap_or(0.0), watermark, + fsm_grammar_state: 0, }; // Initialize terminal properties diff --git a/proto/generate.proto b/proto/generate.proto index aae0e7a4..015dd6e9 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -71,7 +71,9 @@ message NextTokenChooserParameters { /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; /// grammar (applied if not empty) - string grammar = 10; + repeated string grammar = 10; + /// fsm_grammar_state + repeated uint32 fsm_grammar_state = 11; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 9822ea77..38e6e0e3 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -129,6 +129,7 @@ impl Client { frequency_penalty: 0.1, watermark: true, grammar: String::new(), + fsm_grammar_state: 0, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/src/health.rs b/router/src/health.rs index 6f3d2023..f3cac17e 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -46,6 +46,7 @@ impl Health { frequency_penalty: 0.0, watermark: false, grammar: String::new(), + fsm_grammar_state: 0, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/queue.rs b/router/src/queue.rs index 3e4aefa1..0162b906 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -369,6 +369,7 @@ mod tests { frequency_penalty: 0.0, watermark: false, grammar: String::new(), + fsm_grammar_state: 0, }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, diff --git a/router/src/validation.rs b/router/src/validation.rs index a77995df..0455411d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -304,6 +304,7 @@ impl Validation { seed, watermark, grammar, + fsm_grammar_state: 0, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 12be2485..c5ce89be 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -99,6 +99,9 @@ class FlashCausalLMBatch(Batch): # Maximum number of blocks max_blocks: int + # The states for the grammar FSM + fsm_states: Dict[int, int] = None + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -137,6 +140,7 @@ class FlashCausalLMBatch(Batch): read_offsets = [] all_input_ids = [] requests_idx_mapping = {} + fsm_states = {} all_prefill_logprobs = True no_prefill_logprobs = True @@ -319,6 +323,7 @@ class FlashCausalLMBatch(Batch): blocks=blocks, max_blocks=max_blocks, speculative_ids=None, + fsm_states=fsm_states, ) @tracer.start_as_current_span("filter") @@ -594,7 +599,6 @@ class FlashCausalLMBatch(Batch): dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, tokenizer=batches[0].next_token_chooser.tokenizer, - grammar=batches[0].requests.parameters.grammar, ) speculative_ids = ( @@ -1015,9 +1019,9 @@ class FlashCausalLM(Model): # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: if len(batch) > 1: - prefill_tokens_indices[ - out_start_index : out_end_index - 1 - ] = batch.input_ids[start_index + 1 : start_index + out_length] + prefill_tokens_indices[out_start_index : out_end_index - 1] = ( + batch.input_ids[start_index + 1 : start_index + out_length] + ) else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = batch.input_ids[ @@ -1168,7 +1172,7 @@ class FlashCausalLM(Model): if top_n_tokens > 0: all_top_tokens = [] - for (top_token_ids, top_token_logprobs) in zip( + for top_token_ids, top_token_logprobs in zip( top_token_ids, top_token_logprobs ): toptoken_texts = self.tokenizer.batch_decode( diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index b4abeb9b..3178c2f2 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -22,11 +22,12 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess from outlines.fsm.fsm import RegexFSM from outlines.fsm.json_schema import build_regex_from_object -from functools import lru_cache +from functools import lru_cache # TODO: remove when done debugging import time + class NextTokenChooser: def __init__( self, @@ -42,6 +43,7 @@ class NextTokenChooser: device="cpu", tokenizer=None, grammar=None, + fsm_grammar_state=None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -73,6 +75,9 @@ class NextTokenChooser: sampling = do_sample or has_warpers + self.fsm_grammar_state = fsm_grammar_state + self.grammars = grammar + # TODO: is grammar a subset of sampling? If so, we should merge them if grammar: self.choice = Grammar(tokenizer, device, grammar) @@ -92,7 +97,9 @@ class NextTokenChooser: else: scores, next_logprob = self.static_warper(scores) - next_id = self.choice(scores[-1]).view(1, 1) + next_id = self.choice(scores[-1], self.fsm_grammar_state, self.grammars).view( + 1, 1 + ) return next_id, next_logprob @@ -116,6 +123,7 @@ class NextTokenChooser: device=device, tokenizer=tokenizer, grammar=pb.grammar, + fsm_grammar_state=pb.fsm_grammar_state, ) @@ -222,6 +230,7 @@ class HeterogeneousNextTokenChooser: seeds: List[int], tokenizer=None, grammar=None, + fsm_grammar_states=None, ): warpers = [] @@ -275,9 +284,8 @@ class HeterogeneousNextTokenChooser: self.warpers = warpers - first_grammar = grammar[0] if grammar else None - if first_grammar: - self.choice = Grammar(tokenizer, device, first_grammar) + if grammar is not None: + self.choice = Grammar(tokenizer, device) elif any(do_sample): self.choice = HeterogeneousSampling(do_sample, seeds, device) else: @@ -288,6 +296,9 @@ class HeterogeneousNextTokenChooser: self.do_sample = do_sample self.dtype = dtype self.device = device + self.tokenizer = tokenizer + self.fsm_grammar_states = fsm_grammar_states + self.grammars = grammar def __call__( self, @@ -320,7 +331,7 @@ class HeterogeneousNextTokenChooser: for warper in self.warpers: _scores = warper(input_ids, _scores) - _next_ids = self.choice(_scores) + _next_ids = self.choice(_scores, self.fsm_grammar_states, self.grammars) scores[:, j] = _scores next_ids[:, j] = _next_ids next_ids = next_ids.view(B * S) @@ -398,7 +409,7 @@ class HeterogeneousNextTokenChooser: self.do_sample = [self.do_sample[i] for i in indices] if self.use_grammar or any(self.do_sample): - self.choice.filter(indices) + self.choice.filter(indices, self.fsm_grammar_states, self.grammars) else: self.choice = Greedy() @@ -426,6 +437,7 @@ class HeterogeneousNextTokenChooser: dtype=dtype, tokenizer=tokenizer, grammar=[pb_.grammar for pb_ in pb], + fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb], ) @@ -444,51 +456,62 @@ class Sampling: class Greedy: - def __call__(self, logits): + def __call__(self, logits, *args): return logits.argmax(dim=-1) + def filter(self, indices, *args): + return self class Grammar: fsm_state: DefaultDict[int, int] fsm: RegexFSM - def __init__(self, tokenizer, device, grammar): - fsm = self.compile_fsm(grammar, tokenizer) - self.fsm = fsm - self.fsm_state = defaultdict(int) + def __init__(self, tokenizer, device): self.device = device + self.tokenizer = tokenizer - def __call__(self, logits): - seq_id = 0 + def __call__(self, logits, fsm_grammar_states, grammars): + empty = torch.ones(logits.shape[0], dtype=torch.int64, device=logits.device) + try: + for i in range(len(fsm_grammar_states)): + if fsm_grammar_states[i] == -1: + continue - if self.fsm_state[seq_id] == -1: - return self.fsm_state[seq_id].eos_token_id + # this is cached and should be fast after the first time + fsm = self.compile_fsm(grammars[i], self.tokenizer) + allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) + mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) + mask[allowed_tokens] = 0 + biased_scores = logits[i : i + 1] + mask + greedy = biased_scores.argmax(dim=-1) - allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) - mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) - mask[allowed_tokens] = 0 - biased_scores = logits + mask + # if greedy is empty, return the eos token + if greedy.shape[0] == 0: + continue - # greedly pick the token with the highest score - greedy = biased_scores.argmax(dim=-1) + # import ipdb; ipdb.set_trace() + fsm_grammar_states[i] = fsm.next_state( + fsm_grammar_states[i], greedy.item() + ) + + empty[i] = greedy.item() + except Exception as e: + print(f"Exception: {e}") + import ipdb + + ipdb.set_trace() + return empty - # now update the fsm state - self.fsm_state[seq_id] = self.fsm.next_state( - self.fsm_state[seq_id], greedy.item() - ) - return greedy - @lru_cache(maxsize=32, typed=True) def compile_fsm(self, schema, tokenizer): start_time = time.time() tokenizer = self.adapt_tokenizer(tokenizer) - is_json_string = schema.startswith("{") and schema.endswith("}") + is_json_string = schema.startswith("{") and schema.endswith("}") regex_string = build_regex_from_object(schema) if is_json_string else schema fsm = RegexFSM(regex_string, tokenizer) print(f"Compile FSM: {time.time() - start_time}") return fsm - def adapt_tokenizer(self, tokenizer): """Adapt tokenizer to work with the FSM. @@ -515,6 +538,18 @@ class Grammar: return tokenizer + def filter(self, indices, fsm_grammar_states, grammars): + new_fsm_grammar_states = [] + new_grammars = [] + + for i in indices: + new_fsm_grammar_states.append(fsm_grammar_states[i]) + new_grammars.append(grammars[i]) + + self.fsm_state = new_fsm_grammar_states + self.fsm = new_grammars + return self + class HeterogeneousSampling: r""" @@ -534,7 +569,7 @@ class HeterogeneousSampling: self.greedy = Greedy() - def __call__(self, logits): + def __call__(self, logits, fsm_grammar_states, grammars): out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) if self.greedy_indices: # Computing for all indices is faster than slicing