diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 44915ff5..4fc4c389 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -37,7 +37,7 @@ class FlashRW(FlashCausalLM): ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 + dtype = torch.float16 else: raise NotImplementedError("RW is only available on GPU") @@ -124,7 +124,7 @@ class FlashRWSharded(FlashRW): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 + dtype = torch.float16 else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index dd389027..2b1e4959 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -16,7 +16,7 @@ class RW(CausalLM): ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 + dtype = torch.float16 else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 030c8289..134ac7cd 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -23,7 +23,11 @@ def weight_hub_files( """Get the weights filenames on the hub""" api = HfApi() info = api.model_info(model_id, revision=revision) - filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] + filenames = [ + s.rfilename + for s in info.siblings + if s.rfilename.endswith(extension) and len(s.rfilename.split("/")) == 1 + ] if not filenames: raise EntryNotFoundError( @@ -130,24 +134,31 @@ def download_weights( ) -> List[Path]: """Download the safetensors files from the hub""" - def download_file(filename): + def download_file(filename, tries=5): local_file = try_to_load_from_cache(model_id, revision, filename) if local_file is not None: logger.info(f"File {filename} already present in cache.") return Path(local_file) - logger.info(f"Download file: {filename}") - start_time = time.time() - local_file = hf_hub_download( - filename=filename, - repo_id=model_id, - revision=revision, - local_files_only=False, - ) - logger.info( - f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}." - ) - return Path(local_file) + for i in range(tries): + try: + logger.info(f"Download file: {filename}") + start_time = time.time() + local_file = hf_hub_download( + filename=filename, + repo_id=model_id, + revision=revision, + local_files_only=False, + ) + logger.info( + f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}." + ) + return Path(local_file) + except Exception as e: + if i + 1 == tries: + raise e + logger.error(e) + logger.info(f"Retry {i + 1}/{tries - 1}") # We do this instead of using tqdm because we want to parse the logs with the launcher start_time = time.time()