diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 58daee0b..396cc4f6 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -18,9 +18,10 @@ from transformers.models.opt.parallel_layers import ( TensorParallelRowLinear, ) +from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 -from text_generation_server.models.opt import OPT, OPTSharded +from text_generation_server.models.opt import OPT from text_generation_server.utils import ( NextTokenChooser, StoppingCriteria, @@ -184,7 +185,7 @@ class Galactica(OPT): return outputs.logits, outputs.past_key_values -class GalacticaSharded(OPTSharded): +class GalacticaSharded(Galactica): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ):