diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index aebff738..dc2cb21c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -174,12 +174,12 @@ def get_model( model_type = config_dict["model_type"] if model_type == "gpt_bigcode": - return StarCoder(model_id, revision, dtype) + return StarCoder(model_id=model_id, revision=revision, dtype=dtype) if model_type == "bloom": return BLOOM( - model_id, - revision, + model_id=model_id, + revision=revision, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 4a04abed..999c0bb6 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -623,7 +623,7 @@ class CausalLM(Model): def __init__( self, model_id: str, - model_class, + model_class: Optional[Type[torch.nn.Module]] = None, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None,