diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 86944c8e..f6c239f2 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -374,14 +374,16 @@ def generate_load(): prompt: str, max_new_tokens: int, n: int, - **kwargs, + seed: Optional[int] = None, + grammar: Optional[str] = None, ) -> List[Response]: futures = [ client.generate( prompt, max_new_tokens=max_new_tokens, decoder_input_details=True, - **kwargs, + seed=seed, + grammar=grammar, ) for _ in range(n) ] diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 3116064c..97086928 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -1,6 +1,7 @@ import math import torch +from loguru import logger from functools import lru_cache from typing import Optional, List, Dict, Union @@ -511,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor): is_json_string = schema.startswith("{") and schema.endswith("}") regex_string = build_regex_from_object(schema) if is_json_string else schema fsm = RegexFSM(regex_string, tokenizer) - print(f"Compile FSM: {time.time() - start_time}") + logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s") return fsm def adapt_tokenizer(self, tokenizer): @@ -550,7 +551,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): self.tokenizer = tokenizer def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars): - for i in range(len(logits)): + for i in range(logits.shape[0]): if fsm_grammar_states[i] == -1: # todo mask for only eos token continue @@ -584,7 +585,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): is_json_string = schema.startswith("{") and schema.endswith("}") regex_string = build_regex_from_object(schema) if is_json_string else schema fsm = RegexFSM(regex_string, tokenizer) - # print(f"Compile FSM: {time.time() - start_time}") + logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s") return fsm def adapt_tokenizer(self, tokenizer): diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index e06afcb6..b056a522 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -26,19 +26,19 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess class NextTokenChooser: def __init__( self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - frequency_penalty=0.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", - tokenizer=None, - grammar=None, - fsm_grammar_state=None, + watermark: bool = False, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + frequency_penalty: float = 0.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + do_sample: bool = False, + seed: int = 0, + device: str = "cpu", + tokenizer: Optional[PreTrainedTokenizerBase] = None, + grammar: str = "", + fsm_grammar_state: Optional[DefaultDict[int, int]] = None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -54,9 +54,7 @@ class NextTokenChooser: else None ) self.grammar_processor = ( - GrammarLogitProcessor(tokenizer, device) - if grammar and grammar != "" - else None + GrammarLogitProcessor(tokenizer, device) if grammar != "" else None ) self.tokenizer = tokenizer @@ -438,7 +436,10 @@ class HeterogeneousNextTokenChooser: self.grammars = new_grammars self.fsm_grammar_states = new_fsm_grammar_states - self.choice = Greedy() + if any(self.do_sample): + self.choice.filter(indices) + else: + self.choice = Greedy() return self @@ -486,6 +487,7 @@ class Greedy: def __call__(self, logits): return logits.argmax(dim=-1) + class HeterogeneousSampling: r""" Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.