mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
remove unused name checks
This commit is contained in:
parent
02df3dea9d
commit
590fc3794c
@ -41,7 +41,7 @@ torch.set_grad_enabled(False)
|
||||
def get_model(
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
) -> Model:
|
||||
if model_id.startswith("facebook/galactica"):
|
||||
if "facebook/galactica" in model_id:
|
||||
if sharded:
|
||||
return GalacticaSharded(model_id, revision, quantize=quantize)
|
||||
else:
|
||||
|
@ -58,9 +58,6 @@ class BLOOMSharded(BLOOM):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
if not model_id.startswith("bigscience/bloom"):
|
||||
raise ValueError(f"Model {model_id} is not supported")
|
||||
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
if torch.cuda.is_available():
|
||||
|
@ -164,9 +164,6 @@ class GalacticaSharded(Galactica):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
if not model_id.startswith("facebook/galactica"):
|
||||
raise ValueError(f"Model {model_id} is not supported")
|
||||
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
if torch.cuda.is_available():
|
||||
|
Loading…
Reference in New Issue
Block a user