mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: address syntactically comments
This commit is contained in:
parent
f94fc831f4
commit
a28ba7212c
@ -374,14 +374,16 @@ def generate_load():
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
n: int,
|
n: int,
|
||||||
**kwargs,
|
seed: Optional[int] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
futures = [
|
futures = [
|
||||||
client.generate(
|
client.generate(
|
||||||
prompt,
|
prompt,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
decoder_input_details=True,
|
decoder_input_details=True,
|
||||||
**kwargs,
|
seed=seed,
|
||||||
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
for _ in range(n)
|
for _ in range(n)
|
||||||
]
|
]
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, List, Dict, Union
|
from typing import Optional, List, Dict, Union
|
||||||
|
|
||||||
@ -511,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
is_json_string = schema.startswith("{") and schema.endswith("}")
|
is_json_string = schema.startswith("{") and schema.endswith("}")
|
||||||
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
||||||
fsm = RegexFSM(regex_string, tokenizer)
|
fsm = RegexFSM(regex_string, tokenizer)
|
||||||
print(f"Compile FSM: {time.time() - start_time}")
|
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||||
return fsm
|
return fsm
|
||||||
|
|
||||||
def adapt_tokenizer(self, tokenizer):
|
def adapt_tokenizer(self, tokenizer):
|
||||||
@ -550,7 +551,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars):
|
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars):
|
||||||
for i in range(len(logits)):
|
for i in range(logits.shape[0]):
|
||||||
if fsm_grammar_states[i] == -1:
|
if fsm_grammar_states[i] == -1:
|
||||||
# todo mask for only eos token
|
# todo mask for only eos token
|
||||||
continue
|
continue
|
||||||
@ -584,7 +585,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||||||
is_json_string = schema.startswith("{") and schema.endswith("}")
|
is_json_string = schema.startswith("{") and schema.endswith("}")
|
||||||
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
||||||
fsm = RegexFSM(regex_string, tokenizer)
|
fsm = RegexFSM(regex_string, tokenizer)
|
||||||
# print(f"Compile FSM: {time.time() - start_time}")
|
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||||
return fsm
|
return fsm
|
||||||
|
|
||||||
def adapt_tokenizer(self, tokenizer):
|
def adapt_tokenizer(self, tokenizer):
|
||||||
|
@ -26,19 +26,19 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess
|
|||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
watermark=False,
|
watermark: bool = False,
|
||||||
temperature=1.0,
|
temperature: float = 1.0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty: float = 1.0,
|
||||||
frequency_penalty=0.0,
|
frequency_penalty: float = 0.0,
|
||||||
top_k=None,
|
top_k: Optional[int] = None,
|
||||||
top_p=None,
|
top_p: Optional[float] = None,
|
||||||
typical_p=None,
|
typical_p: Optional[float] = None,
|
||||||
do_sample=False,
|
do_sample: bool = False,
|
||||||
seed=0,
|
seed: int = 0,
|
||||||
device="cpu",
|
device: str = "cpu",
|
||||||
tokenizer=None,
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||||
grammar=None,
|
grammar: str = "",
|
||||||
fsm_grammar_state=None,
|
fsm_grammar_state: Optional[DefaultDict[int, int]] = None,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -54,9 +54,7 @@ class NextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
GrammarLogitProcessor(tokenizer, device)
|
GrammarLogitProcessor(tokenizer, device) if grammar != "" else None
|
||||||
if grammar and grammar != ""
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
@ -438,6 +436,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.grammars = new_grammars
|
self.grammars = new_grammars
|
||||||
self.fsm_grammar_states = new_fsm_grammar_states
|
self.fsm_grammar_states = new_fsm_grammar_states
|
||||||
|
|
||||||
|
if any(self.do_sample):
|
||||||
|
self.choice.filter(indices)
|
||||||
|
else:
|
||||||
self.choice = Greedy()
|
self.choice = Greedy()
|
||||||
|
|
||||||
return self
|
return self
|
||||||
@ -486,6 +487,7 @@ class Greedy:
|
|||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
return logits.argmax(dim=-1)
|
return logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousSampling:
|
class HeterogeneousSampling:
|
||||||
r"""
|
r"""
|
||||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||||
|
Loading…
Reference in New Issue
Block a user