diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index f692a7ec..a0f0c9e8 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -87,7 +87,9 @@ class CausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) + next_token_choosers.append( + NextTokenChooser.from_pb(r.parameters, device, tokenizer) + ) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -413,14 +415,14 @@ class CausalLMBatch(Batch): # We slice the keys to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 if batch.keys_head_dim_last: - padded_past_keys[ - start_index:end_index, :, -past_seq_len:, : - ] = past_keys[:, :, -past_seq_len:, :] + padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( + past_keys[:, :, -past_seq_len:, :] + ) else: # BLOOM case - padded_past_keys[ - start_index:end_index, :, :, -past_seq_len: - ] = past_keys[:, :, :, -past_seq_len:] + padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( + past_keys[:, :, :, -past_seq_len:] + ) del past_keys start_index = end_index @@ -438,9 +440,9 @@ class CausalLMBatch(Batch): end_index = start_index + len(batch) # We slice the past values to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 - padded_past_values[ - start_index:end_index, :, -past_seq_len:, : - ] = past_values[:, :, -past_seq_len:, :] + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( + past_values[:, :, -past_seq_len:, :] + ) del past_values # Update values @@ -504,9 +506,11 @@ class CausalLM(Model): model_id, revision=revision, torch_dtype=dtype, - device_map="auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None, + device_map=( + "auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None + ), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) @@ -696,7 +700,7 @@ class CausalLM(Model): if top_n_tokens > 0: all_top_tokens = [] - for (top_token_ids, top_token_logprobs) in zip( + for top_token_ids, top_token_logprobs in zip( top_token_ids, top_token_logprobs ): toptoken_texts = self.tokenizer.batch_decode( @@ -735,6 +739,9 @@ class CausalLM(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bd91b0e2..8b2206dd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -870,7 +870,11 @@ class FlashCausalLM(Model): # Try to find an associated cuda graph cuda_graph = self.cuda_graphs.get(padded_bs, None) - if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None: + if ( + cu_seqlen_prefill is not None + or cuda_graph is None + or batch.speculative_ids is not None + ): return self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1029,6 +1033,9 @@ class FlashCausalLM(Model): cumulative_length += input_length + batch.next_token_chooser = batch.next_token_chooser.advance_grammar( + next_input_ids + ) batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 613ec8b9..cd43e2a4 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -478,10 +478,8 @@ class GrammarLogitProcessor(LogitsProcessor): def __init__(self, tokenizer, device, grammar): self.device = device - self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer) - self.fsm = GrammarLogitProcessor._cached_compile_fsm( - self, grammar, self.tokenizer - ) + self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) + self.fsm = GrammarLogitProcessor._cached_compile_fsm(grammar, self.tokenizer) def __call__( self, @@ -490,26 +488,26 @@ class GrammarLogitProcessor(LogitsProcessor): ): if fsm_grammar_state == -1 or self.fsm is None: return logits - allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask[allowed_tokens] = 0 biased_scores = logits + mask return biased_scores - def advance(self, next_token_id, fsm_grammar_state, grammar): + def advance(self, next_token_id, fsm_grammar_state): + return GrammarLogitProcessor._advance( + next_token_id, fsm_grammar_state, self.fsm + ) + + @staticmethod + def _advance(next_token_id, fsm_grammar_state, fsm): if fsm_grammar_state == -1: return fsm_grammar_state - - if grammar == "" or grammar is None: - return fsm_grammar_state - - fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer) return fsm.next_state(fsm_grammar_state, next_token_id) @staticmethod @lru_cache(maxsize=32, typed=True) - def _cached_compile_fsm(self, schema, tokenizer): + def _cached_compile_fsm(schema, tokenizer): start_time = time.time() try: json.loads(schema) # check if schema is a valid json @@ -522,7 +520,7 @@ class GrammarLogitProcessor(LogitsProcessor): @staticmethod @lru_cache(maxsize=32, typed=True) - def adapt_tokenizer(tokenizer): + def _cached_adapt_tokenizer(tokenizer): """Adapt tokenizer to work with the FSM. The API of Outlines tokenizers is slightly different to that of @@ -560,10 +558,10 @@ class GrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor): def __init__(self, tokenizer, device, grammars): self.device = device - self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer) + self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.fsms = [ ( - GrammarLogitProcessor._cached_compile_fsm(self, g, self.tokenizer) + GrammarLogitProcessor._cached_compile_fsm(g, self.tokenizer) if g else None ) @@ -586,10 +584,13 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): logits[i] = biased_scores return logits - def advance(self, next_token_ids, fsm_grammar_states, grammars): - return GrammarLogitProcessor.advance( - self, next_token_ids, fsm_grammar_states, grammars - ) + def advance_batch(self, next_token_ids, fsm_grammar_states, grammars): + return [ + GrammarLogitProcessor._advance( + next_token_ids[i], fsm_grammar_states[i], self.fsms[i] + ) + for i in range(len(next_token_ids)) + ] def filter(self, indices): return GrammarLogitProcessor.filter(self, indices) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 360e4fe3..70bde1d1 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,5 +1,5 @@ import re -from typing import List, Optional, Tuple, DefaultDict +from typing import List, Optional, Tuple import torch from text_generation_server.pb import generate_pb2 @@ -92,14 +92,15 @@ class NextTokenChooser: next_id = self.choice(scores[-1]).view(1, 1) - if self.grammar_processor is not None: - next_state = self.grammar_processor.advance( - next_id.item(), self.fsm_grammar_state, self.grammar - ) - self.fsm_grammar_state = next_state - return next_id, next_logprob + def advance_grammar(self, next_id): + if self.grammar_processor is not None: + self.fsm_grammar_state = self.grammar_processor.advance( + next_id, self.fsm_grammar_state + ) + return self + @classmethod def from_pb( cls, @@ -385,15 +386,16 @@ class HeterogeneousNextTokenChooser: else: speculative_ids = None - # advance the grammar state - if self.grammar_processor is not None: - for i in range(len(self.fsm_grammar_states)): - self.fsm_grammar_states[i] = self.grammar_processor.advance( - next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i] - ) - return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids + def advance_grammar(self, next_ids: torch.Tensor): + if self.grammar_processor is not None: + other_new_states = self.grammar_processor.advance_batch( + next_ids.tolist(), self.fsm_grammar_states, self.grammars + ) + self.fsm_grammar_states = other_new_states + return self + def filter(self, indices): if self.watermark_processor is not None: self.watermark_processor = self.watermark_processor.filter(indices)