diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 3c8d9fcc..b5ef5592 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -12,7 +12,7 @@ use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; use std::{fs, io}; -use subprocess::{Popen, PopenConfig, PopenError, Redirection}; +use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection}; /// App Configuration #[derive(Parser, Debug)] @@ -84,6 +84,120 @@ fn main() -> ExitCode { }) .expect("Error setting Ctrl-C handler"); + // Download weights + if num_shard > 1 { + // Only download weights if in sharded mode + let mut download_argv = vec![ + "text-generation-server".to_string(), + "download-weights".to_string(), + model_id.clone(), + "--logger-level".to_string(), + "INFO".to_string(), + "--json-output".to_string(), + ]; + // Model optional revision + if let Some(revision) = revision.clone() { + download_argv.push("--revision".to_string()); + download_argv.push(revision) + } + + let mut env = Vec::new(); + + // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard + // Useful when running inside a docker container + if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { + env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); + }; + + // If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard + // Useful when running inside a HuggingFace Inference Endpoint + if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") { + env.push(( + "WEIGHTS_CACHE_OVERRIDE".into(), + weights_cache_override.into(), + )); + }; + + // Start process + tracing::info!("Starting download"); + let mut download_process = match Popen::create( + &download_argv, + PopenConfig { + stdout: Redirection::Pipe, + stderr: Redirection::Pipe, + // Needed for the shutdown procedure + setpgid: true, + env: Some(env), + ..Default::default() + }, + ) { + Ok(p) => p, + Err(err) => { + if let PopenError::IoError(ref err) = err { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-server not found in PATH"); + tracing::error!("Please install it with `make install-server`") + } + } + return ExitCode::FAILURE; + } + }; + + // Redirect STDOUT to the console + let download_stdout = download_process.stdout.take().unwrap(); + thread::spawn(move || { + // Enter download tracing span + let stdout = BufReader::new(download_stdout); + let _span = tracing::span!(tracing::Level::INFO, "download").entered(); + for line in stdout.lines() { + // Parse loguru logs + if let Ok(value) = serde_json::from_str::(&line.unwrap()) { + if let Some(text) = value.get("text") { + // Format escaped newlines + tracing::info!("{}", text.to_string().replace("\\n", "")); + } + } + } + }); + + loop { + if let Some(status) = download_process.poll() { + match status { + ExitStatus::Exited(exit_code) => { + if exit_code == 0 { + tracing::info!("Successfully downloaded weights."); + break; + } else { + let mut err = String::new(); + download_process + .stderr + .take() + .unwrap() + .read_to_string(&mut err) + .unwrap(); + tracing::error!("Download encountered an error: {err}"); + return ExitCode::FAILURE; + } + } + _ => { + tracing::error!("Download process exited with an unkown status."); + return ExitCode::FAILURE; + } + } + } + if !running.load(Ordering::SeqCst) { + download_process.terminate().unwrap(); + tracing::info!("Waiting for download process to gracefully shutdown"); + download_process + .wait_timeout(Duration::from_secs(90)) + .unwrap(); + tracing::info!("Download process terminated"); + return ExitCode::SUCCESS; + } + sleep(Duration::from_millis(100)); + } + } + // Shared shutdown bool let shutdown = Arc::new(Mutex::new(false)); // Shared shutdown channel diff --git a/server/tests/utils/test_convert.py b/server/tests/utils/test_convert.py new file mode 100644 index 00000000..5f284be5 --- /dev/null +++ b/server/tests/utils/test_convert.py @@ -0,0 +1,17 @@ +from text_generation.utils.hub import download_weights, weight_hub_files, weight_files + +from text_generation.utils.convert import convert_files + + +def test_convert_files(): + model_id = "bigscience/bloom-560m" + pt_filenames = weight_hub_files(model_id, extension=".bin") + local_pt_files = download_weights(pt_filenames, model_id) + local_st_files = [ + p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files + ] + convert_files(local_pt_files, local_st_files) + + found_st_files = weight_files(model_id) + + assert all([p in found_st_files for p in local_st_files]) diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py new file mode 100644 index 00000000..b3120160 --- /dev/null +++ b/server/tests/utils/test_hub.py @@ -0,0 +1,40 @@ +import pytest + +from text_generation.utils.hub import ( + weight_hub_files, + download_weights, + weight_files, + EntryNotFoundError, + LocalEntryNotFoundError, + RevisionNotFoundError, +) + + +def test_weight_hub_files(): + filenames = weight_hub_files("bigscience/bloom-560m") + assert filenames == ["model.safetensors"] + + +def test_weight_hub_files_llm(): + filenames = weight_hub_files("bigscience/bloom") + assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)] + + +def test_weight_hub_files_empty(): + with pytest.raises(EntryNotFoundError): + weight_hub_files("bigscience/bloom", extension=".errors") + + +def test_download_weights(): + model_id = "bigscience/bloom-560m" + filenames = weight_hub_files(model_id) + files = download_weights(filenames, model_id) + local_files = weight_files("bigscience/bloom-560m") + assert files == local_files + + +def test_weight_files_error(): + with pytest.raises(RevisionNotFoundError): + weight_files("bigscience/bloom-560m", revision="error") + with pytest.raises(LocalEntryNotFoundError): + weight_files("bert-base-uncased") diff --git a/server/tests/test_utils.py b/server/tests/utils/test_tokens.py similarity index 52% rename from server/tests/test_utils.py rename to server/tests/utils/test_tokens.py index ffe9be65..7eca482f 100644 --- a/server/tests/test_utils.py +++ b/server/tests/utils/test_tokens.py @@ -1,14 +1,6 @@ -import pytest - -from huggingface_hub.utils import RevisionNotFoundError - -from text_generation.utils import ( - weight_hub_files, - download_weights, - weight_files, +from text_generation.utils.tokens import ( StopSequenceCriteria, StoppingCriteria, - LocalEntryNotFoundError, FinishReason, ) @@ -41,31 +33,3 @@ def test_stopping_criteria_max(): assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) - - -def test_weight_hub_files(): - filenames = weight_hub_files("bigscience/bloom-560m") - assert filenames == ["model.safetensors"] - - -def test_weight_hub_files_llm(): - filenames = weight_hub_files("bigscience/bloom") - assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)] - - -def test_weight_hub_files_empty(): - filenames = weight_hub_files("bigscience/bloom", extension=".errors") - assert filenames == [] - - -def test_download_weights(): - files = download_weights("bigscience/bloom-560m") - local_files = weight_files("bigscience/bloom-560m") - assert files == local_files - - -def test_weight_files_error(): - with pytest.raises(RevisionNotFoundError): - weight_files("bigscience/bloom-560m", revision="error") - with pytest.raises(LocalEntryNotFoundError): - weight_files("bert-base-uncased") diff --git a/server/text_generation/cli.py b/server/text_generation/cli.py index 17f99d68..678dce16 100644 --- a/server/text_generation/cli.py +++ b/server/text_generation/cli.py @@ -60,18 +60,49 @@ def download_weights( model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", - convert: bool = False, + logger_level: str = "INFO", + json_output: bool = False, ): + # Remove default handler + logger.remove() + logger.add( + sys.stdout, + format="{message}", + filter="text_generation", + level=logger_level, + serialize=json_output, + backtrace=True, + diagnose=False, + ) + + # Test if files were already download + try: + utils.weight_files(model_id, revision, extension) + logger.info( + "Files are already present in the local cache. " "Skipping download." + ) + return + # Local files not found + except utils.LocalEntryNotFoundError: + pass + + # Download weights directly try: filenames = utils.weight_hub_files(model_id, revision, extension) - utils.download_weights(model_id, revision, filenames) + utils.download_weights(filenames, model_id, revision) except utils.EntryNotFoundError as e: - if not convert or not extension == ".safetensors": + if not extension == ".safetensors": raise e + + logger.warning( + f"No safetensors weights found for model {model_id} at revision {revision}. " + f"Converting PyTorch weights instead." + ) + # Try to see if there are pytorch weights pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") # Download pytorch weights - local_pt_files = utils.download_weights(model_id, revision, pt_filenames) + local_pt_files = utils.download_weights(pt_filenames, model_id, revision) local_st_files = [ p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files diff --git a/server/text_generation/utils/__init__.py b/server/text_generation/utils/__init__.py index b7521b92..a390b710 100644 --- a/server/text_generation/utils/__init__.py +++ b/server/text_generation/utils/__init__.py @@ -6,12 +6,15 @@ from text_generation.utils.hub import ( download_weights, EntryNotFoundError, LocalEntryNotFoundError, + RevisionNotFoundError, ) from text_generation.utils.tokens import ( Greedy, NextTokenChooser, Sampling, StoppingCriteria, + StopSequenceCriteria, + FinishReason, ) __all__ = [ @@ -23,8 +26,11 @@ __all__ = [ "download_weights", "EntryNotFoundError", "LocalEntryNotFoundError", + "RevisionNotFoundError", "Greedy", "NextTokenChooser", "Sampling", "StoppingCriteria", + "StopSequenceCriteria", + "FinishReason", ] diff --git a/server/text_generation/utils/convert.py b/server/text_generation/utils/convert.py index 3d429efa..e7f9660c 100644 --- a/server/text_generation/utils/convert.py +++ b/server/text_generation/utils/convert.py @@ -1,11 +1,13 @@ import concurrent +import time import torch from concurrent.futures import ThreadPoolExecutor from collections import defaultdict +from datetime import timedelta +from loguru import logger from pathlib import Path from safetensors.torch import load_file, save_file -from tqdm import tqdm from typing import Dict, List @@ -79,7 +81,16 @@ def convert_files(pt_files: List[Path], st_files: List[Path]): executor.submit(convert_file, pt_file=pt_file, st_file=st_file) for pt_file, st_file in zip(pt_files, st_files) ] - [ - future.result() - for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) - ] + + # We do this instead of using tqdm because we want to parse the logs with the launcher + logger.info("Converting weights...") + start_time = time.time() + for i, future in enumerate(concurrent.futures.as_completed(futures)): + elapsed = timedelta(seconds=int(time.time() - start_time)) + remaining = len(futures) - (i + 1) + if remaining != 0: + eta = (elapsed / (i + 1)) * remaining + else: + eta = 0 + + logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}") diff --git a/server/text_generation/utils/hub.py b/server/text_generation/utils/hub.py index e7f1e518..60072a20 100644 --- a/server/text_generation/utils/hub.py +++ b/server/text_generation/utils/hub.py @@ -1,21 +1,26 @@ +import time import concurrent import os from concurrent.futures import ThreadPoolExecutor -from functools import partial +from datetime import timedelta +from loguru import logger from pathlib import Path -from tqdm import tqdm from typing import Optional, List -from huggingface_hub import HfApi, _CACHED_NO_EXIST, hf_hub_download +from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from huggingface_hub.utils import LocalEntryNotFoundError, EntryNotFoundError +from huggingface_hub.utils import ( + LocalEntryNotFoundError, + EntryNotFoundError, + RevisionNotFoundError, # Import here to ease try/except in other part of the lib +) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) def weight_hub_files( - model_id: str, revision: str = None, extension: str = ".safetensors" + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" ) -> List[str]: """Get the weights filenames on the hub""" api = HfApi() @@ -32,7 +37,7 @@ def weight_hub_files( def try_to_load_from_cache( - model_id: str, revision: str, filename: str + model_id: str, revision: Optional[str], filename: str ) -> Optional[Path]: """Try to load a file from the Hugging Face cache""" if revision is None: @@ -58,7 +63,7 @@ def try_to_load_from_cache( # Check if file is cached as "no_exist" if (no_exist_dir / revision / filename).is_file(): - return _CACHED_NO_EXIST + return None # Check if revision folder exists if not snapshots_dir.exists(): @@ -73,7 +78,9 @@ def try_to_load_from_cache( return cached_file if cached_file.is_file() else None -def weight_files(model_id: str, revision: str, extension: str) -> List[Path]: +def weight_files( + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" +) -> List[Path]: """Get the local files""" try: filenames = weight_hub_files(model_id, revision, extension) @@ -116,22 +123,47 @@ def weight_files(model_id: str, revision: str, extension: str) -> List[Path]: return files -def download_weights(model_id: str, revision: str, filenames: List[str]) -> List[Path]: +def download_weights( + filenames: List[str], model_id: str, revision: Optional[str] = None +) -> List[Path]: """Download the safetensors files from the hub""" - download_function = partial( - hf_hub_download, - repo_id=model_id, - local_files_only=False, - ) + + def download_file(filename): + local_file = try_to_load_from_cache(model_id, revision, filename) + if local_file is not None: + logger.info(f"File {filename} already present in cache.") + return local_file + + start_time = time.time() + local_file = hf_hub_download( + filename=filename, + repo_id=model_id, + revision=revision, + local_files_only=False, + ) + logger.info( + f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}." + ) + return local_file executor = ThreadPoolExecutor(max_workers=5) futures = [ - executor.submit(download_function, filename=filename, revision=revision) - for filename in filenames - ] - files = [ - future.result() - for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) + executor.submit(download_file, filename=filename) for filename in filenames ] + # We do this instead of using tqdm because we want to parse the logs with the launcher + logger.info("Downloading weights...") + start_time = time.time() + files = [] + for i, future in enumerate(concurrent.futures.as_completed(futures)): + elapsed = timedelta(seconds=int(time.time() - start_time)) + remaining = len(futures) - (i + 1) + if remaining != 0: + eta = (elapsed / (i + 1)) * remaining + else: + eta = 0 + + logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}") + files.append(Path(future.result())) + return [Path(p) for p in files] diff --git a/server/text_generation/utils/tokens.py b/server/text_generation/utils/tokens.py index ef71ab81..cc0e6c35 100644 --- a/server/text_generation/utils/tokens.py +++ b/server/text_generation/utils/tokens.py @@ -12,6 +12,7 @@ from transformers import ( from typing import List, Tuple, Optional from text_generation.pb import generate_pb2 +from text_generation.pb.generate_pb2 import FinishReason class Sampling: @@ -115,15 +116,15 @@ class StoppingCriteria: def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: self.current_tokens += 1 if self.current_tokens >= self.max_new_tokens: - return True, generate_pb2.FinishReason.FINISH_REASON_LENGTH + return True, FinishReason.FINISH_REASON_LENGTH if last_token == self.eos_token_id: - return True, generate_pb2.FinishReason.FINISH_REASON_EOS_TOKEN + return True, FinishReason.FINISH_REASON_EOS_TOKEN self.current_output += last_output for stop_sequence_criteria in self.stop_sequence_criterias: if stop_sequence_criteria(self.current_output): - return True, generate_pb2.FinishReason.FINISH_REASON_STOP_SEQUENCE + return True, FinishReason.FINISH_REASON_STOP_SEQUENCE return False, None