mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Top k attempt
This commit is contained in:
parent
5677540881
commit
d5ff681b00
@ -175,23 +175,31 @@ class VectorizedNextTokenChooser:
|
||||
|
||||
repetition_penalty=self._standardize(repetition_penalty, 1.0)
|
||||
if any([x!=1.0 for x in repetition_penalty]):
|
||||
self.repetition_penalty=torch.tensor([repetition_penalty], dtype=torch.float32, device=device).unsqueeze(1)
|
||||
self.repetition_penalty=torch.tensor(repetition_penalty, dtype=torch.float32, device=device).unsqueeze(1)
|
||||
else:
|
||||
self.repetition_penalty=None
|
||||
|
||||
temperature=self._standardize(temperature, 1.0)
|
||||
if any([x!=1.0 for x in temperature]):
|
||||
do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)]
|
||||
self.temperature=torch.tensor([temperature], dtype=torch.float32, device=device).unsqueeze(1)
|
||||
self.temperature=torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1)
|
||||
else:
|
||||
self.temperature=None
|
||||
|
||||
top_k=self._standardize(top_k, 0)
|
||||
if any([x!=0 for x in top_k]):
|
||||
n_top_k=sum([x!=0 for x in top_k])
|
||||
if n_top_k>0:
|
||||
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
|
||||
self.top_k=torch.tensor([top_k], dtype=torch.float32, device=device).unsqueeze(1)
|
||||
self.max_top_k=max(top_k)
|
||||
self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.float32, device=device).unsqueeze(1)
|
||||
if n_top_k<self.batch_size:
|
||||
self.top_k_mask=torch.tensor([x==0 for x in top_k], dtype=torch.bool, device=device)
|
||||
else:
|
||||
self.top_k_mask=None
|
||||
else:
|
||||
self.max_top_k=None
|
||||
self.top_k=None
|
||||
self.top_k_mask=None
|
||||
|
||||
|
||||
top_p=self._standardize(top_p, 1.0)
|
||||
@ -231,9 +239,17 @@ class VectorizedNextTokenChooser:
|
||||
scores.div_(self.temperature)
|
||||
|
||||
if self.top_k is not None:
|
||||
top_k = min(self.top_k, scores.size(-1)) # Safety check
|
||||
if scores.size(-1)>self.max_top_k: # Safety check
|
||||
max_top_k=scores.size(-1)
|
||||
top_k=torch.clamp_max(self.top_k,max_top_k) # Run only if needed.
|
||||
else:
|
||||
max_top_k=self.max_top_k
|
||||
top_k=self.top_k
|
||||
kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
|
||||
if self.top_k_mask is not None:
|
||||
kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
|
||||
indices_to_remove = scores < kth_scores
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
# Compute logprobs
|
||||
|
Loading…
Reference in New Issue
Block a user