mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
let launcher download weights
This commit is contained in:
parent
397a28080c
commit
97f9ae6a6d
@ -12,7 +12,7 @@ use std::thread;
|
||||
use std::thread::sleep;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{fs, io};
|
||||
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
|
||||
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
@ -84,6 +84,120 @@ fn main() -> ExitCode {
|
||||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Download weights
|
||||
if num_shard > 1 {
|
||||
// Only download weights if in sharded mode
|
||||
let mut download_argv = vec![
|
||||
"text-generation-server".to_string(),
|
||||
"download-weights".to_string(),
|
||||
model_id.clone(),
|
||||
"--logger-level".to_string(),
|
||||
"INFO".to_string(),
|
||||
"--json-output".to_string(),
|
||||
];
|
||||
// Model optional revision
|
||||
if let Some(revision) = revision.clone() {
|
||||
download_argv.push("--revision".to_string());
|
||||
download_argv.push(revision)
|
||||
}
|
||||
|
||||
let mut env = Vec::new();
|
||||
|
||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||
};
|
||||
|
||||
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
|
||||
// Useful when running inside a HuggingFace Inference Endpoint
|
||||
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
|
||||
env.push((
|
||||
"WEIGHTS_CACHE_OVERRIDE".into(),
|
||||
weights_cache_override.into(),
|
||||
));
|
||||
};
|
||||
|
||||
// Start process
|
||||
tracing::info!("Starting download");
|
||||
let mut download_process = match Popen::create(
|
||||
&download_argv,
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
env: Some(env),
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
if let PopenError::IoError(ref err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-server not found in PATH");
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
}
|
||||
}
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
};
|
||||
|
||||
// Redirect STDOUT to the console
|
||||
let download_stdout = download_process.stdout.take().unwrap();
|
||||
thread::spawn(move || {
|
||||
// Enter download tracing span
|
||||
let stdout = BufReader::new(download_stdout);
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
for line in stdout.lines() {
|
||||
// Parse loguru logs
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
|
||||
if let Some(text) = value.get("text") {
|
||||
// Format escaped newlines
|
||||
tracing::info!("{}", text.to_string().replace("\\n", ""));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
if let Some(status) = download_process.poll() {
|
||||
match status {
|
||||
ExitStatus::Exited(exit_code) => {
|
||||
if exit_code == 0 {
|
||||
tracing::info!("Successfully downloaded weights.");
|
||||
break;
|
||||
} else {
|
||||
let mut err = String::new();
|
||||
download_process
|
||||
.stderr
|
||||
.take()
|
||||
.unwrap()
|
||||
.read_to_string(&mut err)
|
||||
.unwrap();
|
||||
tracing::error!("Download encountered an error: {err}");
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::error!("Download process exited with an unkown status.");
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
download_process.terminate().unwrap();
|
||||
tracing::info!("Waiting for download process to gracefully shutdown");
|
||||
download_process
|
||||
.wait_timeout(Duration::from_secs(90))
|
||||
.unwrap();
|
||||
tracing::info!("Download process terminated");
|
||||
return ExitCode::SUCCESS;
|
||||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
}
|
||||
|
||||
// Shared shutdown bool
|
||||
let shutdown = Arc::new(Mutex::new(false));
|
||||
// Shared shutdown channel
|
||||
|
17
server/tests/utils/test_convert.py
Normal file
17
server/tests/utils/test_convert.py
Normal file
@ -0,0 +1,17 @@
|
||||
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
|
||||
|
||||
from text_generation.utils.convert import convert_files
|
||||
|
||||
|
||||
def test_convert_files():
|
||||
model_id = "bigscience/bloom-560m"
|
||||
pt_filenames = weight_hub_files(model_id, extension=".bin")
|
||||
local_pt_files = download_weights(pt_filenames, model_id)
|
||||
local_st_files = [
|
||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
|
||||
]
|
||||
convert_files(local_pt_files, local_st_files)
|
||||
|
||||
found_st_files = weight_files(model_id)
|
||||
|
||||
assert all([p in found_st_files for p in local_st_files])
|
40
server/tests/utils/test_hub.py
Normal file
40
server/tests/utils/test_hub.py
Normal file
@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
|
||||
from text_generation.utils.hub import (
|
||||
weight_hub_files,
|
||||
download_weights,
|
||||
weight_files,
|
||||
EntryNotFoundError,
|
||||
LocalEntryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
def test_weight_hub_files():
|
||||
filenames = weight_hub_files("bigscience/bloom-560m")
|
||||
assert filenames == ["model.safetensors"]
|
||||
|
||||
|
||||
def test_weight_hub_files_llm():
|
||||
filenames = weight_hub_files("bigscience/bloom")
|
||||
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
|
||||
|
||||
|
||||
def test_weight_hub_files_empty():
|
||||
with pytest.raises(EntryNotFoundError):
|
||||
weight_hub_files("bigscience/bloom", extension=".errors")
|
||||
|
||||
|
||||
def test_download_weights():
|
||||
model_id = "bigscience/bloom-560m"
|
||||
filenames = weight_hub_files(model_id)
|
||||
files = download_weights(filenames, model_id)
|
||||
local_files = weight_files("bigscience/bloom-560m")
|
||||
assert files == local_files
|
||||
|
||||
|
||||
def test_weight_files_error():
|
||||
with pytest.raises(RevisionNotFoundError):
|
||||
weight_files("bigscience/bloom-560m", revision="error")
|
||||
with pytest.raises(LocalEntryNotFoundError):
|
||||
weight_files("bert-base-uncased")
|
@ -1,14 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from huggingface_hub.utils import RevisionNotFoundError
|
||||
|
||||
from text_generation.utils import (
|
||||
weight_hub_files,
|
||||
download_weights,
|
||||
weight_files,
|
||||
from text_generation.utils.tokens import (
|
||||
StopSequenceCriteria,
|
||||
StoppingCriteria,
|
||||
LocalEntryNotFoundError,
|
||||
FinishReason,
|
||||
)
|
||||
|
||||
@ -41,31 +33,3 @@ def test_stopping_criteria_max():
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||
|
||||
|
||||
def test_weight_hub_files():
|
||||
filenames = weight_hub_files("bigscience/bloom-560m")
|
||||
assert filenames == ["model.safetensors"]
|
||||
|
||||
|
||||
def test_weight_hub_files_llm():
|
||||
filenames = weight_hub_files("bigscience/bloom")
|
||||
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
|
||||
|
||||
|
||||
def test_weight_hub_files_empty():
|
||||
filenames = weight_hub_files("bigscience/bloom", extension=".errors")
|
||||
assert filenames == []
|
||||
|
||||
|
||||
def test_download_weights():
|
||||
files = download_weights("bigscience/bloom-560m")
|
||||
local_files = weight_files("bigscience/bloom-560m")
|
||||
assert files == local_files
|
||||
|
||||
|
||||
def test_weight_files_error():
|
||||
with pytest.raises(RevisionNotFoundError):
|
||||
weight_files("bigscience/bloom-560m", revision="error")
|
||||
with pytest.raises(LocalEntryNotFoundError):
|
||||
weight_files("bert-base-uncased")
|
@ -60,18 +60,49 @@ def download_weights(
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
extension: str = ".safetensors",
|
||||
convert: bool = False,
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
):
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
filter="text_generation",
|
||||
level=logger_level,
|
||||
serialize=json_output,
|
||||
backtrace=True,
|
||||
diagnose=False,
|
||||
)
|
||||
|
||||
# Test if files were already download
|
||||
try:
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
logger.info(
|
||||
"Files are already present in the local cache. " "Skipping download."
|
||||
)
|
||||
return
|
||||
# Local files not found
|
||||
except utils.LocalEntryNotFoundError:
|
||||
pass
|
||||
|
||||
# Download weights directly
|
||||
try:
|
||||
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||
utils.download_weights(model_id, revision, filenames)
|
||||
utils.download_weights(filenames, model_id, revision)
|
||||
except utils.EntryNotFoundError as e:
|
||||
if not convert or not extension == ".safetensors":
|
||||
if not extension == ".safetensors":
|
||||
raise e
|
||||
|
||||
logger.warning(
|
||||
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||
f"Converting PyTorch weights instead."
|
||||
)
|
||||
|
||||
# Try to see if there are pytorch weights
|
||||
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
|
||||
# Download pytorch weights
|
||||
local_pt_files = utils.download_weights(model_id, revision, pt_filenames)
|
||||
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||
local_st_files = [
|
||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||
for p in local_pt_files
|
||||
|
@ -6,12 +6,15 @@ from text_generation.utils.hub import (
|
||||
download_weights,
|
||||
EntryNotFoundError,
|
||||
LocalEntryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
)
|
||||
from text_generation.utils.tokens import (
|
||||
Greedy,
|
||||
NextTokenChooser,
|
||||
Sampling,
|
||||
StoppingCriteria,
|
||||
StopSequenceCriteria,
|
||||
FinishReason,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -23,8 +26,11 @@ __all__ = [
|
||||
"download_weights",
|
||||
"EntryNotFoundError",
|
||||
"LocalEntryNotFoundError",
|
||||
"RevisionNotFoundError",
|
||||
"Greedy",
|
||||
"NextTokenChooser",
|
||||
"Sampling",
|
||||
"StoppingCriteria",
|
||||
"StopSequenceCriteria",
|
||||
"FinishReason",
|
||||
]
|
||||
|
@ -1,11 +1,13 @@
|
||||
import concurrent
|
||||
import time
|
||||
import torch
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
@ -79,7 +81,16 @@ def convert_files(pt_files: List[Path], st_files: List[Path]):
|
||||
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
|
||||
for pt_file, st_file in zip(pt_files, st_files)
|
||||
]
|
||||
[
|
||||
future.result()
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
|
||||
]
|
||||
|
||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||
logger.info("Converting weights...")
|
||||
start_time = time.time()
|
||||
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||
remaining = len(futures) - (i + 1)
|
||||
if remaining != 0:
|
||||
eta = (elapsed / (i + 1)) * remaining
|
||||
else:
|
||||
eta = 0
|
||||
|
||||
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
||||
|
@ -1,21 +1,26 @@
|
||||
import time
|
||||
import concurrent
|
||||
import os
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import Optional, List
|
||||
|
||||
from huggingface_hub import HfApi, _CACHED_NO_EXIST, hf_hub_download
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from huggingface_hub.utils import LocalEntryNotFoundError, EntryNotFoundError
|
||||
from huggingface_hub.utils import (
|
||||
LocalEntryNotFoundError,
|
||||
EntryNotFoundError,
|
||||
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
|
||||
)
|
||||
|
||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||
|
||||
|
||||
def weight_hub_files(
|
||||
model_id: str, revision: str = None, extension: str = ".safetensors"
|
||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||
) -> List[str]:
|
||||
"""Get the weights filenames on the hub"""
|
||||
api = HfApi()
|
||||
@ -32,7 +37,7 @@ def weight_hub_files(
|
||||
|
||||
|
||||
def try_to_load_from_cache(
|
||||
model_id: str, revision: str, filename: str
|
||||
model_id: str, revision: Optional[str], filename: str
|
||||
) -> Optional[Path]:
|
||||
"""Try to load a file from the Hugging Face cache"""
|
||||
if revision is None:
|
||||
@ -58,7 +63,7 @@ def try_to_load_from_cache(
|
||||
|
||||
# Check if file is cached as "no_exist"
|
||||
if (no_exist_dir / revision / filename).is_file():
|
||||
return _CACHED_NO_EXIST
|
||||
return None
|
||||
|
||||
# Check if revision folder exists
|
||||
if not snapshots_dir.exists():
|
||||
@ -73,7 +78,9 @@ def try_to_load_from_cache(
|
||||
return cached_file if cached_file.is_file() else None
|
||||
|
||||
|
||||
def weight_files(model_id: str, revision: str, extension: str) -> List[Path]:
|
||||
def weight_files(
|
||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||
) -> List[Path]:
|
||||
"""Get the local files"""
|
||||
try:
|
||||
filenames = weight_hub_files(model_id, revision, extension)
|
||||
@ -116,22 +123,47 @@ def weight_files(model_id: str, revision: str, extension: str) -> List[Path]:
|
||||
return files
|
||||
|
||||
|
||||
def download_weights(model_id: str, revision: str, filenames: List[str]) -> List[Path]:
|
||||
def download_weights(
|
||||
filenames: List[str], model_id: str, revision: Optional[str] = None
|
||||
) -> List[Path]:
|
||||
"""Download the safetensors files from the hub"""
|
||||
download_function = partial(
|
||||
hf_hub_download,
|
||||
repo_id=model_id,
|
||||
local_files_only=False,
|
||||
)
|
||||
|
||||
def download_file(filename):
|
||||
local_file = try_to_load_from_cache(model_id, revision, filename)
|
||||
if local_file is not None:
|
||||
logger.info(f"File {filename} already present in cache.")
|
||||
return local_file
|
||||
|
||||
start_time = time.time()
|
||||
local_file = hf_hub_download(
|
||||
filename=filename,
|
||||
repo_id=model_id,
|
||||
revision=revision,
|
||||
local_files_only=False,
|
||||
)
|
||||
logger.info(
|
||||
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
||||
)
|
||||
return local_file
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=5)
|
||||
futures = [
|
||||
executor.submit(download_function, filename=filename, revision=revision)
|
||||
for filename in filenames
|
||||
]
|
||||
files = [
|
||||
future.result()
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
|
||||
executor.submit(download_file, filename=filename) for filename in filenames
|
||||
]
|
||||
|
||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||
logger.info("Downloading weights...")
|
||||
start_time = time.time()
|
||||
files = []
|
||||
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||
remaining = len(futures) - (i + 1)
|
||||
if remaining != 0:
|
||||
eta = (elapsed / (i + 1)) * remaining
|
||||
else:
|
||||
eta = 0
|
||||
|
||||
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
||||
files.append(Path(future.result()))
|
||||
|
||||
return [Path(p) for p in files]
|
||||
|
@ -12,6 +12,7 @@ from transformers import (
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.pb.generate_pb2 import FinishReason
|
||||
|
||||
|
||||
class Sampling:
|
||||
@ -115,15 +116,15 @@ class StoppingCriteria:
|
||||
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||
self.current_tokens += 1
|
||||
if self.current_tokens >= self.max_new_tokens:
|
||||
return True, generate_pb2.FinishReason.FINISH_REASON_LENGTH
|
||||
return True, FinishReason.FINISH_REASON_LENGTH
|
||||
|
||||
if last_token == self.eos_token_id:
|
||||
return True, generate_pb2.FinishReason.FINISH_REASON_EOS_TOKEN
|
||||
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||
|
||||
self.current_output += last_output
|
||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||
if stop_sequence_criteria(self.current_output):
|
||||
return True, generate_pb2.FinishReason.FINISH_REASON_STOP_SEQUENCE
|
||||
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
||||
|
||||
return False, None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user