mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
black
This commit is contained in:
parent
89c5621ecf
commit
e09314a72f
@ -116,11 +116,13 @@ class FlashCausalLMBatch(Batch):
|
||||
max_length = 0
|
||||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)):
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
zip(pb.requests, batch_tokenized_inputs)
|
||||
):
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate:]
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
|
@ -134,7 +134,7 @@ def download_weights(
|
||||
) -> List[Path]:
|
||||
"""Download the safetensors files from the hub"""
|
||||
|
||||
def download_file(filename, tries=5):
|
||||
def download_file(filename, tries=5, backoff: int = 5):
|
||||
local_file = try_to_load_from_cache(model_id, revision, filename)
|
||||
if local_file is not None:
|
||||
logger.info(f"File {filename} already present in cache.")
|
||||
@ -158,6 +158,8 @@ def download_weights(
|
||||
if i + 1 == tries:
|
||||
raise e
|
||||
logger.error(e)
|
||||
logger.info(f"Retrying in {backoff} seconds")
|
||||
time.sleep(backoff)
|
||||
logger.info(f"Retry {i + 1}/{tries - 1}")
|
||||
|
||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||
|
Loading…
Reference in New Issue
Block a user