mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix galactica
This commit is contained in:
parent
aafec48ff3
commit
5632fc5bad
@ -18,9 +18,10 @@ from transformers.models.opt.parallel_layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.opt import OPT, OPTSharded
|
from text_generation_server.models.opt import OPT
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
NextTokenChooser,
|
NextTokenChooser,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
@ -184,7 +185,7 @@ class Galactica(OPT):
|
|||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, outputs.past_key_values
|
||||||
|
|
||||||
|
|
||||||
class GalacticaSharded(OPTSharded):
|
class GalacticaSharded(Galactica):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
|
Loading…
Reference in New Issue
Block a user