mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support grammars in batch
This commit is contained in:
parent
8fd2664a3c
commit
ff6e8d9e23
@ -45,6 +45,7 @@ pub async fn run(
|
|||||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||||
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
||||||
watermark,
|
watermark,
|
||||||
|
fsm_grammar_state: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize terminal properties
|
// Initialize terminal properties
|
||||||
|
@ -71,7 +71,9 @@ message NextTokenChooserParameters {
|
|||||||
/// token watermarking using "A Watermark for Large Language Models"
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
/// grammar (applied if not empty)
|
/// grammar (applied if not empty)
|
||||||
string grammar = 10;
|
repeated string grammar = 10;
|
||||||
|
/// fsm_grammar_state
|
||||||
|
repeated uint32 fsm_grammar_state = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
@ -129,6 +129,7 @@ impl Client {
|
|||||||
frequency_penalty: 0.1,
|
frequency_penalty: 0.1,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
|
fsm_grammar_state: 0,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
@ -46,6 +46,7 @@ impl Health {
|
|||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
|
fsm_grammar_state: 0,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -369,6 +369,7 @@ mod tests {
|
|||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
|
fsm_grammar_state: 0,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
|
@ -304,6 +304,7 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
grammar,
|
grammar,
|
||||||
|
fsm_grammar_state: 0,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
@ -99,6 +99,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Maximum number of blocks
|
# Maximum number of blocks
|
||||||
max_blocks: int
|
max_blocks: int
|
||||||
|
|
||||||
|
# The states for the grammar FSM
|
||||||
|
fsm_states: Dict[int, int] = None
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
@ -137,6 +140,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
fsm_states = {}
|
||||||
|
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
no_prefill_logprobs = True
|
no_prefill_logprobs = True
|
||||||
@ -319,6 +323,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=None,
|
speculative_ids=None,
|
||||||
|
fsm_states=fsm_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
@ -594,7 +599,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
dtype=batches[0].next_token_chooser.dtype,
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
tokenizer=batches[0].next_token_chooser.tokenizer,
|
tokenizer=batches[0].next_token_chooser.tokenizer,
|
||||||
grammar=batches[0].requests.parameters.grammar,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
@ -1015,9 +1019,9 @@ class FlashCausalLM(Model):
|
|||||||
# Copy batch.input_ids to prefill_token_indices
|
# Copy batch.input_ids to prefill_token_indices
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
prefill_tokens_indices[
|
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
||||||
out_start_index : out_end_index - 1
|
batch.input_ids[start_index + 1 : start_index + out_length]
|
||||||
] = batch.input_ids[start_index + 1 : start_index + out_length]
|
)
|
||||||
else:
|
else:
|
||||||
# Set prefill_tokens_indices to the correct slice
|
# Set prefill_tokens_indices to the correct slice
|
||||||
prefill_tokens_indices = batch.input_ids[
|
prefill_tokens_indices = batch.input_ids[
|
||||||
@ -1168,7 +1172,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
all_top_tokens = []
|
all_top_tokens = []
|
||||||
for (top_token_ids, top_token_logprobs) in zip(
|
for top_token_ids, top_token_logprobs in zip(
|
||||||
top_token_ids, top_token_logprobs
|
top_token_ids, top_token_logprobs
|
||||||
):
|
):
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
|
@ -27,6 +27,7 @@ from functools import lru_cache
|
|||||||
# TODO: remove when done debugging
|
# TODO: remove when done debugging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -42,6 +43,7 @@ class NextTokenChooser:
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
grammar=None,
|
grammar=None,
|
||||||
|
fsm_grammar_state=None,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -73,6 +75,9 @@ class NextTokenChooser:
|
|||||||
|
|
||||||
sampling = do_sample or has_warpers
|
sampling = do_sample or has_warpers
|
||||||
|
|
||||||
|
self.fsm_grammar_state = fsm_grammar_state
|
||||||
|
self.grammars = grammar
|
||||||
|
|
||||||
# TODO: is grammar a subset of sampling? If so, we should merge them
|
# TODO: is grammar a subset of sampling? If so, we should merge them
|
||||||
if grammar:
|
if grammar:
|
||||||
self.choice = Grammar(tokenizer, device, grammar)
|
self.choice = Grammar(tokenizer, device, grammar)
|
||||||
@ -92,7 +97,9 @@ class NextTokenChooser:
|
|||||||
else:
|
else:
|
||||||
scores, next_logprob = self.static_warper(scores)
|
scores, next_logprob = self.static_warper(scores)
|
||||||
|
|
||||||
next_id = self.choice(scores[-1]).view(1, 1)
|
next_id = self.choice(scores[-1], self.fsm_grammar_state, self.grammars).view(
|
||||||
|
1, 1
|
||||||
|
)
|
||||||
|
|
||||||
return next_id, next_logprob
|
return next_id, next_logprob
|
||||||
|
|
||||||
@ -116,6 +123,7 @@ class NextTokenChooser:
|
|||||||
device=device,
|
device=device,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammar=pb.grammar,
|
grammar=pb.grammar,
|
||||||
|
fsm_grammar_state=pb.fsm_grammar_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -222,6 +230,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
grammar=None,
|
grammar=None,
|
||||||
|
fsm_grammar_states=None,
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
|
|
||||||
@ -275,9 +284,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
self.warpers = warpers
|
self.warpers = warpers
|
||||||
|
|
||||||
first_grammar = grammar[0] if grammar else None
|
if grammar is not None:
|
||||||
if first_grammar:
|
self.choice = Grammar(tokenizer, device)
|
||||||
self.choice = Grammar(tokenizer, device, first_grammar)
|
|
||||||
elif any(do_sample):
|
elif any(do_sample):
|
||||||
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
||||||
else:
|
else:
|
||||||
@ -288,6 +296,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.fsm_grammar_states = fsm_grammar_states
|
||||||
|
self.grammars = grammar
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -320,7 +331,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
|
|
||||||
_next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores, self.fsm_grammar_states, self.grammars)
|
||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
next_ids = next_ids.view(B * S)
|
next_ids = next_ids.view(B * S)
|
||||||
@ -398,7 +409,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.do_sample = [self.do_sample[i] for i in indices]
|
self.do_sample = [self.do_sample[i] for i in indices]
|
||||||
|
|
||||||
if self.use_grammar or any(self.do_sample):
|
if self.use_grammar or any(self.do_sample):
|
||||||
self.choice.filter(indices)
|
self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
|
||||||
else:
|
else:
|
||||||
self.choice = Greedy()
|
self.choice = Greedy()
|
||||||
|
|
||||||
@ -426,6 +437,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammar=[pb_.grammar for pb_ in pb],
|
grammar=[pb_.grammar for pb_ in pb],
|
||||||
|
fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -444,39 +456,51 @@ class Sampling:
|
|||||||
|
|
||||||
|
|
||||||
class Greedy:
|
class Greedy:
|
||||||
def __call__(self, logits):
|
def __call__(self, logits, *args):
|
||||||
return logits.argmax(dim=-1)
|
return logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
def filter(self, indices, *args):
|
||||||
|
return self
|
||||||
|
|
||||||
class Grammar:
|
class Grammar:
|
||||||
fsm_state: DefaultDict[int, int]
|
fsm_state: DefaultDict[int, int]
|
||||||
fsm: RegexFSM
|
fsm: RegexFSM
|
||||||
|
|
||||||
def __init__(self, tokenizer, device, grammar):
|
def __init__(self, tokenizer, device):
|
||||||
fsm = self.compile_fsm(grammar, tokenizer)
|
|
||||||
self.fsm = fsm
|
|
||||||
self.fsm_state = defaultdict(int)
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits, fsm_grammar_states, grammars):
|
||||||
seq_id = 0
|
empty = torch.ones(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||||
|
try:
|
||||||
|
for i in range(len(fsm_grammar_states)):
|
||||||
|
if fsm_grammar_states[i] == -1:
|
||||||
|
continue
|
||||||
|
|
||||||
if self.fsm_state[seq_id] == -1:
|
# this is cached and should be fast after the first time
|
||||||
return self.fsm_state[seq_id].eos_token_id
|
fsm = self.compile_fsm(grammars[i], self.tokenizer)
|
||||||
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||||
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
|
|
||||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||||
mask[allowed_tokens] = 0
|
mask[allowed_tokens] = 0
|
||||||
biased_scores = logits + mask
|
biased_scores = logits[i : i + 1] + mask
|
||||||
|
|
||||||
# greedly pick the token with the highest score
|
|
||||||
greedy = biased_scores.argmax(dim=-1)
|
greedy = biased_scores.argmax(dim=-1)
|
||||||
|
|
||||||
# now update the fsm state
|
# if greedy is empty, return the eos token
|
||||||
self.fsm_state[seq_id] = self.fsm.next_state(
|
if greedy.shape[0] == 0:
|
||||||
self.fsm_state[seq_id], greedy.item()
|
continue
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
fsm_grammar_states[i] = fsm.next_state(
|
||||||
|
fsm_grammar_states[i], greedy.item()
|
||||||
)
|
)
|
||||||
return greedy
|
|
||||||
|
empty[i] = greedy.item()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception: {e}")
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
return empty
|
||||||
|
|
||||||
@lru_cache(maxsize=32, typed=True)
|
@lru_cache(maxsize=32, typed=True)
|
||||||
def compile_fsm(self, schema, tokenizer):
|
def compile_fsm(self, schema, tokenizer):
|
||||||
@ -488,7 +512,6 @@ class Grammar:
|
|||||||
print(f"Compile FSM: {time.time() - start_time}")
|
print(f"Compile FSM: {time.time() - start_time}")
|
||||||
return fsm
|
return fsm
|
||||||
|
|
||||||
|
|
||||||
def adapt_tokenizer(self, tokenizer):
|
def adapt_tokenizer(self, tokenizer):
|
||||||
"""Adapt tokenizer to work with the FSM.
|
"""Adapt tokenizer to work with the FSM.
|
||||||
|
|
||||||
@ -515,6 +538,18 @@ class Grammar:
|
|||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
def filter(self, indices, fsm_grammar_states, grammars):
|
||||||
|
new_fsm_grammar_states = []
|
||||||
|
new_grammars = []
|
||||||
|
|
||||||
|
for i in indices:
|
||||||
|
new_fsm_grammar_states.append(fsm_grammar_states[i])
|
||||||
|
new_grammars.append(grammars[i])
|
||||||
|
|
||||||
|
self.fsm_state = new_fsm_grammar_states
|
||||||
|
self.fsm = new_grammars
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousSampling:
|
class HeterogeneousSampling:
|
||||||
r"""
|
r"""
|
||||||
@ -534,7 +569,7 @@ class HeterogeneousSampling:
|
|||||||
|
|
||||||
self.greedy = Greedy()
|
self.greedy = Greedy()
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits, fsm_grammar_states, grammars):
|
||||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||||
if self.greedy_indices:
|
if self.greedy_indices:
|
||||||
# Computing for all indices is faster than slicing
|
# Computing for all indices is faster than slicing
|
||||||
|
Loading…
Reference in New Issue
Block a user