remove unused name checks

This commit is contained in:
OlivierDehaene 2023-03-06 13:48:07 +01:00
parent 02df3dea9d
commit 590fc3794c
3 changed files with 1 additions and 7 deletions

View File

@ -41,7 +41,7 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
if model_id.startswith("facebook/galactica"):
if "facebook/galactica" in model_id:
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:

View File

@ -58,9 +58,6 @@ class BLOOMSharded(BLOOM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
if not model_id.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():

View File

@ -164,9 +164,6 @@ class GalacticaSharded(Galactica):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
if not model_id.startswith("facebook/galactica"):
raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():