Hotfix gaudi2 with newer transformers. (#3176)

This commit is contained in:
Nicolas Patry 2025-04-15 12:39:28 +02:00 committed by GitHub
parent ad765cd06b
commit 4645678ff0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@ import torch
import habana_frameworks.torch.core as htcore
from loguru import logger
from typing import Dict, Union
from typing import Dict
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM
@ -13,7 +13,6 @@ from typing import List, Optional, DefaultDict
import time
from transformers import (
LogitsWarper,
LogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
@ -191,7 +190,7 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
class HeterogeneousTemperatureLogitsWarper:
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
[`LogitsProcessor`] for temperature (exponential scaling output probability distribution).
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
@ -220,7 +219,7 @@ class HeterogeneousTemperatureLogitsWarper:
return None
class HeterogeneousTopPLogitsWarper(LogitsWarper):
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible.
@ -279,9 +278,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
return None
class HeterogeneousTopKLogitsWarper(LogitsWarper):
class HeterogeneousTopKLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
[`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
@ -360,9 +359,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
return None
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
[`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
@ -454,13 +453,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
r"""
A wrapper for logit warpers or processors without heterogeneous parameter support.
Args:
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
processors (`Dict[int, LogitsProcessor]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially.
"""
def __init__(
self,
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
processors: Dict[int, LogitsProcessor],
):
self.processors = processors