mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve grammar advance logic to avoid blocking GPU
This commit is contained in:
parent
d0d7cd9e92
commit
8f14019053
@ -87,7 +87,9 @@ class CausalLMBatch(Batch):
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
@ -413,14 +415,14 @@ class CausalLMBatch(Batch):
|
||||
# We slice the keys to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if batch.keys_head_dim_last:
|
||||
padded_past_keys[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = past_keys[:, :, -past_seq_len:, :]
|
||||
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_keys[:, :, -past_seq_len:, :]
|
||||
)
|
||||
else:
|
||||
# BLOOM case
|
||||
padded_past_keys[
|
||||
start_index:end_index, :, :, -past_seq_len:
|
||||
] = past_keys[:, :, :, -past_seq_len:]
|
||||
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
|
||||
past_keys[:, :, :, -past_seq_len:]
|
||||
)
|
||||
del past_keys
|
||||
|
||||
start_index = end_index
|
||||
@ -438,9 +440,9 @@ class CausalLMBatch(Batch):
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the past values to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
padded_past_values[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = past_values[:, :, -past_seq_len:, :]
|
||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_values[:, :, -past_seq_len:, :]
|
||||
)
|
||||
del past_values
|
||||
|
||||
# Update values
|
||||
@ -504,9 +506,11 @@ class CausalLM(Model):
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -696,7 +700,7 @@ class CausalLM(Model):
|
||||
|
||||
if top_n_tokens > 0:
|
||||
all_top_tokens = []
|
||||
for (top_token_ids, top_token_logprobs) in zip(
|
||||
for top_token_ids, top_token_logprobs in zip(
|
||||
top_token_ids, top_token_logprobs
|
||||
):
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
@ -735,6 +739,9 @@ class CausalLM(Model):
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||
next_token_id_squeezed.item()
|
||||
)
|
||||
batch.input_ids[i, 0] = next_token_id
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.input_lengths[i] = new_input_length
|
||||
|
@ -870,7 +870,11 @@ class FlashCausalLM(Model):
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None:
|
||||
if (
|
||||
cu_seqlen_prefill is not None
|
||||
or cuda_graph is None
|
||||
or batch.speculative_ids is not None
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -1029,6 +1033,9 @@ class FlashCausalLM(Model):
|
||||
|
||||
cumulative_length += input_length
|
||||
|
||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
||||
next_input_ids
|
||||
)
|
||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
|
@ -478,10 +478,8 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, tokenizer, device, grammar):
|
||||
self.device = device
|
||||
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||
self, grammar, self.tokenizer
|
||||
)
|
||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(grammar, self.tokenizer)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -490,26 +488,26 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
):
|
||||
if fsm_grammar_state == -1 or self.fsm is None:
|
||||
return logits
|
||||
|
||||
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||
mask[allowed_tokens] = 0
|
||||
biased_scores = logits + mask
|
||||
return biased_scores
|
||||
|
||||
def advance(self, next_token_id, fsm_grammar_state, grammar):
|
||||
def advance(self, next_token_id, fsm_grammar_state):
|
||||
return GrammarLogitProcessor._advance(
|
||||
next_token_id, fsm_grammar_state, self.fsm
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _advance(next_token_id, fsm_grammar_state, fsm):
|
||||
if fsm_grammar_state == -1:
|
||||
return fsm_grammar_state
|
||||
|
||||
if grammar == "" or grammar is None:
|
||||
return fsm_grammar_state
|
||||
|
||||
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
|
||||
return fsm.next_state(fsm_grammar_state, next_token_id)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=32, typed=True)
|
||||
def _cached_compile_fsm(self, schema, tokenizer):
|
||||
def _cached_compile_fsm(schema, tokenizer):
|
||||
start_time = time.time()
|
||||
try:
|
||||
json.loads(schema) # check if schema is a valid json
|
||||
@ -522,7 +520,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=32, typed=True)
|
||||
def adapt_tokenizer(tokenizer):
|
||||
def _cached_adapt_tokenizer(tokenizer):
|
||||
"""Adapt tokenizer to work with the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
@ -560,10 +558,10 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
def __init__(self, tokenizer, device, grammars):
|
||||
self.device = device
|
||||
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||
self.fsms = [
|
||||
(
|
||||
GrammarLogitProcessor._cached_compile_fsm(self, g, self.tokenizer)
|
||||
GrammarLogitProcessor._cached_compile_fsm(g, self.tokenizer)
|
||||
if g
|
||||
else None
|
||||
)
|
||||
@ -586,10 +584,13 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
logits[i] = biased_scores
|
||||
return logits
|
||||
|
||||
def advance(self, next_token_ids, fsm_grammar_states, grammars):
|
||||
return GrammarLogitProcessor.advance(
|
||||
self, next_token_ids, fsm_grammar_states, grammars
|
||||
)
|
||||
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
|
||||
return [
|
||||
GrammarLogitProcessor._advance(
|
||||
next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
|
||||
)
|
||||
for i in range(len(next_token_ids))
|
||||
]
|
||||
|
||||
def filter(self, indices):
|
||||
return GrammarLogitProcessor.filter(self, indices)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import List, Optional, Tuple, DefaultDict
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
@ -92,14 +92,15 @@ class NextTokenChooser:
|
||||
|
||||
next_id = self.choice(scores[-1]).view(1, 1)
|
||||
|
||||
if self.grammar_processor is not None:
|
||||
next_state = self.grammar_processor.advance(
|
||||
next_id.item(), self.fsm_grammar_state, self.grammar
|
||||
)
|
||||
self.fsm_grammar_state = next_state
|
||||
|
||||
return next_id, next_logprob
|
||||
|
||||
def advance_grammar(self, next_id):
|
||||
if self.grammar_processor is not None:
|
||||
self.fsm_grammar_state = self.grammar_processor.advance(
|
||||
next_id, self.fsm_grammar_state
|
||||
)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
@ -385,15 +386,16 @@ class HeterogeneousNextTokenChooser:
|
||||
else:
|
||||
speculative_ids = None
|
||||
|
||||
# advance the grammar state
|
||||
if self.grammar_processor is not None:
|
||||
for i in range(len(self.fsm_grammar_states)):
|
||||
self.fsm_grammar_states[i] = self.grammar_processor.advance(
|
||||
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
|
||||
)
|
||||
|
||||
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||
|
||||
def advance_grammar(self, next_ids: torch.Tensor):
|
||||
if self.grammar_processor is not None:
|
||||
other_new_states = self.grammar_processor.advance_batch(
|
||||
next_ids.tolist(), self.fsm_grammar_states, self.grammars
|
||||
)
|
||||
self.fsm_grammar_states = other_new_states
|
||||
return self
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
self.watermark_processor = self.watermark_processor.filter(indices)
|
||||
|
Loading…
Reference in New Issue
Block a user