This commit is contained in:
OlivierDehaene 2023-04-12 11:31:55 +02:00
parent 2aa5004482
commit b163aef8ed

View File

@ -10,7 +10,12 @@ B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device, decode_buffer: int = 3):
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
decode_buffer: int = 3,
):
if decode_buffer < 1:
raise ValueError("decode_buffer must be >= 1")
@ -29,10 +34,10 @@ class Model(ABC):
raise NotImplementedError
def decode_token(
self,
all_input_ids: List[int],
offset: Optional[int] = None,
token_offset: Optional[int] = None,
self,
all_input_ids: List[int],
offset: Optional[int] = None,
token_offset: Optional[int] = None,
) -> Tuple[str, Optional[int], Optional[int]]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
if all_input_ids[-1] in self.all_special_ids:
@ -57,7 +62,9 @@ class Model(ABC):
sequence_text = raw_texts[1]
else:
# Only decode the last token without using a token buffer
sequence_text = self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False)
sequence_text = self.tokenizer.decode(
all_input_ids[-1], skip_special_tokens=False
)
# no offset in this case
offset = 0
else: