Make B a generic type of Model

This PR fixes typing error at `def batch_type`
This commit is contained in:
Yang, Bo 2023-05-18 07:29:20 -07:00 committed by GitHub
parent 5a58226130
commit 27d30f685a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from typing import Generic, List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, GeneratedText
@ -10,7 +10,7 @@ from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
class Model(ABC):
class Model(ABC, Generic[B]):
def __init__(
self,
model: torch.nn.Module,