Top k attempt

This commit is contained in:
Joel Lamy-Poirier 2023-05-03 14:25:35 -04:00
parent 5677540881
commit d5ff681b00
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -175,23 +175,31 @@ class VectorizedNextTokenChooser:
repetition_penalty=self._standardize(repetition_penalty, 1.0)
if any([x!=1.0 for x in repetition_penalty]):
self.repetition_penalty=torch.tensor([repetition_penalty], dtype=torch.float32, device=device).unsqueeze(1)
self.repetition_penalty=torch.tensor(repetition_penalty, dtype=torch.float32, device=device).unsqueeze(1)
else:
self.repetition_penalty=None
temperature=self._standardize(temperature, 1.0)
if any([x!=1.0 for x in temperature]):
do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)]
self.temperature=torch.tensor([temperature], dtype=torch.float32, device=device).unsqueeze(1)
self.temperature=torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1)
else:
self.temperature=None
top_k=self._standardize(top_k, 0)
if any([x!=0 for x in top_k]):
n_top_k=sum([x!=0 for x in top_k])
if n_top_k>0:
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
self.top_k=torch.tensor([top_k], dtype=torch.float32, device=device).unsqueeze(1)
self.max_top_k=max(top_k)
self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.float32, device=device).unsqueeze(1)
if n_top_k<self.batch_size:
self.top_k_mask=torch.tensor([x==0 for x in top_k], dtype=torch.bool, device=device)
else:
self.top_k_mask=None
else:
self.max_top_k=None
self.top_k=None
self.top_k_mask=None
top_p=self._standardize(top_p, 1.0)
@ -231,9 +239,17 @@ class VectorizedNextTokenChooser:
scores.div_(self.temperature)
if self.top_k is not None:
top_k = min(self.top_k, scores.size(-1)) # Safety check
if scores.size(-1)>self.max_top_k: # Safety check
max_top_k=scores.size(-1)
top_k=torch.clamp_max(self.top_k,max_top_k) # Run only if needed.
else:
max_top_k=self.max_top_k
top_k=self.top_k
kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
if self.top_k_mask is not None:
kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
indices_to_remove = scores < kth_scores
scores = scores.masked_fill(indices_to_remove, self.filter_value)
# Compute logprobs