mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
Hotfix gaudi2 with newer transformers.
This commit is contained in:
parent
ad765cd06b
commit
cedb5f07c0
@ -3,7 +3,7 @@ import torch
|
||||
import habana_frameworks.torch.core as htcore
|
||||
|
||||
from loguru import logger
|
||||
from typing import Dict, Union
|
||||
from typing import Dict
|
||||
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
@ -13,7 +13,6 @@ from typing import List, Optional, DefaultDict
|
||||
import time
|
||||
|
||||
from transformers import (
|
||||
LogitsWarper,
|
||||
LogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
@ -191,7 +190,7 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
|
||||
class HeterogeneousTemperatureLogitsWarper:
|
||||
r"""
|
||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||
[`LogitsProcessor`] for temperature (exponential scaling output probability distribution).
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
@ -220,7 +219,7 @@ class HeterogeneousTemperatureLogitsWarper:
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
|
||||
"""
|
||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
@ -279,9 +278,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
class HeterogeneousTopKLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
[`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
@ -360,9 +359,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
return None
|
||||
|
||||
|
||||
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||
[`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
@ -454,13 +453,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
||||
r"""
|
||||
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
||||
Args:
|
||||
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
|
||||
processors (`Dict[int, LogitsProcessor]`):
|
||||
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
|
||||
processors: Dict[int, LogitsProcessor],
|
||||
):
|
||||
self.processors = processors
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user