fix galactica

This commit is contained in:
OlivierDehaene 2023-04-11 18:59:13 +02:00
parent aafec48ff3
commit 5632fc5bad

View File

@ -18,9 +18,10 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.opt import OPT, OPTSharded from text_generation_server.models.opt import OPT
from text_generation_server.utils import ( from text_generation_server.utils import (
NextTokenChooser, NextTokenChooser,
StoppingCriteria, StoppingCriteria,
@ -184,7 +185,7 @@ class Galactica(OPT):
return outputs.logits, outputs.past_key_values return outputs.logits, outputs.past_key_values
class GalacticaSharded(OPTSharded): class GalacticaSharded(Galactica):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):