mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix galactica
This commit is contained in:
parent
aafec48ff3
commit
5632fc5bad
@ -18,9 +18,10 @@ from transformers.models.opt.parallel_layers import (
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.opt import OPT, OPTSharded
|
||||
from text_generation_server.models.opt import OPT
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
@ -184,7 +185,7 @@ class Galactica(OPT):
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
|
||||
class GalacticaSharded(OPTSharded):
|
||||
class GalacticaSharded(Galactica):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user