mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fmt
This commit is contained in:
parent
2aa5004482
commit
b163aef8ed
@ -10,7 +10,12 @@ B = TypeVar("B", bound=Batch)
|
||||
|
||||
|
||||
class Model(ABC):
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device, decode_buffer: int = 3):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
decode_buffer: int = 3,
|
||||
):
|
||||
if decode_buffer < 1:
|
||||
raise ValueError("decode_buffer must be >= 1")
|
||||
|
||||
@ -29,10 +34,10 @@ class Model(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode_token(
|
||||
self,
|
||||
all_input_ids: List[int],
|
||||
offset: Optional[int] = None,
|
||||
token_offset: Optional[int] = None,
|
||||
self,
|
||||
all_input_ids: List[int],
|
||||
offset: Optional[int] = None,
|
||||
token_offset: Optional[int] = None,
|
||||
) -> Tuple[str, Optional[int], Optional[int]]:
|
||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||
if all_input_ids[-1] in self.all_special_ids:
|
||||
@ -57,7 +62,9 @@ class Model(ABC):
|
||||
sequence_text = raw_texts[1]
|
||||
else:
|
||||
# Only decode the last token without using a token buffer
|
||||
sequence_text = self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False)
|
||||
sequence_text = self.tokenizer.decode(
|
||||
all_input_ids[-1], skip_special_tokens=False
|
||||
)
|
||||
# no offset in this case
|
||||
offset = 0
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user