clean dtype

This commit is contained in:
OlivierDehaene 2023-05-12 15:53:56 +02:00
parent e7826855a3
commit b9ad3acc4e
2 changed files with 1 additions and 2 deletions

View File

@ -197,7 +197,6 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
def __init__( def __init__(
self, self,
top_k: List[int], top_k: List[int],
dtype: torch.dtype,
device: torch.device, device: torch.device,
filter_value: float = -math.inf, filter_value: float = -math.inf,
min_tokens_to_keep: int = 1, min_tokens_to_keep: int = 1,

View File

@ -187,7 +187,7 @@ class HeterogeneousNextTokenChooser:
if any([x != 0 for x in top_k]): if any([x != 0 for x in top_k]):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, dtype, device)) warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
if any([x < 1.0 for x in top_p]): if any([x < 1.0 for x in top_p]):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]