feat: address syntactically comments

This commit is contained in:
drbh 2024-02-12 11:08:35 -05:00
parent f94fc831f4
commit a28ba7212c
3 changed files with 27 additions and 22 deletions

View File

@ -374,14 +374,16 @@ def generate_load():
prompt: str, prompt: str,
max_new_tokens: int, max_new_tokens: int,
n: int, n: int,
**kwargs, seed: Optional[int] = None,
grammar: Optional[str] = None,
) -> List[Response]: ) -> List[Response]:
futures = [ futures = [
client.generate( client.generate(
prompt, prompt,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
decoder_input_details=True, decoder_input_details=True,
**kwargs, seed=seed,
grammar=grammar,
) )
for _ in range(n) for _ in range(n)
] ]

View File

@ -1,6 +1,7 @@
import math import math
import torch import torch
from loguru import logger
from functools import lru_cache from functools import lru_cache
from typing import Optional, List, Dict, Union from typing import Optional, List, Dict, Union
@ -511,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
is_json_string = schema.startswith("{") and schema.endswith("}") is_json_string = schema.startswith("{") and schema.endswith("}")
regex_string = build_regex_from_object(schema) if is_json_string else schema regex_string = build_regex_from_object(schema) if is_json_string else schema
fsm = RegexFSM(regex_string, tokenizer) fsm = RegexFSM(regex_string, tokenizer)
print(f"Compile FSM: {time.time() - start_time}") logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm return fsm
def adapt_tokenizer(self, tokenizer): def adapt_tokenizer(self, tokenizer):
@ -550,7 +551,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
self.tokenizer = tokenizer self.tokenizer = tokenizer
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars): def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars):
for i in range(len(logits)): for i in range(logits.shape[0]):
if fsm_grammar_states[i] == -1: if fsm_grammar_states[i] == -1:
# todo mask for only eos token # todo mask for only eos token
continue continue
@ -584,7 +585,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
is_json_string = schema.startswith("{") and schema.endswith("}") is_json_string = schema.startswith("{") and schema.endswith("}")
regex_string = build_regex_from_object(schema) if is_json_string else schema regex_string = build_regex_from_object(schema) if is_json_string else schema
fsm = RegexFSM(regex_string, tokenizer) fsm = RegexFSM(regex_string, tokenizer)
# print(f"Compile FSM: {time.time() - start_time}") logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm return fsm
def adapt_tokenizer(self, tokenizer): def adapt_tokenizer(self, tokenizer):

View File

@ -26,19 +26,19 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
watermark=False, watermark: bool = False,
temperature=1.0, temperature: float = 1.0,
repetition_penalty=1.0, repetition_penalty: float = 1.0,
frequency_penalty=0.0, frequency_penalty: float = 0.0,
top_k=None, top_k: Optional[int] = None,
top_p=None, top_p: Optional[float] = None,
typical_p=None, typical_p: Optional[float] = None,
do_sample=False, do_sample: bool = False,
seed=0, seed: int = 0,
device="cpu", device: str = "cpu",
tokenizer=None, tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar=None, grammar: str = "",
fsm_grammar_state=None, fsm_grammar_state: Optional[DefaultDict[int, int]] = None,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -54,9 +54,7 @@ class NextTokenChooser:
else None else None
) )
self.grammar_processor = ( self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device) GrammarLogitProcessor(tokenizer, device) if grammar != "" else None
if grammar and grammar != ""
else None
) )
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -438,7 +436,10 @@ class HeterogeneousNextTokenChooser:
self.grammars = new_grammars self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states self.fsm_grammar_states = new_fsm_grammar_states
self.choice = Greedy() if any(self.do_sample):
self.choice.filter(indices)
else:
self.choice = Greedy()
return self return self
@ -486,6 +487,7 @@ class Greedy:
def __call__(self, logits): def __call__(self, logits):
return logits.argmax(dim=-1) return logits.argmax(dim=-1)
class HeterogeneousSampling: class HeterogeneousSampling:
r""" r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.