diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 225133f5..a2ad2d5e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -116,11 +116,13 @@ class FlashCausalLMBatch(Batch): max_length = 0 # Parse batch - for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)): + for i, (r, tokenized_input) in enumerate( + zip(pb.requests, batch_tokenized_inputs) + ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenized_input[-r.truncate:] + tokenized_input = tokenized_input[-r.truncate :] input_length = len(tokenized_input) max_seqlen = max(max_seqlen, input_length) diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 134ac7cd..2ed7673c 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -134,7 +134,7 @@ def download_weights( ) -> List[Path]: """Download the safetensors files from the hub""" - def download_file(filename, tries=5): + def download_file(filename, tries=5, backoff: int = 5): local_file = try_to_load_from_cache(model_id, revision, filename) if local_file is not None: logger.info(f"File {filename} already present in cache.") @@ -158,6 +158,8 @@ def download_weights( if i + 1 == tries: raise e logger.error(e) + logger.info(f"Retrying in {backoff} seconds") + time.sleep(backoff) logger.info(f"Retry {i + 1}/{tries - 1}") # We do this instead of using tqdm because we want to parse the logs with the launcher