mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
clean dtype
This commit is contained in:
parent
e7826855a3
commit
b9ad3acc4e
@ -197,7 +197,6 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
def __init__(
|
||||
self,
|
||||
top_k: List[int],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
filter_value: float = -math.inf,
|
||||
min_tokens_to_keep: int = 1,
|
||||
|
@ -187,7 +187,7 @@ class HeterogeneousNextTokenChooser:
|
||||
|
||||
if any([x != 0 for x in top_k]):
|
||||
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, dtype, device))
|
||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||
|
||||
if any([x < 1.0 for x in top_p]):
|
||||
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
||||
|
Loading…
Reference in New Issue
Block a user