diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 908b144c..386b7dc9 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -41,7 +41,7 @@ torch.set_grad_enabled(False) def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: - if model_id.startswith("facebook/galactica"): + if "facebook/galactica" in model_id: if sharded: return GalacticaSharded(model_id, revision, quantize=quantize) else: diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 08c3ac94..83a0d63e 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -58,9 +58,6 @@ class BLOOMSharded(BLOOM): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - if not model_id.startswith("bigscience/bloom"): - raise ValueError(f"Model {model_id} is not supported") - self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 if torch.cuda.is_available(): diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index f3a76459..9a71c5d3 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -164,9 +164,6 @@ class GalacticaSharded(Galactica): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): - if not model_id.startswith("facebook/galactica"): - raise ValueError(f"Model {model_id} is not supported") - self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 if torch.cuda.is_available():