diff --git a/backends/gaudi/server/text_generation_server/utils/logits_process.py b/backends/gaudi/server/text_generation_server/utils/logits_process.py index 472f2dcb0..c0fd6cbae 100644 --- a/backends/gaudi/server/text_generation_server/utils/logits_process.py +++ b/backends/gaudi/server/text_generation_server/utils/logits_process.py @@ -3,7 +3,7 @@ import torch import habana_frameworks.torch.core as htcore from loguru import logger -from typing import Dict, Union +from typing import Dict from text_generation_server.pb.generate_pb2 import GrammarType from outlines.fsm.fsm import RegexFSM @@ -13,7 +13,6 @@ from typing import List, Optional, DefaultDict import time from transformers import ( - LogitsWarper, LogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, @@ -191,7 +190,7 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): class HeterogeneousTemperatureLogitsWarper: r""" - [`LogitsWarper`] for temperature (exponential scaling output probability distribution). + [`LogitsProcessor`] for temperature (exponential scaling output probability distribution). This version allows for a separate value for each sample and runs inplace when possible. It doesn't validate inputs. @@ -220,7 +219,7 @@ class HeterogeneousTemperatureLogitsWarper: return None -class HeterogeneousTopPLogitsWarper(LogitsWarper): +class HeterogeneousTopPLogitsWarper(LogitsProcessor): """ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. This version allows for a separate value for each sample and runs inplace when possible. @@ -279,9 +278,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): return None -class HeterogeneousTopKLogitsWarper(LogitsWarper): +class HeterogeneousTopKLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. This version allows for a separate value for each sample and runs inplace when possible. It doesn't validate inputs. @@ -360,9 +359,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper): return None -class HeterogeneousTypicalLogitsWarper(LogitsWarper): +class HeterogeneousTypicalLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language + [`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. This version allows for a separate value for each sample and runs inplace when possible. It doesn't validate inputs. @@ -454,13 +453,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): r""" A wrapper for logit warpers or processors without heterogeneous parameter support. Args: - processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): + processors (`Dict[int, LogitsProcessor]`): A mapping of sample indices to logit warpers or processors, to be run sequentially. """ def __init__( self, - processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], + processors: Dict[int, LogitsProcessor], ): self.processors = processors