mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fmt
This commit is contained in:
parent
2aa5004482
commit
b163aef8ed
@ -10,7 +10,12 @@ B = TypeVar("B", bound=Batch)
|
|||||||
|
|
||||||
|
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device, decode_buffer: int = 3):
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
device: torch.device,
|
||||||
|
decode_buffer: int = 3,
|
||||||
|
):
|
||||||
if decode_buffer < 1:
|
if decode_buffer < 1:
|
||||||
raise ValueError("decode_buffer must be >= 1")
|
raise ValueError("decode_buffer must be >= 1")
|
||||||
|
|
||||||
@ -57,7 +62,9 @@ class Model(ABC):
|
|||||||
sequence_text = raw_texts[1]
|
sequence_text = raw_texts[1]
|
||||||
else:
|
else:
|
||||||
# Only decode the last token without using a token buffer
|
# Only decode the last token without using a token buffer
|
||||||
sequence_text = self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False)
|
sequence_text = self.tokenizer.decode(
|
||||||
|
all_input_ids[-1], skip_special_tokens=False
|
||||||
|
)
|
||||||
# no offset in this case
|
# no offset in this case
|
||||||
offset = 0
|
offset = 0
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user