fix galactica

This commit is contained in:
OlivierDehaene 2023-04-11 18:59:13 +02:00
parent aafec48ff3
commit 5632fc5bad

View File

@ -18,9 +18,10 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.opt import OPT, OPTSharded
from text_generation_server.models.opt import OPT
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
@ -184,7 +185,7 @@ class Galactica(OPT):
return outputs.logits, outputs.past_key_values
class GalacticaSharded(OPTSharded):
class GalacticaSharded(Galactica):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):