feat: address syntactically comments

This commit is contained in:
drbh 2024-02-12 11:08:35 -05:00
parent f94fc831f4
commit a28ba7212c
3 changed files with 27 additions and 22 deletions

View File

@ -374,14 +374,16 @@ def generate_load():
prompt: str,
max_new_tokens: int,
n: int,
**kwargs,
seed: Optional[int] = None,
grammar: Optional[str] = None,
) -> List[Response]:
futures = [
client.generate(
prompt,
max_new_tokens=max_new_tokens,
decoder_input_details=True,
**kwargs,
seed=seed,
grammar=grammar,
)
for _ in range(n)
]

View File

@ -1,6 +1,7 @@
import math
import torch
from loguru import logger
from functools import lru_cache
from typing import Optional, List, Dict, Union
@ -511,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
is_json_string = schema.startswith("{") and schema.endswith("}")
regex_string = build_regex_from_object(schema) if is_json_string else schema
fsm = RegexFSM(regex_string, tokenizer)
print(f"Compile FSM: {time.time() - start_time}")
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
def adapt_tokenizer(self, tokenizer):
@ -550,7 +551,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars):
for i in range(len(logits)):
for i in range(logits.shape[0]):
if fsm_grammar_states[i] == -1:
# todo mask for only eos token
continue
@ -584,7 +585,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
is_json_string = schema.startswith("{") and schema.endswith("}")
regex_string = build_regex_from_object(schema) if is_json_string else schema
fsm = RegexFSM(regex_string, tokenizer)
# print(f"Compile FSM: {time.time() - start_time}")
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
def adapt_tokenizer(self, tokenizer):

View File

@ -26,19 +26,19 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess
class NextTokenChooser:
def __init__(
self,
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
frequency_penalty=0.0,
top_k=None,
top_p=None,
typical_p=None,
do_sample=False,
seed=0,
device="cpu",
tokenizer=None,
grammar=None,
fsm_grammar_state=None,
watermark: bool = False,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
frequency_penalty: float = 0.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
do_sample: bool = False,
seed: int = 0,
device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "",
fsm_grammar_state: Optional[DefaultDict[int, int]] = None,
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@ -54,9 +54,7 @@ class NextTokenChooser:
else None
)
self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device)
if grammar and grammar != ""
else None
GrammarLogitProcessor(tokenizer, device) if grammar != "" else None
)
self.tokenizer = tokenizer
@ -438,7 +436,10 @@ class HeterogeneousNextTokenChooser:
self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states
self.choice = Greedy()
if any(self.do_sample):
self.choice.filter(indices)
else:
self.choice = Greedy()
return self
@ -486,6 +487,7 @@ class Greedy:
def __call__(self, logits):
return logits.argmax(dim=-1)
class HeterogeneousSampling:
r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.