mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
clean dtype
This commit is contained in:
parent
e7826855a3
commit
b9ad3acc4e
@ -197,7 +197,6 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
top_k: List[int],
|
top_k: List[int],
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
filter_value: float = -math.inf,
|
filter_value: float = -math.inf,
|
||||||
min_tokens_to_keep: int = 1,
|
min_tokens_to_keep: int = 1,
|
||||||
|
@ -187,7 +187,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
if any([x != 0 for x in top_k]):
|
if any([x != 0 for x in top_k]):
|
||||||
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, dtype, device))
|
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||||
|
|
||||||
if any([x < 1.0 for x in top_p]):
|
if any([x < 1.0 for x in top_p]):
|
||||||
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
||||||
|
Loading…
Reference in New Issue
Block a user