mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
add cpu support
This commit is contained in:
parent
e2727387aa
commit
1df2aa03c5
@ -64,18 +64,14 @@ class StaticWarper:
|
|||||||
self.static_scores = scores
|
self.static_scores = scores
|
||||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
capture_stream = torch.cuda.stream(torch.cuda.Stream())
|
with torch.cuda.graph(self.cuda_graph):
|
||||||
capture_stream.__enter__()
|
|
||||||
self.cuda_graph.capture_begin()
|
|
||||||
|
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
self.static_warped_scores = warper(None, self.static_scores)
|
self.static_warped_scores = warper(None, self.static_scores)
|
||||||
|
|
||||||
# Compute logprobs
|
# Compute logprobs
|
||||||
self.static_next_logprob = torch.log_softmax(self.static_warped_scores, -1)
|
self.static_next_logprob = torch.log_softmax(
|
||||||
|
self.static_warped_scores, -1
|
||||||
self.cuda_graph.capture_end()
|
)
|
||||||
capture_stream.__exit__(None, None, None)
|
|
||||||
|
|
||||||
self.static_scores.copy_(scores)
|
self.static_scores.copy_(scores)
|
||||||
self.cuda_graph.replay()
|
self.cuda_graph.replay()
|
||||||
@ -84,9 +80,15 @@ class StaticWarper:
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(10)
|
@lru_cache(10)
|
||||||
def static_warper(temperature: Optional[float], top_k: Optional[int], top_p: Optional[float],
|
def static_warper(
|
||||||
typical_p: Optional[float]) -> StaticWarper:
|
temperature: Optional[float],
|
||||||
return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
|
top_k: Optional[int],
|
||||||
|
top_p: Optional[float],
|
||||||
|
typical_p: Optional[float],
|
||||||
|
) -> StaticWarper:
|
||||||
|
return StaticWarper(
|
||||||
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
@ -102,16 +104,27 @@ class NextTokenChooser:
|
|||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
):
|
):
|
||||||
self.watermark_warper = WatermarkLogitsProcessor(device=device) if watermark else None
|
self.watermark_warper = (
|
||||||
self.repetition_warper = RepetitionPenaltyLogitsProcessor(
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
penalty=repetition_penalty) if repetition_penalty else None
|
)
|
||||||
|
self.repetition_warper = (
|
||||||
|
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
||||||
|
if repetition_penalty
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
sampling = do_sample or (temperature is not None and temperature != 1.0) or (
|
sampling = (
|
||||||
top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (
|
do_sample
|
||||||
typical_p is not None and typical_p < 1.0)
|
or (temperature is not None and temperature != 1.0)
|
||||||
|
or (top_k is not None and top_k != 0)
|
||||||
|
or (top_p is not None and top_p < 1.0)
|
||||||
|
or (typical_p is not None and typical_p < 1.0)
|
||||||
|
)
|
||||||
if sampling:
|
if sampling:
|
||||||
self.choice = Sampling(seed, device)
|
self.choice = Sampling(seed, device)
|
||||||
self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
|
self.static_warper = static_warper(
|
||||||
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.choice = Greedy()
|
self.choice = Greedy()
|
||||||
self.static_warper = None
|
self.static_warper = None
|
||||||
|
Loading…
Reference in New Issue
Block a user