mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: address syntactically comments
This commit is contained in:
parent
f94fc831f4
commit
a28ba7212c
@ -374,14 +374,16 @@ def generate_load():
|
||||
prompt: str,
|
||||
max_new_tokens: int,
|
||||
n: int,
|
||||
**kwargs,
|
||||
seed: Optional[int] = None,
|
||||
grammar: Optional[str] = None,
|
||||
) -> List[Response]:
|
||||
futures = [
|
||||
client.generate(
|
||||
prompt,
|
||||
max_new_tokens=max_new_tokens,
|
||||
decoder_input_details=True,
|
||||
**kwargs,
|
||||
seed=seed,
|
||||
grammar=grammar,
|
||||
)
|
||||
for _ in range(n)
|
||||
]
|
||||
|
@ -1,6 +1,7 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
from loguru import logger
|
||||
from functools import lru_cache
|
||||
from typing import Optional, List, Dict, Union
|
||||
|
||||
@ -511,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
is_json_string = schema.startswith("{") and schema.endswith("}")
|
||||
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
print(f"Compile FSM: {time.time() - start_time}")
|
||||
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||
return fsm
|
||||
|
||||
def adapt_tokenizer(self, tokenizer):
|
||||
@ -550,7 +551,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars):
|
||||
for i in range(len(logits)):
|
||||
for i in range(logits.shape[0]):
|
||||
if fsm_grammar_states[i] == -1:
|
||||
# todo mask for only eos token
|
||||
continue
|
||||
@ -584,7 +585,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
is_json_string = schema.startswith("{") and schema.endswith("}")
|
||||
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
# print(f"Compile FSM: {time.time() - start_time}")
|
||||
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||
return fsm
|
||||
|
||||
def adapt_tokenizer(self, tokenizer):
|
||||
|
@ -26,19 +26,19 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess
|
||||
class NextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
watermark=False,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
frequency_penalty=0.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
do_sample=False,
|
||||
seed=0,
|
||||
device="cpu",
|
||||
tokenizer=None,
|
||||
grammar=None,
|
||||
fsm_grammar_state=None,
|
||||
watermark: bool = False,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
do_sample: bool = False,
|
||||
seed: int = 0,
|
||||
device: str = "cpu",
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
grammar: str = "",
|
||||
fsm_grammar_state: Optional[DefaultDict[int, int]] = None,
|
||||
):
|
||||
self.watermark_processor = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
@ -54,9 +54,7 @@ class NextTokenChooser:
|
||||
else None
|
||||
)
|
||||
self.grammar_processor = (
|
||||
GrammarLogitProcessor(tokenizer, device)
|
||||
if grammar and grammar != ""
|
||||
else None
|
||||
GrammarLogitProcessor(tokenizer, device) if grammar != "" else None
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@ -438,6 +436,9 @@ class HeterogeneousNextTokenChooser:
|
||||
self.grammars = new_grammars
|
||||
self.fsm_grammar_states = new_fsm_grammar_states
|
||||
|
||||
if any(self.do_sample):
|
||||
self.choice.filter(indices)
|
||||
else:
|
||||
self.choice = Greedy()
|
||||
|
||||
return self
|
||||
@ -486,6 +487,7 @@ class Greedy:
|
||||
def __call__(self, logits):
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
|
||||
class HeterogeneousSampling:
|
||||
r"""
|
||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||
|
Loading…
Reference in New Issue
Block a user