add cuda graphs to token warping

This commit is contained in:
OlivierDehaene 2023-05-09 16:30:19 +02:00
parent 745f596c88
commit e2727387aa

View File

@ -1,8 +1,8 @@
import re
import torch
from functools import lru_cache
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
@ -34,6 +34,61 @@ class Greedy:
return logits.argmax()
class StaticWarper:
def __init__(
self,
temperature=1.0,
top_k=None,
top_p=None,
typical_p=None,
):
self.warpers = []
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
self.warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
self.warpers.append(TopKLogitsWarper(top_k=top_k))
if top_p is not None and top_p < 1.0:
self.warpers.append(TopPLogitsWarper(top_p=top_p))
if typical_p is not None and typical_p < 1.0:
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
self.cuda_graph = None
self.static_scores = None
self.static_warped_scores = None
self.static_next_logprob = None
def __call__(self, scores):
if self.cuda_graph is None:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
capture_stream = torch.cuda.stream(torch.cuda.Stream())
capture_stream.__enter__()
self.cuda_graph.capture_begin()
for warper in self.warpers:
self.static_warped_scores = warper(None, self.static_scores)
# Compute logprobs
self.static_next_logprob = torch.log_softmax(self.static_warped_scores, -1)
self.cuda_graph.capture_end()
capture_stream.__exit__(None, None, None)
self.static_scores.copy_(scores)
self.cuda_graph.replay()
return self.static_warped_scores, self.static_next_logprob
@lru_cache(10)
def static_warper(temperature: Optional[float], top_k: Optional[int], top_p: Optional[float],
typical_p: Optional[float]) -> StaticWarper:
return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
class NextTokenChooser:
def __init__(
self,
@ -47,43 +102,34 @@ class NextTokenChooser:
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
self.watermark_warper = WatermarkLogitsProcessor(device=device) if watermark else None
self.repetition_warper = RepetitionPenaltyLogitsProcessor(
penalty=repetition_penalty) if repetition_penalty else None
if watermark:
warpers.append(WatermarkLogitsProcessor(device=device))
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p))
sampling = True
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
sampling = do_sample or (temperature is not None and temperature != 1.0) or (
top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (
typical_p is not None and typical_p < 1.0)
if sampling:
self.choice = Sampling(seed, device)
self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
else:
self.choice = Greedy()
self.static_warper = None
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
if self.watermark_warper:
scores = self.watermark_warper(input_ids, scores)
if self.repetition_warper:
scores = self.repetition_warper(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
else:
scores, next_logprob = self.static_warper(scores)
# Choose tokens
next_id = self.choice(scores[-1])
next_id = self.choice(scores[-1]).view(1, 1)
return next_id.view(1, 1), logprobs
return next_id, next_logprob
@classmethod
def from_pb(