Top p and typical p

This commit is contained in:
Joel Lamy-Poirier 2023-05-03 14:55:08 -04:00
parent cc929530c2
commit 4554a69b22
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -166,6 +166,7 @@ class VectorizedNextTokenChooser:
device="cpu",
):
self.batch_size=batch_size
self.filter_value = -float("Inf")
do_sample=self._standardize(do_sample, False)
@ -191,7 +192,7 @@ class VectorizedNextTokenChooser:
if n_top_k>0:
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
self.max_top_k=max(top_k)
self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.float32, device=device).unsqueeze(1)
self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.int64, device=device).unsqueeze(1)
if n_top_k<self.batch_size:
self.top_k_mask=torch.tensor([x==0 for x in top_k], dtype=torch.bool, device=device)
else:
@ -201,14 +202,19 @@ class VectorizedNextTokenChooser:
self.top_k=None
self.top_k_mask=None
top_p=self._standardize(top_p, 1.0)
if any([x<1.0 for x in top_p]):
raise NotImplementedError("Top P not implemented")
do_sample=[sample or x<1.0 for x, sample in zip(temperature, top_p)]
self.top_p_inv=torch.tensor([1.0-x for x in top_p], dtype=torch.float32, device=device).unsqueeze(1)
else:
self.top_p_inv=None
typical_p=self._standardize(typical_p, 1.0)
if any([x<1.0 for x in typical_p]):
raise NotImplementedError("Typical P not implemented")
do_sample=[sample or x<1.0 for x, sample in zip(typical_p, do_sample)]
self.typical_p=torch.tensor(typical_p, dtype=torch.float32, device=device).unsqueeze(1)
else:
self.typical_p=None
self.do_sample = any(do_sample)
if self.do_sample and not all(do_sample):
@ -252,6 +258,33 @@ class VectorizedNextTokenChooser:
indices_to_remove = scores < kth_scores
scores = scores.masked_fill(indices_to_remove, self.filter_value)
if self.top_p_inv is not None:
# TODO: Merge wit top_k
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= self.top_p_inv
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
if self.typical_p is not None:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.typical_p).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
# Compute logprobs
logprobs = torch.log_softmax(scores, dim=-1)