mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Hotfix gaudi2 with newer transformers.
This commit is contained in:
parent
ad765cd06b
commit
cedb5f07c0
@ -3,7 +3,7 @@ import torch
|
|||||||
import habana_frameworks.torch.core as htcore
|
import habana_frameworks.torch.core as htcore
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Dict, Union
|
from typing import Dict
|
||||||
from text_generation_server.pb.generate_pb2 import GrammarType
|
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||||
|
|
||||||
from outlines.fsm.fsm import RegexFSM
|
from outlines.fsm.fsm import RegexFSM
|
||||||
@ -13,7 +13,6 @@ from typing import List, Optional, DefaultDict
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LogitsWarper,
|
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
@ -191,7 +190,7 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
class HeterogeneousTemperatureLogitsWarper:
|
class HeterogeneousTemperatureLogitsWarper:
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
[`LogitsProcessor`] for temperature (exponential scaling output probability distribution).
|
||||||
This version allows for a separate value for each sample and runs inplace when possible.
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
It doesn't validate inputs.
|
It doesn't validate inputs.
|
||||||
|
|
||||||
@ -220,7 +219,7 @@ class HeterogeneousTemperatureLogitsWarper:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
|
||||||
"""
|
"""
|
||||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
This version allows for a separate value for each sample and runs inplace when possible.
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
@ -279,9 +278,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
class HeterogeneousTopKLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
[`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||||
This version allows for a separate value for each sample and runs inplace when possible.
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
It doesn't validate inputs.
|
It doesn't validate inputs.
|
||||||
|
|
||||||
@ -360,9 +359,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
[`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||||
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||||
This version allows for a separate value for each sample and runs inplace when possible.
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
It doesn't validate inputs.
|
It doesn't validate inputs.
|
||||||
@ -454,13 +453,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
|||||||
r"""
|
r"""
|
||||||
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
A wrapper for logit warpers or processors without heterogeneous parameter support.
|
||||||
Args:
|
Args:
|
||||||
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
|
processors (`Dict[int, LogitsProcessor]`):
|
||||||
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
|
processors: Dict[int, LogitsProcessor],
|
||||||
):
|
):
|
||||||
self.processors = processors
|
self.processors = processors
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user