diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index d5ba73f1..92482a94 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -91,14 +91,12 @@ def download_weights( except (utils.LocalEntryNotFoundError, FileNotFoundError): pass - is_local_model = ( - Path(model_id).exists() - and Path(model_id).is_dir() - or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - ) + is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( + "WEIGHTS_CACHE_OVERRIDE", None + ) is not None if not is_local_model: - # Download weights from the hub + # Try to download weights from the hub try: filenames = utils.weight_hub_files(model_id, revision, extension) utils.download_weights(filenames, model_id, revision) @@ -115,11 +113,13 @@ def download_weights( try: # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE local_pt_files = utils.weight_files(model_id, revision, ".bin") + + # No local pytorch weights except utils.LocalEntryNotFoundError: if extension == ".safetensors": logger.warning( f"No safetensors weights found for model {model_id} at revision {revision}. " - f"Converting PyTorch weights instead." + f"Downloading PyTorch weights." ) # Try to see if there are pytorch weights on the hub @@ -128,6 +128,11 @@ def download_weights( local_pt_files = utils.download_weights(pt_filenames, model_id, revision) if auto_convert: + logger.warning( + f"No safetensors weights found for model {model_id} at revision {revision}. " + f"Converting PyTorch weights to safetensors." + ) + # Safetensors final filenames local_st_files = [ p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"