add cpu support

This commit is contained in:
OlivierDehaene 2023-05-09 16:30:19 +02:00
parent e2727387aa
commit 1df2aa03c5

View File

@ -36,11 +36,11 @@ class Greedy:
class StaticWarper: class StaticWarper:
def __init__( def __init__(
self, self,
temperature=1.0, temperature=1.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
typical_p=None, typical_p=None,
): ):
self.warpers = [] self.warpers = []
@ -64,18 +64,14 @@ class StaticWarper:
self.static_scores = scores self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph() self.cuda_graph = torch.cuda.CUDAGraph()
capture_stream = torch.cuda.stream(torch.cuda.Stream()) with torch.cuda.graph(self.cuda_graph):
capture_stream.__enter__() for warper in self.warpers:
self.cuda_graph.capture_begin() self.static_warped_scores = warper(None, self.static_scores)
for warper in self.warpers: # Compute logprobs
self.static_warped_scores = warper(None, self.static_scores) self.static_next_logprob = torch.log_softmax(
self.static_warped_scores, -1
# Compute logprobs )
self.static_next_logprob = torch.log_softmax(self.static_warped_scores, -1)
self.cuda_graph.capture_end()
capture_stream.__exit__(None, None, None)
self.static_scores.copy_(scores) self.static_scores.copy_(scores)
self.cuda_graph.replay() self.cuda_graph.replay()
@ -84,34 +80,51 @@ class StaticWarper:
@lru_cache(10) @lru_cache(10)
def static_warper(temperature: Optional[float], top_k: Optional[int], top_p: Optional[float], def static_warper(
typical_p: Optional[float]) -> StaticWarper: temperature: Optional[float],
return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) top_k: Optional[int],
top_p: Optional[float],
typical_p: Optional[float],
) -> StaticWarper:
return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
watermark=False, watermark=False,
temperature=1.0, temperature=1.0,
repetition_penalty=1.0, repetition_penalty=1.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
typical_p=None, typical_p=None,
do_sample=False, do_sample=False,
seed=0, seed=0,
device="cpu", device="cpu",
): ):
self.watermark_warper = WatermarkLogitsProcessor(device=device) if watermark else None self.watermark_warper = (
self.repetition_warper = RepetitionPenaltyLogitsProcessor( WatermarkLogitsProcessor(device=device) if watermark else None
penalty=repetition_penalty) if repetition_penalty else None )
self.repetition_warper = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty
else None
)
sampling = do_sample or (temperature is not None and temperature != 1.0) or ( sampling = (
top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or ( do_sample
typical_p is not None and typical_p < 1.0) or (temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0)
)
if sampling: if sampling:
self.choice = Sampling(seed, device) self.choice = Sampling(seed, device)
self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) self.static_warper = static_warper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
else: else:
self.choice = Greedy() self.choice = Greedy()
self.static_warper = None self.static_warper = None
@ -133,9 +146,9 @@ class NextTokenChooser:
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.NextTokenChooserParameters, pb: generate_pb2.NextTokenChooserParameters,
device: torch.device, device: torch.device,
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return NextTokenChooser(
watermark=pb.watermark, watermark=pb.watermark,
@ -163,11 +176,11 @@ class StopSequenceCriteria:
class StoppingCriteria: class StoppingCriteria:
def __init__( def __init__(
self, self,
eos_token_id: int, eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria], stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20, max_new_tokens: int = 20,
ignore_eos_token: bool = False, ignore_eos_token: bool = False,
): ):
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
@ -193,9 +206,9 @@ class StoppingCriteria:
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.StoppingCriteriaParameters, pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria": ) -> "StoppingCriteria":
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences