diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 29bad321..90fbd635 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -1,7 +1,7 @@ import torch from abc import ABC, abstractmethod -from typing import List, Tuple, Optional, TypeVar, Type +from typing import Generic, List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase from text_generation_server.models.types import Batch, GeneratedText @@ -10,7 +10,7 @@ from text_generation_server.pb.generate_pb2 import InfoResponse B = TypeVar("B", bound=Batch) -class Model(ABC): +class Model(ABC, Generic[B]): def __init__( self, model: torch.nn.Module,