Fix typing in Model.generate_token

This commit is contained in:
Jae-Won Chung 2023-07-28 17:23:41 -04:00
parent 3ef5ffbc64
commit 5a465fa40a

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.models.types import Batch, Generation
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
@ -52,7 +52,7 @@ class Model(ABC):
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
raise NotImplementedError
def warmup(self, batch: B) -> Optional[int]: