Make B a generic type of Model

This PR fixes typing error at `def batch_type`
This commit is contained in:
Yang, Bo 2023-05-18 07:29:20 -07:00 committed by GitHub
parent 5a58226130
commit 27d30f685a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import Generic, List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, GeneratedText from text_generation_server.models.types import Batch, GeneratedText
@ -10,7 +10,7 @@ from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC, Generic[B]):
def __init__( def __init__(
self, self,
model: torch.nn.Module, model: torch.nn.Module,