mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Top p and typical p
This commit is contained in:
parent
cc929530c2
commit
4554a69b22
@ -166,6 +166,7 @@ class VectorizedNextTokenChooser:
|
||||
device="cpu",
|
||||
):
|
||||
self.batch_size=batch_size
|
||||
self.filter_value = -float("Inf")
|
||||
|
||||
do_sample=self._standardize(do_sample, False)
|
||||
|
||||
@ -191,7 +192,7 @@ class VectorizedNextTokenChooser:
|
||||
if n_top_k>0:
|
||||
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
|
||||
self.max_top_k=max(top_k)
|
||||
self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.float32, device=device).unsqueeze(1)
|
||||
self.top_k=torch.tensor([max(x-1,0) for x in top_k], dtype=torch.int64, device=device).unsqueeze(1)
|
||||
if n_top_k<self.batch_size:
|
||||
self.top_k_mask=torch.tensor([x==0 for x in top_k], dtype=torch.bool, device=device)
|
||||
else:
|
||||
@ -201,14 +202,19 @@ class VectorizedNextTokenChooser:
|
||||
self.top_k=None
|
||||
self.top_k_mask=None
|
||||
|
||||
|
||||
top_p=self._standardize(top_p, 1.0)
|
||||
if any([x<1.0 for x in top_p]):
|
||||
raise NotImplementedError("Top P not implemented")
|
||||
do_sample=[sample or x<1.0 for x, sample in zip(temperature, top_p)]
|
||||
self.top_p_inv=torch.tensor([1.0-x for x in top_p], dtype=torch.float32, device=device).unsqueeze(1)
|
||||
else:
|
||||
self.top_p_inv=None
|
||||
|
||||
typical_p=self._standardize(typical_p, 1.0)
|
||||
if any([x<1.0 for x in typical_p]):
|
||||
raise NotImplementedError("Typical P not implemented")
|
||||
do_sample=[sample or x<1.0 for x, sample in zip(typical_p, do_sample)]
|
||||
self.typical_p=torch.tensor(typical_p, dtype=torch.float32, device=device).unsqueeze(1)
|
||||
else:
|
||||
self.typical_p=None
|
||||
|
||||
self.do_sample = any(do_sample)
|
||||
if self.do_sample and not all(do_sample):
|
||||
@ -252,6 +258,33 @@ class VectorizedNextTokenChooser:
|
||||
indices_to_remove = scores < kth_scores
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
if self.top_p_inv is not None:
|
||||
# TODO: Merge wit top_k
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= self.top_p_inv
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
if self.typical_p is not None:
|
||||
# calculate entropy
|
||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||
p = torch.exp(normalized)
|
||||
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||
# shift and sort
|
||||
shifted_scores = torch.abs((-normalized) - ent)
|
||||
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||
sorted_logits = scores.gather(-1, sorted_indices)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative mass above the threshold
|
||||
last_ind = (cumulative_probs < self.typical_p).sum(dim=1)
|
||||
last_ind[last_ind < 0] = 0
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = torch.log_softmax(scores, dim=-1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user