This commit is contained in:
OlivierDehaene 2023-06-05 15:34:54 +02:00
parent 89c5621ecf
commit e09314a72f
2 changed files with 7 additions and 3 deletions

View File

@ -116,7 +116,9 @@ class FlashCausalLMBatch(Batch):
max_length = 0 max_length = 0
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i

View File

@ -134,7 +134,7 @@ def download_weights(
) -> List[Path]: ) -> List[Path]:
"""Download the safetensors files from the hub""" """Download the safetensors files from the hub"""
def download_file(filename, tries=5): def download_file(filename, tries=5, backoff: int = 5):
local_file = try_to_load_from_cache(model_id, revision, filename) local_file = try_to_load_from_cache(model_id, revision, filename)
if local_file is not None: if local_file is not None:
logger.info(f"File {filename} already present in cache.") logger.info(f"File {filename} already present in cache.")
@ -158,6 +158,8 @@ def download_weights(
if i + 1 == tries: if i + 1 == tries:
raise e raise e
logger.error(e) logger.error(e)
logger.info(f"Retrying in {backoff} seconds")
time.sleep(backoff)
logger.info(f"Retry {i + 1}/{tries - 1}") logger.info(f"Retry {i + 1}/{tries - 1}")
# We do this instead of using tqdm because we want to parse the logs with the launcher # We do this instead of using tqdm because we want to parse the logs with the launcher