mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Top p and typical p
This commit is contained in:
parent
cc929530c2
commit
4554a69b22
@ -166,6 +166,7 @@ class VectorizedNextTokenChooser:
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
):
|
):
|
||||||
self.batch_size=batch_size
|
self.batch_size=batch_size
|
||||||
|
self.filter_value = -float("Inf")
|
||||||
|
|
||||||
do_sample=self._standardize(do_sample, False)
|
do_sample=self._standardize(do_sample, False)
|
||||||
|
|
||||||
@ -191,7 +192,7 @@ class VectorizedNextTokenChooser:
|
|||||||
if n_top_k>0:
|
if n_top_k>0:
|
||||||
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
|
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
|
||||||
self.max_top_k=max(top_k)
|
self.max_top_k=max(top_k)
|
||||||
self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.float32, device=device).unsqueeze(1)
|
self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.int64, device=device).unsqueeze(1)
|
||||||
if n_top_k<self.batch_size:
|
if n_top_k<self.batch_size:
|
||||||
self.top_k_mask=torch.tensor([x==0 for x in top_k], dtype=torch.bool, device=device)
|
self.top_k_mask=torch.tensor([x==0 for x in top_k], dtype=torch.bool, device=device)
|
||||||
else:
|
else:
|
||||||
@ -201,14 +202,19 @@ class VectorizedNextTokenChooser:
|
|||||||
self.top_k=None
|
self.top_k=None
|
||||||
self.top_k_mask=None
|
self.top_k_mask=None
|
||||||
|
|
||||||
|
|
||||||
top_p=self._standardize(top_p, 1.0)
|
top_p=self._standardize(top_p, 1.0)
|
||||||
if any([x<1.0 for x in top_p]):
|
if any([x<1.0 for x in top_p]):
|
||||||
raise NotImplementedError("Top P not implemented")
|
do_sample=[sample or x<1.0 for x, sample in zip(temperature, top_p)]
|
||||||
|
self.top_p_inv=torch.tensor([1.0-x for x in top_p], dtype=torch.float32, device=device).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
self.top_p_inv=None
|
||||||
|
|
||||||
typical_p=self._standardize(typical_p, 1.0)
|
typical_p=self._standardize(typical_p, 1.0)
|
||||||
if any([x<1.0 for x in typical_p]):
|
if any([x<1.0 for x in typical_p]):
|
||||||
raise NotImplementedError("Typical P not implemented")
|
do_sample=[sample or x<1.0 for x, sample in zip(typical_p, do_sample)]
|
||||||
|
self.typical_p=torch.tensor(typical_p, dtype=torch.float32, device=device).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
self.typical_p=None
|
||||||
|
|
||||||
self.do_sample = any(do_sample)
|
self.do_sample = any(do_sample)
|
||||||
if self.do_sample and not all(do_sample):
|
if self.do_sample and not all(do_sample):
|
||||||
@ -252,6 +258,33 @@ class VectorizedNextTokenChooser:
|
|||||||
indices_to_remove = scores < kth_scores
|
indices_to_remove = scores < kth_scores
|
||||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
|
if self.top_p_inv is not None:
|
||||||
|
# TODO: Merge wit top_k
|
||||||
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||||
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
|
sorted_indices_to_remove = cumulative_probs <= self.top_p_inv
|
||||||
|
# scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
|
if self.typical_p is not None:
|
||||||
|
# calculate entropy
|
||||||
|
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||||
|
p = torch.exp(normalized)
|
||||||
|
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||||
|
# shift and sort
|
||||||
|
shifted_scores = torch.abs((-normalized) - ent)
|
||||||
|
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||||
|
sorted_logits = scores.gather(-1, sorted_indices)
|
||||||
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
|
# Remove tokens with cumulative mass above the threshold
|
||||||
|
last_ind = (cumulative_probs < self.typical_p).sum(dim=1)
|
||||||
|
last_ind[last_ind < 0] = 0
|
||||||
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
# Compute logprobs
|
# Compute logprobs
|
||||||
logprobs = torch.log_softmax(scores, dim=-1)
|
logprobs = torch.log_softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user