mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
add cpu support
This commit is contained in:
parent
e2727387aa
commit
1df2aa03c5
@ -36,11 +36,11 @@ class Greedy:
|
||||
|
||||
class StaticWarper:
|
||||
def __init__(
|
||||
self,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
self,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
):
|
||||
self.warpers = []
|
||||
|
||||
@ -64,18 +64,14 @@ class StaticWarper:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
capture_stream = torch.cuda.stream(torch.cuda.Stream())
|
||||
capture_stream.__enter__()
|
||||
self.cuda_graph.capture_begin()
|
||||
with torch.cuda.graph(self.cuda_graph):
|
||||
for warper in self.warpers:
|
||||
self.static_warped_scores = warper(None, self.static_scores)
|
||||
|
||||
for warper in self.warpers:
|
||||
self.static_warped_scores = warper(None, self.static_scores)
|
||||
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(self.static_warped_scores, -1)
|
||||
|
||||
self.cuda_graph.capture_end()
|
||||
capture_stream.__exit__(None, None, None)
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(
|
||||
self.static_warped_scores, -1
|
||||
)
|
||||
|
||||
self.static_scores.copy_(scores)
|
||||
self.cuda_graph.replay()
|
||||
@ -84,34 +80,51 @@ class StaticWarper:
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def static_warper(temperature: Optional[float], top_k: Optional[int], top_p: Optional[float],
|
||||
typical_p: Optional[float]) -> StaticWarper:
|
||||
return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
|
||||
def static_warper(
|
||||
temperature: Optional[float],
|
||||
top_k: Optional[int],
|
||||
top_p: Optional[float],
|
||||
typical_p: Optional[float],
|
||||
) -> StaticWarper:
|
||||
return StaticWarper(
|
||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||
)
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
watermark=False,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
do_sample=False,
|
||||
seed=0,
|
||||
device="cpu",
|
||||
self,
|
||||
watermark=False,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
typical_p=None,
|
||||
do_sample=False,
|
||||
seed=0,
|
||||
device="cpu",
|
||||
):
|
||||
self.watermark_warper = WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
self.repetition_warper = RepetitionPenaltyLogitsProcessor(
|
||||
penalty=repetition_penalty) if repetition_penalty else None
|
||||
self.watermark_warper = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
)
|
||||
self.repetition_warper = (
|
||||
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
||||
if repetition_penalty
|
||||
else None
|
||||
)
|
||||
|
||||
sampling = do_sample or (temperature is not None and temperature != 1.0) or (
|
||||
top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (
|
||||
typical_p is not None and typical_p < 1.0)
|
||||
sampling = (
|
||||
do_sample
|
||||
or (temperature is not None and temperature != 1.0)
|
||||
or (top_k is not None and top_k != 0)
|
||||
or (top_p is not None and top_p < 1.0)
|
||||
or (typical_p is not None and typical_p < 1.0)
|
||||
)
|
||||
if sampling:
|
||||
self.choice = Sampling(seed, device)
|
||||
self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
|
||||
self.static_warper = static_warper(
|
||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||
)
|
||||
else:
|
||||
self.choice = Greedy()
|
||||
self.static_warper = None
|
||||
@ -133,9 +146,9 @@ class NextTokenChooser:
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
device: torch.device,
|
||||
cls,
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
device: torch.device,
|
||||
) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
watermark=pb.watermark,
|
||||
@ -163,11 +176,11 @@ class StopSequenceCriteria:
|
||||
|
||||
class StoppingCriteria:
|
||||
def __init__(
|
||||
self,
|
||||
eos_token_id: int,
|
||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||
max_new_tokens: int = 20,
|
||||
ignore_eos_token: bool = False,
|
||||
self,
|
||||
eos_token_id: int,
|
||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||
max_new_tokens: int = 20,
|
||||
ignore_eos_token: bool = False,
|
||||
):
|
||||
self.eos_token_id = eos_token_id
|
||||
self.stop_sequence_criterias = stop_sequence_criterias
|
||||
@ -193,9 +206,9 @@ class StoppingCriteria:
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.StoppingCriteriaParameters,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cls,
|
||||
pb: generate_pb2.StoppingCriteriaParameters,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> "StoppingCriteria":
|
||||
stop_sequence_criterias = [
|
||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||
|
Loading…
Reference in New Issue
Block a user