add cpu support

This commit is contained in:
OlivierDehaene 2023-05-09 16:30:19 +02:00
parent e2727387aa
commit 1df2aa03c5

View File

@ -64,18 +64,14 @@ class StaticWarper:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
capture_stream = torch.cuda.stream(torch.cuda.Stream())
capture_stream.__enter__()
self.cuda_graph.capture_begin()
with torch.cuda.graph(self.cuda_graph):
for warper in self.warpers:
self.static_warped_scores = warper(None, self.static_scores)
# Compute logprobs
self.static_next_logprob = torch.log_softmax(self.static_warped_scores, -1)
self.cuda_graph.capture_end()
capture_stream.__exit__(None, None, None)
self.static_next_logprob = torch.log_softmax(
self.static_warped_scores, -1
)
self.static_scores.copy_(scores)
self.cuda_graph.replay()
@ -84,9 +80,15 @@ class StaticWarper:
@lru_cache(10)
def static_warper(temperature: Optional[float], top_k: Optional[int], top_p: Optional[float],
typical_p: Optional[float]) -> StaticWarper:
return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
def static_warper(
temperature: Optional[float],
top_k: Optional[int],
top_p: Optional[float],
typical_p: Optional[float],
) -> StaticWarper:
return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
class NextTokenChooser:
@ -102,16 +104,27 @@ class NextTokenChooser:
seed=0,
device="cpu",
):
self.watermark_warper = WatermarkLogitsProcessor(device=device) if watermark else None
self.repetition_warper = RepetitionPenaltyLogitsProcessor(
penalty=repetition_penalty) if repetition_penalty else None
self.watermark_warper = (
WatermarkLogitsProcessor(device=device) if watermark else None
)
self.repetition_warper = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty
else None
)
sampling = do_sample or (temperature is not None and temperature != 1.0) or (
top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (
typical_p is not None and typical_p < 1.0)
sampling = (
do_sample
or (temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0)
)
if sampling:
self.choice = Sampling(seed, device)
self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
self.static_warper = static_warper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
else:
self.choice = Greedy()
self.static_warper = None