From 52c9ff9aca1dab15e1b319159469ffbdf0ce33fd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 17 May 2024 14:20:58 +0000 Subject: [PATCH] Optional base_name_or_model_path. --- server/text_generation_server/cli.py | 58 ++++++++++++++-------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index ba45916c..497c8f50 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -171,22 +171,23 @@ def download_weights( with open(config, "r") as f: config = json.load(f) - model_id = config["base_model_name_or_path"] - revision = "main" - try: - utils.weight_files(model_id, revision, extension) - logger.info( - f"Files for parent {model_id} are already present on the host. " - "Skipping download." - ) - return - # Local files not found - except ( - utils.LocalEntryNotFoundError, - FileNotFoundError, - utils.EntryNotFoundError, - ): - pass + base_model_id = config.get("base_model_name_or_path", None) + if base_model_id: + revision = "main" + try: + utils.weight_files(base_model_id, revision, extension) + logger.info( + f"Files for parent {base_model_id} are already present on the host. " + "Skipping download." + ) + return + # Local files not found + except ( + utils.LocalEntryNotFoundError, + FileNotFoundError, + utils.EntryNotFoundError, + ): + pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass @@ -222,18 +223,19 @@ def download_weights( with open(config, "r") as f: config = json.load(f) - model_id = config["base_model_name_or_path"] - revision = "main" - try: - utils.weight_files(model_id, revision, extension) - logger.info( - f"Files for parent {model_id} are already present on the host. " - "Skipping download." - ) - return - # Local files not found - except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): - pass + base_model_id = config.get("base_model_name_or_path", None) + if base_model_id: + revision = "main" + try: + utils.weight_files(base_model_id, revision, extension) + logger.info( + f"Files for parent {base_model_id} are already present on the host. " + "Skipping download." + ) + return + # Local files not found + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass