feat: support grammars in batch

This commit is contained in:
drbh 2024-02-09 15:58:00 +00:00
parent 8fd2664a3c
commit ff6e8d9e23
8 changed files with 83 additions and 37 deletions

View File

@ -45,6 +45,7 @@ pub async fn run(
repetition_penalty: repetition_penalty.unwrap_or(1.0), repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0), frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark, watermark,
fsm_grammar_state: 0,
}; };
// Initialize terminal properties // Initialize terminal properties

View File

@ -71,7 +71,9 @@ message NextTokenChooserParameters {
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
/// grammar (applied if not empty) /// grammar (applied if not empty)
string grammar = 10; repeated string grammar = 10;
/// fsm_grammar_state
repeated uint32 fsm_grammar_state = 11;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -129,6 +129,7 @@ impl Client {
frequency_penalty: 0.1, frequency_penalty: 0.1,
watermark: true, watermark: true,
grammar: String::new(), grammar: String::new(),
fsm_grammar_state: 0,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens: max_total_tokens - truncate,

View File

@ -46,6 +46,7 @@ impl Health {
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
fsm_grammar_state: 0,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -369,6 +369,7 @@ mod tests {
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
fsm_grammar_state: 0,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: false,

View File

@ -304,6 +304,7 @@ impl Validation {
seed, seed,
watermark, watermark,
grammar, grammar,
fsm_grammar_state: 0,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -99,6 +99,9 @@ class FlashCausalLMBatch(Batch):
# Maximum number of blocks # Maximum number of blocks
max_blocks: int max_blocks: int
# The states for the grammar FSM
fsm_states: Dict[int, int] = None
def to_pb(self) -> generate_pb2.CachedBatch: def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
@ -137,6 +140,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
fsm_states = {}
all_prefill_logprobs = True all_prefill_logprobs = True
no_prefill_logprobs = True no_prefill_logprobs = True
@ -319,6 +323,7 @@ class FlashCausalLMBatch(Batch):
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=None, speculative_ids=None,
fsm_states=fsm_states,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -594,7 +599,6 @@ class FlashCausalLMBatch(Batch):
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
tokenizer=batches[0].next_token_chooser.tokenizer, tokenizer=batches[0].next_token_chooser.tokenizer,
grammar=batches[0].requests.parameters.grammar,
) )
speculative_ids = ( speculative_ids = (
@ -1015,9 +1019,9 @@ class FlashCausalLM(Model):
# Copy batch.input_ids to prefill_token_indices # Copy batch.input_ids to prefill_token_indices
if prefill_logprobs: if prefill_logprobs:
if len(batch) > 1: if len(batch) > 1:
prefill_tokens_indices[ prefill_tokens_indices[out_start_index : out_end_index - 1] = (
out_start_index : out_end_index - 1 batch.input_ids[start_index + 1 : start_index + out_length]
] = batch.input_ids[start_index + 1 : start_index + out_length] )
else: else:
# Set prefill_tokens_indices to the correct slice # Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[ prefill_tokens_indices = batch.input_ids[
@ -1168,7 +1172,7 @@ class FlashCausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip( for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs top_token_ids, top_token_logprobs
): ):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(

View File

@ -27,6 +27,7 @@ from functools import lru_cache
# TODO: remove when done debugging # TODO: remove when done debugging
import time import time
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
@ -42,6 +43,7 @@ class NextTokenChooser:
device="cpu", device="cpu",
tokenizer=None, tokenizer=None,
grammar=None, grammar=None,
fsm_grammar_state=None,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -73,6 +75,9 @@ class NextTokenChooser:
sampling = do_sample or has_warpers sampling = do_sample or has_warpers
self.fsm_grammar_state = fsm_grammar_state
self.grammars = grammar
# TODO: is grammar a subset of sampling? If so, we should merge them # TODO: is grammar a subset of sampling? If so, we should merge them
if grammar: if grammar:
self.choice = Grammar(tokenizer, device, grammar) self.choice = Grammar(tokenizer, device, grammar)
@ -92,7 +97,9 @@ class NextTokenChooser:
else: else:
scores, next_logprob = self.static_warper(scores) scores, next_logprob = self.static_warper(scores)
next_id = self.choice(scores[-1]).view(1, 1) next_id = self.choice(scores[-1], self.fsm_grammar_state, self.grammars).view(
1, 1
)
return next_id, next_logprob return next_id, next_logprob
@ -116,6 +123,7 @@ class NextTokenChooser:
device=device, device=device,
tokenizer=tokenizer, tokenizer=tokenizer,
grammar=pb.grammar, grammar=pb.grammar,
fsm_grammar_state=pb.fsm_grammar_state,
) )
@ -222,6 +230,7 @@ class HeterogeneousNextTokenChooser:
seeds: List[int], seeds: List[int],
tokenizer=None, tokenizer=None,
grammar=None, grammar=None,
fsm_grammar_states=None,
): ):
warpers = [] warpers = []
@ -275,9 +284,8 @@ class HeterogeneousNextTokenChooser:
self.warpers = warpers self.warpers = warpers
first_grammar = grammar[0] if grammar else None if grammar is not None:
if first_grammar: self.choice = Grammar(tokenizer, device)
self.choice = Grammar(tokenizer, device, first_grammar)
elif any(do_sample): elif any(do_sample):
self.choice = HeterogeneousSampling(do_sample, seeds, device) self.choice = HeterogeneousSampling(do_sample, seeds, device)
else: else:
@ -288,6 +296,9 @@ class HeterogeneousNextTokenChooser:
self.do_sample = do_sample self.do_sample = do_sample
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.tokenizer = tokenizer
self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammar
def __call__( def __call__(
self, self,
@ -320,7 +331,7 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores, self.fsm_grammar_states, self.grammars)
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
next_ids = next_ids.view(B * S) next_ids = next_ids.view(B * S)
@ -398,7 +409,7 @@ class HeterogeneousNextTokenChooser:
self.do_sample = [self.do_sample[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices]
if self.use_grammar or any(self.do_sample): if self.use_grammar or any(self.do_sample):
self.choice.filter(indices) self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
else: else:
self.choice = Greedy() self.choice = Greedy()
@ -426,6 +437,7 @@ class HeterogeneousNextTokenChooser:
dtype=dtype, dtype=dtype,
tokenizer=tokenizer, tokenizer=tokenizer,
grammar=[pb_.grammar for pb_ in pb], grammar=[pb_.grammar for pb_ in pb],
fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb],
) )
@ -444,39 +456,51 @@ class Sampling:
class Greedy: class Greedy:
def __call__(self, logits): def __call__(self, logits, *args):
return logits.argmax(dim=-1) return logits.argmax(dim=-1)
def filter(self, indices, *args):
return self
class Grammar: class Grammar:
fsm_state: DefaultDict[int, int] fsm_state: DefaultDict[int, int]
fsm: RegexFSM fsm: RegexFSM
def __init__(self, tokenizer, device, grammar): def __init__(self, tokenizer, device):
fsm = self.compile_fsm(grammar, tokenizer)
self.fsm = fsm
self.fsm_state = defaultdict(int)
self.device = device self.device = device
self.tokenizer = tokenizer
def __call__(self, logits): def __call__(self, logits, fsm_grammar_states, grammars):
seq_id = 0 empty = torch.ones(logits.shape[0], dtype=torch.int64, device=logits.device)
try:
for i in range(len(fsm_grammar_states)):
if fsm_grammar_states[i] == -1:
continue
if self.fsm_state[seq_id] == -1: # this is cached and should be fast after the first time
return self.fsm_state[seq_id].eos_token_id fsm = self.compile_fsm(grammars[i], self.tokenizer)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits[i : i + 1] + mask
greedy = biased_scores.argmax(dim=-1)
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) # if greedy is empty, return the eos token
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) if greedy.shape[0] == 0:
mask[allowed_tokens] = 0 continue
biased_scores = logits + mask
# greedly pick the token with the highest score # import ipdb; ipdb.set_trace()
greedy = biased_scores.argmax(dim=-1) fsm_grammar_states[i] = fsm.next_state(
fsm_grammar_states[i], greedy.item()
)
# now update the fsm state empty[i] = greedy.item()
self.fsm_state[seq_id] = self.fsm.next_state( except Exception as e:
self.fsm_state[seq_id], greedy.item() print(f"Exception: {e}")
) import ipdb
return greedy
ipdb.set_trace()
return empty
@lru_cache(maxsize=32, typed=True) @lru_cache(maxsize=32, typed=True)
def compile_fsm(self, schema, tokenizer): def compile_fsm(self, schema, tokenizer):
@ -488,7 +512,6 @@ class Grammar:
print(f"Compile FSM: {time.time() - start_time}") print(f"Compile FSM: {time.time() - start_time}")
return fsm return fsm
def adapt_tokenizer(self, tokenizer): def adapt_tokenizer(self, tokenizer):
"""Adapt tokenizer to work with the FSM. """Adapt tokenizer to work with the FSM.
@ -515,6 +538,18 @@ class Grammar:
return tokenizer return tokenizer
def filter(self, indices, fsm_grammar_states, grammars):
new_fsm_grammar_states = []
new_grammars = []
for i in indices:
new_fsm_grammar_states.append(fsm_grammar_states[i])
new_grammars.append(grammars[i])
self.fsm_state = new_fsm_grammar_states
self.fsm = new_grammars
return self
class HeterogeneousSampling: class HeterogeneousSampling:
r""" r"""
@ -534,7 +569,7 @@ class HeterogeneousSampling:
self.greedy = Greedy() self.greedy = Greedy()
def __call__(self, logits): def __call__(self, logits, fsm_grammar_states, grammars):
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices: if self.greedy_indices:
# Computing for all indices is faster than slicing # Computing for all indices is faster than slicing