let launcher download weights

This commit is contained in:
OlivierDehaene 2023-02-14 12:09:58 +01:00
parent 397a28080c
commit 97f9ae6a6d
9 changed files with 286 additions and 70 deletions

View File

@ -12,7 +12,7 @@ use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
/// App Configuration
#[derive(Parser, Debug)]
@ -84,6 +84,120 @@ fn main() -> ExitCode {
})
.expect("Error setting Ctrl-C handler");
// Download weights
if num_shard > 1 {
// Only download weights if in sharded mode
let mut download_argv = vec![
"text-generation-server".to_string(),
"download-weights".to_string(),
model_id.clone(),
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
// Model optional revision
if let Some(revision) = revision.clone() {
download_argv.push("--revision".to_string());
download_argv.push(revision)
}
let mut env = Vec::new();
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// Useful when running inside a docker container
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
env.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
));
};
// Start process
tracing::info!("Starting download");
let mut download_process = match Popen::create(
&download_argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
env: Some(env),
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
}
}
return ExitCode::FAILURE;
}
};
// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
thread::spawn(move || {
// Enter download tracing span
let stdout = BufReader::new(download_stdout);
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
if let Some(text) = value.get("text") {
// Format escaped newlines
tracing::info!("{}", text.to_string().replace("\\n", ""));
}
}
}
});
loop {
if let Some(status) = download_process.poll() {
match status {
ExitStatus::Exited(exit_code) => {
if exit_code == 0 {
tracing::info!("Successfully downloaded weights.");
break;
} else {
let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
tracing::error!("Download encountered an error: {err}");
return ExitCode::FAILURE;
}
}
_ => {
tracing::error!("Download process exited with an unkown status.");
return ExitCode::FAILURE;
}
}
}
if !running.load(Ordering::SeqCst) {
download_process.terminate().unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process
.wait_timeout(Duration::from_secs(90))
.unwrap();
tracing::info!("Download process terminated");
return ExitCode::SUCCESS;
}
sleep(Duration::from_millis(100));
}
}
// Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false));
// Shared shutdown channel

View File

@ -0,0 +1,17 @@
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
from text_generation.utils.convert import convert_files
def test_convert_files():
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files)
found_st_files = weight_files(model_id)
assert all([p in found_st_files for p in local_st_files])

View File

@ -0,0 +1,40 @@
import pytest
from text_generation.utils.hub import (
weight_hub_files,
download_weights,
weight_files,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
with pytest.raises(EntryNotFoundError):
weight_hub_files("bigscience/bloom", extension=".errors")
def test_download_weights():
model_id = "bigscience/bloom-560m"
filenames = weight_hub_files(model_id)
files = download_weights(filenames, model_id)
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -1,14 +1,6 @@
import pytest
from huggingface_hub.utils import RevisionNotFoundError
from text_generation.utils import (
weight_hub_files,
download_weights,
weight_files,
from text_generation.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
LocalEntryNotFoundError,
FinishReason,
)
@ -41,31 +33,3 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
filenames = weight_hub_files("bigscience/bloom", extension=".errors")
assert filenames == []
def test_download_weights():
files = download_weights("bigscience/bloom-560m")
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -60,18 +60,49 @@ def download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
convert: bool = False,
logger_level: str = "INFO",
json_output: bool = False,
):
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Test if files were already download
try:
utils.weight_files(model_id, revision, extension)
logger.info(
"Files are already present in the local cache. " "Skipping download."
)
return
# Local files not found
except utils.LocalEntryNotFoundError:
pass
# Download weights directly
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(model_id, revision, filenames)
utils.download_weights(filenames, model_id, revision)
except utils.EntryNotFoundError as e:
if not convert or not extension == ".safetensors":
if not extension == ".safetensors":
raise e
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights instead."
)
# Try to see if there are pytorch weights
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights
local_pt_files = utils.download_weights(model_id, revision, pt_filenames)
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files

View File

