text-generation-inference/server/text_generation/models/model.py
Nick Hill e6d3eb5d5d
fix(server): Minor refactorization using new_zeros (#24)
- Fix some type hints, in particular base tokenizer class
- Make use of `tensor.new_zero/empty` methods
- Simplify env var string parsing in launcher
2023-01-17 09:10:22 +01:00

25 lines
659 B
Python

import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
self.tokenizer = tokenizer
self.device = device
@property
@abstractmethod
def batch_type(self) -> Type[B]:
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError