mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
cleanup
This commit is contained in:
parent
3248fdfbd4
commit
a944dd0fd5
@ -29,7 +29,7 @@ class Sampling:
|
|||||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||||
torch.div(probs, q, out=q)
|
torch.div(probs, q, out=q)
|
||||||
|
|
||||||
return torch.argmax(q, dim=-1, keepdim=True)
|
return q.argmax()
|
||||||
|
|
||||||
|
|
||||||
class Greedy:
|
class Greedy:
|
||||||
@ -107,36 +107,36 @@ class NextTokenChooser:
|
|||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
):
|
):
|
||||||
self.watermark_warper = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
)
|
)
|
||||||
self.repetition_warper = (
|
self.repetition_processor = (
|
||||||
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
||||||
if repetition_penalty
|
if repetition_penalty
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling = (
|
has_warpers = (
|
||||||
do_sample
|
(temperature is not None and temperature != 1.0)
|
||||||
or (temperature is not None and temperature != 1.0)
|
|
||||||
or (top_k is not None and top_k != 0)
|
or (top_k is not None and top_k != 0)
|
||||||
or (top_p is not None and top_p < 1.0)
|
or (top_p is not None and top_p < 1.0)
|
||||||
or (typical_p is not None and typical_p < 1.0)
|
or (typical_p is not None and typical_p < 1.0)
|
||||||
)
|
)
|
||||||
if sampling:
|
if has_warpers:
|
||||||
self.choice = Sampling(seed, device)
|
|
||||||
self.static_warper = static_warper(
|
self.static_warper = static_warper(
|
||||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.choice = Greedy()
|
|
||||||
self.static_warper = None
|
self.static_warper = None
|
||||||
|
|
||||||
|
sampling = do_sample or has_warpers
|
||||||
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
if self.watermark_warper:
|
if self.watermark_processor:
|
||||||
scores = self.watermark_warper(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_warper:
|
if self.repetition_processor:
|
||||||
scores = self.repetition_warper(input_ids, scores)
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
next_logprob = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
|
Loading…
Reference in New Issue
Block a user