@ -6,12 +6,15 @@ from text_generation.utils.hub import (
download_weights,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
from text_generation.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,
StoppingCriteria,
StopSequenceCriteria,
FinishReason,
)
__all__ = [
@ -23,8 +26,11 @@ __all__ = [
"download_weights",
"EntryNotFoundError",
"LocalEntryNotFoundError",
"RevisionNotFoundError",
"Greedy",
"NextTokenChooser",
"Sampling",
"StoppingCriteria",
"StopSequenceCriteria",
"FinishReason",
]

View File

@ -1,11 +1,13 @@
import concurrent
import time
import torch
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from datetime import timedelta
from loguru import logger
from pathlib import Path
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from typing import Dict, List
@ -79,7 +81,16 @@ def convert_files(pt_files: List[Path], st_files: List[Path]):
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
for pt_file, st_file in zip(pt_files, st_files)
]
[
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
logger.info("Converting weights...")
start_time = time.time()
for i, future in enumerate(concurrent.futures.as_completed(futures)):
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1)
if remaining != 0:
eta = (elapsed / (i + 1)) * remaining
else:
eta = 0
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")

View File

@ -1,21 +1,26 @@
import time
import concurrent
import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from datetime import timedelta
from loguru import logger
from pathlib import Path
from tqdm import tqdm
from typing import Optional, List
from huggingface_hub import HfApi, _CACHED_NO_EXIST, hf_hub_download
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import LocalEntryNotFoundError, EntryNotFoundError
from huggingface_hub.utils import (
LocalEntryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
)
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
def weight_hub_files(
model_id: str, revision: str = None, extension: str = ".safetensors"
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()
@ -32,7 +37,7 @@ def weight_hub_files(
def try_to_load_from_cache(
model_id: str, revision: str, filename: str
model_id: str, revision: Optional[str], filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
if revision is None:
@ -58,7 +63,7 @@ def try_to_load_from_cache(
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return _CACHED_NO_EXIST
return None
# Check if revision folder exists
if not snapshots_dir.exists():
@ -73,7 +78,9 @@ def try_to_load_from_cache(
return cached_file if cached_file.is_file() else None
def weight_files(model_id: str, revision: str, extension: str) -> List[Path]:
def weight_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[Path]:
"""Get the local files"""
try:
filenames = weight_hub_files(model_id, revision, extension)
@ -116,22 +123,47 @@ def weight_files(model_id: str, revision: str, extension: str) -> List[Path]:
return files
def download_weights(model_id: str, revision: str, filenames: List[str]) -> List[Path]:
def download_weights(
filenames: List[str], model_id: str, revision: Optional[str] = None
) -> List[Path]:
"""Download the safetensors files from the hub"""
download_function = partial(
hf_hub_download,
def download_file(filename):
local_file = try_to_load_from_cache(model_id, revision, filename)
if local_file is not None:
logger.info(f"File {filename} already present in cache.")
return local_file
start_time = time.time()
local_file = hf_hub_download(
filename=filename,
repo_id=model_id,
revision=revision,
local_files_only=False,
)
logger.info(
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
)
return local_file
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_function, filename=filename, revision=revision)
for filename in filenames
]
files = [
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
executor.submit(download_file, filename=filename) for filename in filenames
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
logger.info("Downloading weights...")
start_time = time.time()
files = []
for i, future in enumerate(concurrent.futures.as_completed(futures)):
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1)
if remaining != 0:
eta = (elapsed / (i + 1)) * remaining
else:
eta = 0
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
files.append(Path(future.result()))
return [Path(p) for p in files]

View File

@ -12,6 +12,7 @@ from transformers import (
from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
class Sampling:
@ -115,15 +116,15 @@ class StoppingCriteria:
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, generate_pb2.FinishReason.FINISH_REASON_LENGTH
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, generate_pb2.FinishReason.FINISH_REASON_EOS_TOKEN
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, generate_pb2.FinishReason.FINISH_REASON_STOP_SEQUENCE
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None