feat: improve grammar advance logic to avoid blocking GPU

This commit is contained in:
drbh 2024-02-13 00:01:19 +00:00
parent d0d7cd9e92
commit 8f14019053
4 changed files with 65 additions and 48 deletions

View File

@ -87,7 +87,9 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
@ -413,14 +415,14 @@ class CausalLMBatch(Batch):
# We slice the keys to remove the padding from previous batches # We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
padded_past_keys[ padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_keys[:, :, -past_seq_len:, :]
] = past_keys[:, :, -past_seq_len:, :] )
else: else:
# BLOOM case # BLOOM case
padded_past_keys[ padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
start_index:end_index, :, :, -past_seq_len: past_keys[:, :, :, -past_seq_len:]
] = past_keys[:, :, :, -past_seq_len:] )
del past_keys del past_keys
start_index = end_index start_index = end_index
@ -438,9 +440,9 @@ class CausalLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches # We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
padded_past_values[ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_values[:, :, -past_seq_len:, :]
] = past_values[:, :, -past_seq_len:, :] )
del past_values del past_values
# Update values # Update values
@ -504,9 +506,11 @@ class CausalLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" device_map=(
if torch.cuda.is_available() and torch.cuda.device_count() > 1 "auto"
else None, if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -696,7 +700,7 @@ class CausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip( for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs top_token_ids, top_token_logprobs
): ):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
@ -735,6 +739,9 @@ class CausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length

View File

@ -870,7 +870,11 @@ class FlashCausalLM(Model):
# Try to find an associated cuda graph # Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None) cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None: if (
cu_seqlen_prefill is not None
or cuda_graph is None
or batch.speculative_ids is not None
):
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1029,6 +1033,9 @@ class FlashCausalLM(Model):
cumulative_length += input_length cumulative_length += input_length
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
next_input_ids
)
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids

View File

@ -478,10 +478,8 @@ class GrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammar): def __init__(self, tokenizer, device, grammar):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm( self.fsm = GrammarLogitProcessor._cached_compile_fsm(grammar, self.tokenizer)
self, grammar, self.tokenizer
)
def __call__( def __call__(
self, self,
@ -490,26 +488,26 @@ class GrammarLogitProcessor(LogitsProcessor):
): ):
if fsm_grammar_state == -1 or self.fsm is None: if fsm_grammar_state == -1 or self.fsm is None:
return logits return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0 mask[allowed_tokens] = 0
biased_scores = logits + mask biased_scores = logits + mask
return biased_scores return biased_scores
def advance(self, next_token_id, fsm_grammar_state, grammar): def advance(self, next_token_id, fsm_grammar_state):
return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsm
)
@staticmethod
def _advance(next_token_id, fsm_grammar_state, fsm):
if fsm_grammar_state == -1: if fsm_grammar_state == -1:
return fsm_grammar_state return fsm_grammar_state
if grammar == "" or grammar is None:
return fsm_grammar_state
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
return fsm.next_state(fsm_grammar_state, next_token_id) return fsm.next_state(fsm_grammar_state, next_token_id)
@staticmethod @staticmethod
@lru_cache(maxsize=32, typed=True) @lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(self, schema, tokenizer): def _cached_compile_fsm(schema, tokenizer):
start_time = time.time() start_time = time.time()
try: try:
json.loads(schema) # check if schema is a valid json json.loads(schema) # check if schema is a valid json
@ -522,7 +520,7 @@ class GrammarLogitProcessor(LogitsProcessor):
@staticmethod @staticmethod
@lru_cache(maxsize=32, typed=True) @lru_cache(maxsize=32, typed=True)
def adapt_tokenizer(tokenizer): def _cached_adapt_tokenizer(tokenizer):
"""Adapt tokenizer to work with the FSM. """Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of The API of Outlines tokenizers is slightly different to that of
@ -560,10 +558,10 @@ class GrammarLogitProcessor(LogitsProcessor):
class HeterogeneousGrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammars): def __init__(self, tokenizer, device, grammars):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = [ self.fsms = [
( (
GrammarLogitProcessor._cached_compile_fsm(self, g, self.tokenizer) GrammarLogitProcessor._cached_compile_fsm(g, self.tokenizer)
if g if g
else None else None
) )
@ -586,10 +584,13 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
logits[i] = biased_scores logits[i] = biased_scores
return logits return logits
def advance(self, next_token_ids, fsm_grammar_states, grammars): def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
return GrammarLogitProcessor.advance( return [
self, next_token_ids, fsm_grammar_states, grammars GrammarLogitProcessor._advance(
) next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
)
for i in range(len(next_token_ids))
]
def filter(self, indices): def filter(self, indices):
return GrammarLogitProcessor.filter(self, indices) return GrammarLogitProcessor.filter(self, indices)

View File

@ -1,5 +1,5 @@
import re import re
from typing import List, Optional, Tuple, DefaultDict from typing import List, Optional, Tuple
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@ -92,14 +92,15 @@ class NextTokenChooser:
next_id = self.choice(scores[-1]).view(1, 1) next_id = self.choice(scores[-1]).view(1, 1)
if self.grammar_processor is not None:
next_state = self.grammar_processor.advance(
next_id.item(), self.fsm_grammar_state, self.grammar
)
self.fsm_grammar_state = next_state
return next_id, next_logprob return next_id, next_logprob
def advance_grammar(self, next_id):
if self.grammar_processor is not None:
self.fsm_grammar_state = self.grammar_processor.advance(
next_id, self.fsm_grammar_state
)
return self
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
@ -385,15 +386,16 @@ class HeterogeneousNextTokenChooser:
else: else:
speculative_ids = None speculative_ids = None
# advance the grammar state
if self.grammar_processor is not None:
for i in range(len(self.fsm_grammar_states)):
self.fsm_grammar_states[i] = self.grammar_processor.advance(
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
)
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def advance_grammar(self, next_ids: torch.Tensor):
if self.grammar_processor is not None:
other_new_states = self.grammar_processor.advance_batch(
next_ids.tolist(), self.fsm_grammar_states, self.grammars
)
self.fsm_grammar_states = other_new_states
return self
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
self.watermark_processor = self.watermark_processor.filter(indices) self.watermark_processor = self.watermark_processor.filter(indices)