mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
add cuda graphs to token warping
This commit is contained in:
parent
745f596c88
commit
e2727387aa
@ -1,8 +1,8 @@
|
|||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LogitsProcessorList,
|
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
@ -34,62 +34,108 @@ class Greedy:
|
|||||||
return logits.argmax()
|
return logits.argmax()
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class StaticWarper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
watermark=False,
|
temperature=1.0,
|
||||||
temperature=1.0,
|
top_k=None,
|
||||||
repetition_penalty=1.0,
|
top_p=None,
|
||||||
top_k=None,
|
typical_p=None,
|
||||||
top_p=None,
|
|
||||||
typical_p=None,
|
|
||||||
do_sample=False,
|
|
||||||
seed=0,
|
|
||||||
device="cpu",
|
|
||||||
):
|
):
|
||||||
warpers = LogitsProcessorList()
|
self.warpers = []
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
|
||||||
# all samplers can be found in `generation_utils_samplers.py`
|
|
||||||
sampling = do_sample
|
|
||||||
|
|
||||||
if watermark:
|
|
||||||
warpers.append(WatermarkLogitsProcessor(device=device))
|
|
||||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
|
||||||
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
|
||||||
if temperature is not None and temperature != 1.0:
|
if temperature is not None and temperature != 1.0:
|
||||||
temperature = float(temperature)
|
temperature = float(temperature)
|
||||||
warpers.append(TemperatureLogitsWarper(temperature))
|
self.warpers.append(TemperatureLogitsWarper(temperature))
|
||||||
sampling = True
|
|
||||||
if top_k is not None and top_k != 0:
|
if top_k is not None and top_k != 0:
|
||||||
warpers.append(TopKLogitsWarper(top_k=top_k))
|
self.warpers.append(TopKLogitsWarper(top_k=top_k))
|
||||||
sampling = True
|
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
warpers.append(TopPLogitsWarper(top_p=top_p))
|
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
sampling = True
|
|
||||||
if typical_p is not None and typical_p < 1.0:
|
if typical_p is not None and typical_p < 1.0:
|
||||||
warpers.append(TypicalLogitsWarper(mass=typical_p))
|
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||||
sampling = True
|
|
||||||
|
|
||||||
self.warpers = warpers
|
self.cuda_graph = None
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.static_scores = None
|
||||||
|
self.static_warped_scores = None
|
||||||
|
self.static_next_logprob = None
|
||||||
|
|
||||||
|
def __call__(self, scores):
|
||||||
|
if self.cuda_graph is None:
|
||||||
|
self.static_scores = scores
|
||||||
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
|
capture_stream = torch.cuda.stream(torch.cuda.Stream())
|
||||||
|
capture_stream.__enter__()
|
||||||
|
self.cuda_graph.capture_begin()
|
||||||
|
|
||||||
|
for warper in self.warpers:
|
||||||
|
self.static_warped_scores = warper(None, self.static_scores)
|
||||||
|
|
||||||
|
# Compute logprobs
|
||||||
|
self.static_next_logprob = torch.log_softmax(self.static_warped_scores, -1)
|
||||||
|
|
||||||
|
self.cuda_graph.capture_end()
|
||||||
|
capture_stream.__exit__(None, None, None)
|
||||||
|
|
||||||
|
self.static_scores.copy_(scores)
|
||||||
|
self.cuda_graph.replay()
|
||||||
|
|
||||||
|
return self.static_warped_scores, self.static_next_logprob
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(10)
|
||||||
|
def static_warper(temperature: Optional[float], top_k: Optional[int], top_p: Optional[float],
|
||||||
|
typical_p: Optional[float]) -> StaticWarper:
|
||||||
|
return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
|
||||||
|
|
||||||
|
|
||||||
|
class NextTokenChooser:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
watermark=False,
|
||||||
|
temperature=1.0,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
top_k=None,
|
||||||
|
top_p=None,
|
||||||
|
typical_p=None,
|
||||||
|
do_sample=False,
|
||||||
|
seed=0,
|
||||||
|
device="cpu",
|
||||||
|
):
|
||||||
|
self.watermark_warper = WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
|
self.repetition_warper = RepetitionPenaltyLogitsProcessor(
|
||||||
|
penalty=repetition_penalty) if repetition_penalty else None
|
||||||
|
|
||||||
|
sampling = do_sample or (temperature is not None and temperature != 1.0) or (
|
||||||
|
top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (
|
||||||
|
typical_p is not None and typical_p < 1.0)
|
||||||
|
if sampling:
|
||||||
|
self.choice = Sampling(seed, device)
|
||||||
|
self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
|
||||||
|
else:
|
||||||
|
self.choice = Greedy()
|
||||||
|
self.static_warper = None
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
# Warp logits
|
if self.watermark_warper:
|
||||||
scores = self.warpers(input_ids, scores)
|
scores = self.watermark_warper(input_ids, scores)
|
||||||
|
if self.repetition_warper:
|
||||||
|
scores = self.repetition_warper(input_ids, scores)
|
||||||
|
|
||||||
# Compute logprobs
|
if self.static_warper is None:
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
|
else:
|
||||||
|
scores, next_logprob = self.static_warper(scores)
|
||||||
|
|
||||||
# Choose tokens
|
next_id = self.choice(scores[-1]).view(1, 1)
|
||||||
next_id = self.choice(scores[-1])
|
|
||||||
|
|
||||||
return next_id.view(1, 1), logprobs
|
return next_id, next_logprob
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.NextTokenChooserParameters,
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
watermark=pb.watermark,
|
watermark=pb.watermark,
|
||||||
@ -117,11 +163,11 @@ class StopSequenceCriteria:
|
|||||||
|
|
||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eos_token_id: int,
|
eos_token_id: int,
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
ignore_eos_token: bool = False,
|
ignore_eos_token: bool = False,
|
||||||
):
|
):
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
@ -147,9 +193,9 @@ class StoppingCriteria:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.StoppingCriteriaParameters,
|
pb: generate_pb2.StoppingCriteriaParameters,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
|
Loading…
Reference in New Issue
Block a user