mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve grammar advance logic to avoid blocking GPU
This commit is contained in:
parent
d0d7cd9e92
commit
8f14019053
@ -87,7 +87,9 @@ class CausalLMBatch(Batch):
|
|||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
next_token_choosers.append(
|
||||||
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
|
)
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
@ -413,14 +415,14 @@ class CausalLMBatch(Batch):
|
|||||||
# We slice the keys to remove the padding from previous batches
|
# We slice the keys to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
if batch.keys_head_dim_last:
|
if batch.keys_head_dim_last:
|
||||||
padded_past_keys[
|
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
past_keys[:, :, -past_seq_len:, :]
|
||||||
] = past_keys[:, :, -past_seq_len:, :]
|
)
|
||||||
else:
|
else:
|
||||||
# BLOOM case
|
# BLOOM case
|
||||||
padded_past_keys[
|
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
|
||||||
start_index:end_index, :, :, -past_seq_len:
|
past_keys[:, :, :, -past_seq_len:]
|
||||||
] = past_keys[:, :, :, -past_seq_len:]
|
)
|
||||||
del past_keys
|
del past_keys
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
@ -438,9 +440,9 @@ class CausalLMBatch(Batch):
|
|||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
# We slice the past values to remove the padding from previous batches
|
# We slice the past values to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
padded_past_values[
|
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
past_values[:, :, -past_seq_len:, :]
|
||||||
] = past_values[:, :, -past_seq_len:, :]
|
)
|
||||||
del past_values
|
del past_values
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
@ -504,9 +506,11 @@ class CausalLM(Model):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto"
|
device_map=(
|
||||||
|
"auto"
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
else None,
|
else None
|
||||||
|
),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -696,7 +700,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
all_top_tokens = []
|
all_top_tokens = []
|
||||||
for (top_token_ids, top_token_logprobs) in zip(
|
for top_token_ids, top_token_logprobs in zip(
|
||||||
top_token_ids, top_token_logprobs
|
top_token_ids, top_token_logprobs
|
||||||
):
|
):
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
@ -735,6 +739,9 @@ class CausalLM(Model):
|
|||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||||
|
next_token_id_squeezed.item()
|
||||||
|
)
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
|
@ -870,7 +870,11 @@ class FlashCausalLM(Model):
|
|||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None:
|
if (
|
||||||
|
cu_seqlen_prefill is not None
|
||||||
|
or cuda_graph is None
|
||||||
|
or batch.speculative_ids is not None
|
||||||
|
):
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -1029,6 +1033,9 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
|
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
||||||
|
next_input_ids
|
||||||
|
)
|
||||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
|
@ -478,10 +478,8 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
def __init__(self, tokenizer, device, grammar):
|
def __init__(self, tokenizer, device, grammar):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
self.fsm = GrammarLogitProcessor._cached_compile_fsm(grammar, self.tokenizer)
|
||||||
self, grammar, self.tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -490,26 +488,26 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
):
|
):
|
||||||
if fsm_grammar_state == -1 or self.fsm is None:
|
if fsm_grammar_state == -1 or self.fsm is None:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
||||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||||
mask[allowed_tokens] = 0
|
mask[allowed_tokens] = 0
|
||||||
biased_scores = logits + mask
|
biased_scores = logits + mask
|
||||||
return biased_scores
|
return biased_scores
|
||||||
|
|
||||||
def advance(self, next_token_id, fsm_grammar_state, grammar):
|
def advance(self, next_token_id, fsm_grammar_state):
|
||||||
|
return GrammarLogitProcessor._advance(
|
||||||
|
next_token_id, fsm_grammar_state, self.fsm
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _advance(next_token_id, fsm_grammar_state, fsm):
|
||||||
if fsm_grammar_state == -1:
|
if fsm_grammar_state == -1:
|
||||||
return fsm_grammar_state
|
return fsm_grammar_state
|
||||||
|
|
||||||
if grammar == "" or grammar is None:
|
|
||||||
return fsm_grammar_state
|
|
||||||
|
|
||||||
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
|
|
||||||
return fsm.next_state(fsm_grammar_state, next_token_id)
|
return fsm.next_state(fsm_grammar_state, next_token_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache(maxsize=32, typed=True)
|
@lru_cache(maxsize=32, typed=True)
|
||||||
def _cached_compile_fsm(self, schema, tokenizer):
|
def _cached_compile_fsm(schema, tokenizer):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
json.loads(schema) # check if schema is a valid json
|
json.loads(schema) # check if schema is a valid json
|
||||||
@ -522,7 +520,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache(maxsize=32, typed=True)
|
@lru_cache(maxsize=32, typed=True)
|
||||||
def adapt_tokenizer(tokenizer):
|
def _cached_adapt_tokenizer(tokenizer):
|
||||||
"""Adapt tokenizer to work with the FSM.
|
"""Adapt tokenizer to work with the FSM.
|
||||||
|
|
||||||
The API of Outlines tokenizers is slightly different to that of
|
The API of Outlines tokenizers is slightly different to that of
|
||||||
@ -560,10 +558,10 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
def __init__(self, tokenizer, device, grammars):
|
def __init__(self, tokenizer, device, grammars):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||||
self.fsms = [
|
self.fsms = [
|
||||||
(
|
(
|
||||||
GrammarLogitProcessor._cached_compile_fsm(self, g, self.tokenizer)
|
GrammarLogitProcessor._cached_compile_fsm(g, self.tokenizer)
|
||||||
if g
|
if g
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
@ -586,10 +584,13 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||||||
logits[i] = biased_scores
|
logits[i] = biased_scores
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def advance(self, next_token_ids, fsm_grammar_states, grammars):
|
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
|
||||||
return GrammarLogitProcessor.advance(
|
return [
|
||||||
self, next_token_ids, fsm_grammar_states, grammars
|
GrammarLogitProcessor._advance(
|
||||||
|
next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
|
||||||
)
|
)
|
||||||
|
for i in range(len(next_token_ids))
|
||||||
|
]
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
return GrammarLogitProcessor.filter(self, indices)
|
return GrammarLogitProcessor.filter(self, indices)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple, DefaultDict
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
@ -92,14 +92,15 @@ class NextTokenChooser:
|
|||||||
|
|
||||||
next_id = self.choice(scores[-1]).view(1, 1)
|
next_id = self.choice(scores[-1]).view(1, 1)
|
||||||
|
|
||||||
if self.grammar_processor is not None:
|
|
||||||
next_state = self.grammar_processor.advance(
|
|
||||||
next_id.item(), self.fsm_grammar_state, self.grammar
|
|
||||||
)
|
|
||||||
self.fsm_grammar_state = next_state
|
|
||||||
|
|
||||||
return next_id, next_logprob
|
return next_id, next_logprob
|
||||||
|
|
||||||
|
def advance_grammar(self, next_id):
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
self.fsm_grammar_state = self.grammar_processor.advance(
|
||||||
|
next_id, self.fsm_grammar_state
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
@ -385,15 +386,16 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
|
|
||||||
# advance the grammar state
|
|
||||||
if self.grammar_processor is not None:
|
|
||||||
for i in range(len(self.fsm_grammar_states)):
|
|
||||||
self.fsm_grammar_states[i] = self.grammar_processor.advance(
|
|
||||||
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
|
|
||||||
)
|
|
||||||
|
|
||||||
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
|
def advance_grammar(self, next_ids: torch.Tensor):
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
other_new_states = self.grammar_processor.advance_batch(
|
||||||
|
next_ids.tolist(), self.fsm_grammar_states, self.grammars
|
||||||
|
)
|
||||||
|
self.fsm_grammar_states = other_new_states
|
||||||
|
return self
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
self.watermark_processor = self.watermark_processor.filter(indices)
|
self.watermark_processor = self.watermark_processor.filter(indices)
|
||||||
|
Loading…
Reference in New Issue
Block a user