From b163aef8edd4ef05e14e843d263d13c23c2de6b9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 12 Apr 2023 11:31:55 +0200 Subject: [PATCH] fmt --- server/text_generation_server/models/model.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 145e205a..08a48553 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -10,7 +10,12 @@ B = TypeVar("B", bound=Batch) class Model(ABC): - def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device, decode_buffer: int = 3): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + decode_buffer: int = 3, + ): if decode_buffer < 1: raise ValueError("decode_buffer must be >= 1") @@ -29,10 +34,10 @@ class Model(ABC): raise NotImplementedError def decode_token( - self, - all_input_ids: List[int], - offset: Optional[int] = None, - token_offset: Optional[int] = None, + self, + all_input_ids: List[int], + offset: Optional[int] = None, + token_offset: Optional[int] = None, ) -> Tuple[str, Optional[int], Optional[int]]: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" if all_input_ids[-1] in self.all_special_ids: @@ -57,7 +62,9 @@ class Model(ABC): sequence_text = raw_texts[1] else: # Only decode the last token without using a token buffer - sequence_text = self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False) + sequence_text = self.tokenizer.decode( + all_input_ids[-1], skip_special_tokens=False + ) # no offset in this case offset = 0 else: