feat: improve grammar advance logic to avoid blocking GPU

This commit is contained in:
drbh 2024-02-13 00:01:19 +00:00
parent d0d7cd9e92
commit 8f14019053
4 changed files with 65 additions and 48 deletions

View File

@ -87,7 +87,9 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
@ -413,14 +415,14 @@ class CausalLMBatch(Batch):
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[
start_index:end_index, :, -past_seq_len:, :
] = past_keys[:, :, -past_seq_len:, :]
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
)
else:
# BLOOM case
padded_past_keys[
start_index:end_index, :, :, -past_seq_len:
] = past_keys[:, :, :, -past_seq_len:]
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:]
)
del past_keys
start_index = end_index
@ -438,9 +440,9 @@ class CausalLMBatch(Batch):
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[
start_index:end_index, :, -past_seq_len:, :
] = past_values[:, :, -past_seq_len:, :]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Update values
@ -504,9 +506,11 @@ class CausalLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
@ -696,7 +700,7 @@ class CausalLM(Model):
if top_n_tokens > 0:
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(
for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
@ -735,6 +739,9 @@ class CausalLM(Model):
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length

View File

@ -870,7 +870,11 @@ class FlashCausalLM(Model):
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None:
if (
cu_seqlen_prefill is not None
or cuda_graph is None
or batch.speculative_ids is not None
):
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -1029,6 +1033,9 @@ class FlashCausalLM(Model):
cumulative_length += input_length
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
next_input_ids
)
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids

View File

@ -478,10 +478,8 @@ class GrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammar):
self.device = device
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
self, grammar, self.tokenizer
)
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm(grammar, self.tokenizer)
def __call__(
self,
@ -490,26 +488,26 @@ class GrammarLogitProcessor(LogitsProcessor):
):
if fsm_grammar_state == -1 or self.fsm is None:
return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits + mask
return biased_scores
def advance(self, next_token_id, fsm_grammar_state, grammar):
def advance(self, next_token_id, fsm_grammar_state):
return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsm
)
@staticmethod
def _advance(next_token_id, fsm_grammar_state, fsm):
if fsm_grammar_state == -1:
return fsm_grammar_state
if grammar == "" or grammar is None:
return fsm_grammar_state
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
return fsm.next_state(fsm_grammar_state, next_token_id)
@staticmethod
@lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(self, schema, tokenizer):
def _cached_compile_fsm(schema, tokenizer):
start_time = time.time()
try:
json.loads(schema) # check if schema is a valid json
@ -522,7 +520,7 @@ class GrammarLogitProcessor(LogitsProcessor):
@staticmethod
@lru_cache(maxsize=32, typed=True)
def adapt_tokenizer(tokenizer):
def _cached_adapt_tokenizer(tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
@ -560,10 +558,10 @@ class GrammarLogitProcessor(LogitsProcessor):
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammars):
self.device = device
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = [
(
GrammarLogitProcessor._cached_compile_fsm(self, g, self.tokenizer)
GrammarLogitProcessor._cached_compile_fsm(g, self.tokenizer)
if g
else None
)
@ -586,10 +584,13 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
logits[i] = biased_scores
return logits
def advance(self, next_token_ids, fsm_grammar_states, grammars):
return GrammarLogitProcessor.advance(
self, next_token_ids, fsm_grammar_states, grammars
)
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
return [
GrammarLogitProcessor._advance(
next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
)
for i in range(len(next_token_ids))
]
def filter(self, indices):
return GrammarLogitProcessor.filter(self, indices)

View File

@ -1,5 +1,5 @@
import re
from typing import List, Optional, Tuple, DefaultDict
from typing import List, Optional, Tuple
import torch
from text_generation_server.pb import generate_pb2
@ -92,14 +92,15 @@ class NextTokenChooser:
next_id = self.choice(scores[-1]).view(1, 1)
if self.grammar_processor is not None:
next_state = self.grammar_processor.advance(
next_id.item(), self.fsm_grammar_state, self.grammar
)
self.fsm_grammar_state = next_state
return next_id, next_logprob
def advance_grammar(self, next_id):
if self.grammar_processor is not None:
self.fsm_grammar_state = self.grammar_processor.advance(
next_id, self.fsm_grammar_state
)
return self
@classmethod
def from_pb(
cls,
@ -385,15 +386,16 @@ class HeterogeneousNextTokenChooser:
else:
speculative_ids = None
# advance the grammar state
if self.grammar_processor is not None:
for i in range(len(self.fsm_grammar_states)):
self.fsm_grammar_states[i] = self.grammar_processor.advance(
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
)
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def advance_grammar(self, next_ids: torch.Tensor):
if self.grammar_processor is not None:
other_new_states = self.grammar_processor.advance_batch(
next_ids.tolist(), self.fsm_grammar_states, self.grammars
)
self.fsm_grammar_states = other_new_states
return self
def filter(self, indices):
if self.watermark_processor is not None:
self.watermark_processor = self.watermark_processor.filter(indices)