text-generation-inference/server/text_generation/models/model.py
Nick Hill 31d76e238d
fix(batching): Avoid theoretical hang in batcher loop (#5)
- Avoid theoretical hang in batcher loop
- Avoid a couple of clones in the router generate method
- Keep attention mask tensors as integers
- Remove num_heads attribute

Co-authored-by: OlivierDehaene <Olivier.dehaene@gmail.com>
2022-12-05 10:10:59 +01:00

25 lines
629 B
Python

import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from tokenizers import Tokenizer
from text_generation.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(self, tokenizer: Tokenizer, device: torch.device):
self.tokenizer = tokenizer
self.device = device
@property
@abstractmethod
def batch_type(self) -> Type[B]:
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError