diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 736aa5a7..41211d8a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -512,7 +512,11 @@ enum LauncherError { WebserverCannotStart, } -fn download_model(args: &Args, running: Arc) -> Result<(), LauncherError> { +fn download_convert_model( + args: &Args, + auto_convert: bool, + running: Arc, +) -> Result<(), LauncherError> { let mut download_argv = vec![ "text-generation-server".to_string(), "download-weights".to_string(), @@ -524,6 +528,11 @@ fn download_model(args: &Args, running: Arc) -> Result<(), LauncherE "--json-output".to_string(), ]; + // Auto convert weights to safetensors + if auto_convert { + download_argv.push("--auto-convert".to_string()); + } + // Model optional revision if let Some(revision) = &args.revision { download_argv.push("--revision".to_string()); @@ -855,14 +864,11 @@ fn main() -> Result<(), LauncherError> { }) .expect("Error setting Ctrl-C handler"); - // Check if model_id is a local model - let local_path = Path::new(&args.model_id); - let is_local_model = local_path.exists() && local_path.is_dir(); - - // Download weights for sharded models - if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 { - download_model(&args, running.clone())?; - } + // auto_convert is only needed for sharded models as we do not require safetensors in + // single shard mode + let auto_convert = num_shard > 1; + // Download and convert model weights + download_convert_model(&args, auto_convert, running.clone())?; // Shared shutdown bool let shutdown = Arc::new(Mutex::new(false)); diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 94340fac..92482a94 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -63,6 +63,7 @@ def download_weights( model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", + auto_convert: bool = True, logger_level: str = "INFO", json_output: bool = False, ): @@ -84,31 +85,55 @@ def download_weights( # Test if files were already download try: utils.weight_files(model_id, revision, extension) - logger.info( - "Files are already present in the local cache. " "Skipping download." - ) + logger.info("Files are already present on the host. " "Skipping download.") return # Local files not found - except utils.LocalEntryNotFoundError: + except (utils.LocalEntryNotFoundError, FileNotFoundError): pass - # Download weights directly + is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( + "WEIGHTS_CACHE_OVERRIDE", None + ) is not None + + if not is_local_model: + # Try to download weights from the hub + try: + filenames = utils.weight_hub_files(model_id, revision, extension) + utils.download_weights(filenames, model_id, revision) + # Successfully downloaded weights + return + + # No weights found on the hub with this extension + except utils.EntryNotFoundError as e: + # Check if we want to automatically convert to safetensors or if we can use .bin weights instead + if not extension == ".safetensors" or not auto_convert: + raise e + + # Try to see if there are local pytorch weights try: - filenames = utils.weight_hub_files(model_id, revision, extension) - utils.download_weights(filenames, model_id, revision) - except utils.EntryNotFoundError as e: - if not extension == ".safetensors": - raise e + # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE + local_pt_files = utils.weight_files(model_id, revision, ".bin") - logger.warning( - f"No safetensors weights found for model {model_id} at revision {revision}. " - f"Converting PyTorch weights instead." - ) + # No local pytorch weights + except utils.LocalEntryNotFoundError: + if extension == ".safetensors": + logger.warning( + f"No safetensors weights found for model {model_id} at revision {revision}. " + f"Downloading PyTorch weights." + ) - # Try to see if there are pytorch weights + # Try to see if there are pytorch weights on the hub pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") # Download pytorch weights local_pt_files = utils.download_weights(pt_filenames, model_id, revision) + + if auto_convert: + logger.warning( + f"No safetensors weights found for model {model_id} at revision {revision}. " + f"Converting PyTorch weights to safetensors." + ) + + # Safetensors final filenames local_st_files = [ p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index e43a4b79..f528a430 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -223,6 +223,15 @@ class BLOOMSharded(BLOOM): if name == "word_embeddings.weight": model.lm_head._parameters["weight"] = tensor + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model: {uninitialized_parameters}" + ) + def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index e640113b..105ff519 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -139,6 +139,15 @@ class FlashLlama(FlashCausalLM): del value + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model: {uninitialized_parameters}" + ) + torch.cuda.empty_cache() model.post_load_weights(quantize) @@ -300,5 +309,15 @@ class FlashLlamaSharded(FlashLlama): else: module._buffers[param_name] = tensor + + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model: {uninitialized_parameters}" + ) + torch.cuda.empty_cache() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index eae584ac..fc769583 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -149,4 +149,14 @@ class FlashNeoXSharded(FlashNeoX): module._parameters[param_name] = tensor else: module._buffers[param_name] = tensor + + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model: {uninitialized_parameters}" + ) + model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 550be956..333180e8 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -372,5 +372,15 @@ class FlashSantacoderSharded(FlashSantacoder): module._parameters[param_name] = tensor else: module._buffers[param_name] = tensor + + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model: {uninitialized_parameters}" + ) + torch.cuda.empty_cache() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 78e9bfe4..2577f1b1 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -355,6 +355,15 @@ class GalacticaSharded(Galactica): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model: {uninitialized_parameters}" + ) + def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 3b5fe2cc..e73a3c82 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -205,6 +205,15 @@ class GPTNeoxSharded(CausalLM): else: module._buffers[param_name] = tensor + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model: {uninitialized_parameters}" + ) + def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 1a21186f..50e5271e 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -210,6 +210,15 @@ class OPTSharded(OPT): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model: {uninitialized_parameters}" + ) + def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 487a5984..9e8c3c4c 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -211,6 +211,15 @@ class T5Sharded(Seq2SeqLM): else: module._buffers[param_name] = tensor + uninitialized_parameters = [] + for n, p in model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model: {uninitialized_parameters}" + ) + def forward( self, input_ids, diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 4feec8a1..030c8289 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -77,7 +77,12 @@ def weight_files( """Get the local files""" # Local model if Path(model_id).exists() and Path(model_id).is_dir(): - return list(Path(model_id).glob(f"*{extension}")) + local_files = list(Path(model_id).glob(f"*{extension}")) + if not local_files: + raise FileNotFoundError( + f"No local weights found in {model_id} with extension {extension}" + ) + return local_files try: filenames = weight_hub_files(model_id, revision, extension) @@ -98,7 +103,7 @@ def weight_files( for filename in filenames: p = Path(WEIGHTS_CACHE_OVERRIDE) / filename if not p.exists(): - raise LocalEntryNotFoundError( + raise FileNotFoundError( f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}." ) files.append(p)