Hotfix gaudi2 with newer transformers. (#3176)

This commit is contained in:
Nicolas Patry 2025-04-15 12:39:28 +02:00 committed by GitHub
parent ad765cd06b
commit 4645678ff0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@ import torch
import habana_frameworks.torch.core as htcore import habana_frameworks.torch.core as htcore
from loguru import logger from loguru import logger
from typing import Dict, Union from typing import Dict
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM from outlines.fsm.fsm import RegexFSM
@ -13,7 +13,6 @@ from typing import List, Optional, DefaultDict
import time import time
from transformers import ( from transformers import (
LogitsWarper,
LogitsProcessor, LogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
@ -191,7 +190,7 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
class HeterogeneousTemperatureLogitsWarper: class HeterogeneousTemperatureLogitsWarper:
r""" r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution). [`LogitsProcessor`] for temperature (exponential scaling output probability distribution).
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs. It doesn't validate inputs.
@ -220,7 +219,7 @@ class HeterogeneousTemperatureLogitsWarper:
return None return None
class HeterogeneousTopPLogitsWarper(LogitsWarper): class HeterogeneousTopPLogitsWarper(LogitsProcessor):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
@ -279,9 +278,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
return None return None
class HeterogeneousTopKLogitsWarper(LogitsWarper): class HeterogeneousTopKLogitsWarper(LogitsProcessor):
r""" r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs. It doesn't validate inputs.
@ -360,9 +359,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
return None return None
class HeterogeneousTypicalLogitsWarper(LogitsWarper): class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
r""" r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language [`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information. Generation](https://arxiv.org/abs/2202.00666) for more information.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs. It doesn't validate inputs.
@ -454,13 +453,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
r""" r"""
A wrapper for logit warpers or processors without heterogeneous parameter support. A wrapper for logit warpers or processors without heterogeneous parameter support.
Args: Args:
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): processors (`Dict[int, LogitsProcessor]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially. A mapping of sample indices to logit warpers or processors, to be run sequentially.
""" """
def __init__( def __init__(
self, self,
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], processors: Dict[int, LogitsProcessor],
): ):
self.processors = processors self.processors = processors