2022-11-04 13:22:47 +00:00
|
|
|
import torch
|
|
|
|
|
2022-11-03 15:07:54 +00:00
|
|
|
from abc import ABC, abstractmethod
|
2022-11-04 17:03:04 +00:00
|
|
|
from typing import List, Tuple, Optional, TypeVar, Type
|
2023-01-17 08:10:22 +00:00
|
|
|
from transformers import PreTrainedTokenizerBase
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.types import Batch, GeneratedText
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2022-11-04 17:03:04 +00:00
|
|
|
B = TypeVar("B", bound=Batch)
|
|
|
|
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2022-11-03 15:07:54 +00:00
|
|
|
class Model(ABC):
|
2023-01-17 08:10:22 +00:00
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
2022-11-04 13:22:47 +00:00
|
|
|
self.tokenizer = tokenizer
|
2023-02-24 14:55:57 +00:00
|
|
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
2022-11-04 13:22:47 +00:00
|
|
|
self.device = device
|
|
|
|
|
2022-11-04 17:03:04 +00:00
|
|
|
@property
|
2022-11-03 15:07:54 +00:00
|
|
|
@abstractmethod
|
2022-11-04 17:03:04 +00:00
|
|
|
def batch_type(self) -> Type[B]:
|
2022-11-03 15:07:54 +00:00
|
|
|
raise NotImplementedError
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2022-11-04 17:03:04 +00:00
|
|
|
@abstractmethod
|
|
|
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
|
|
|
raise NotImplementedError
|
2023-03-06 12:22:58 +00:00
|
|
|
|
2023-04-04 10:35:29 +00:00
|
|
|
def decode_token(self, previous_token_id: int, token_id: int) -> str:
|
2023-03-06 12:22:58 +00:00
|
|
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
2023-04-04 10:35:29 +00:00
|
|
|
# Decode previous token and previous token + token
|
|
|
|
results = self.tokenizer.batch_decode(
|
|
|
|
[[previous_token_id], [previous_token_id, token_id]],
|
|
|
|
skip_special_tokens=False,
|
2023-03-06 12:22:58 +00:00
|
|
|
)
|
2023-04-04 10:35:29 +00:00
|
|
|
# slice to remove previous token
|
|
|
|
return results[1][len(results[0]) :]
|