mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
let launcher download weights
This commit is contained in:
parent
397a28080c
commit
97f9ae6a6d
@ -12,7 +12,7 @@ use std::thread;
|
|||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{fs, io};
|
use std::{fs, io};
|
||||||
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
|
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -84,6 +84,120 @@ fn main() -> ExitCode {
|
|||||||
})
|
})
|
||||||
.expect("Error setting Ctrl-C handler");
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
|
// Download weights
|
||||||
|
if num_shard > 1 {
|
||||||
|
// Only download weights if in sharded mode
|
||||||
|
let mut download_argv = vec![
|
||||||
|
"text-generation-server".to_string(),
|
||||||
|
"download-weights".to_string(),
|
||||||
|
model_id.clone(),
|
||||||
|
"--logger-level".to_string(),
|
||||||
|
"INFO".to_string(),
|
||||||
|
"--json-output".to_string(),
|
||||||
|
];
|
||||||
|
// Model optional revision
|
||||||
|
if let Some(revision) = revision.clone() {
|
||||||
|
download_argv.push("--revision".to_string());
|
||||||
|
download_argv.push(revision)
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut env = Vec::new();
|
||||||
|
|
||||||
|
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
||||||
|
// Useful when running inside a docker container
|
||||||
|
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
||||||
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
|
};
|
||||||
|
|
||||||
|
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
|
||||||
|
// Useful when running inside a HuggingFace Inference Endpoint
|
||||||
|
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
|
||||||
|
env.push((
|
||||||
|
"WEIGHTS_CACHE_OVERRIDE".into(),
|
||||||
|
weights_cache_override.into(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start process
|
||||||
|
tracing::info!("Starting download");
|
||||||
|
let mut download_process = match Popen::create(
|
||||||
|
&download_argv,
|
||||||
|
PopenConfig {
|
||||||
|
stdout: Redirection::Pipe,
|
||||||
|
stderr: Redirection::Pipe,
|
||||||
|
// Needed for the shutdown procedure
|
||||||
|
setpgid: true,
|
||||||
|
env: Some(env),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(err) => {
|
||||||
|
if let PopenError::IoError(ref err) = err {
|
||||||
|
if err.kind() == io::ErrorKind::NotFound {
|
||||||
|
tracing::error!("text-generation-server not found in PATH");
|
||||||
|
tracing::error!("Please install it with `make install-server`")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Redirect STDOUT to the console
|
||||||
|
let download_stdout = download_process.stdout.take().unwrap();
|
||||||
|
thread::spawn(move || {
|
||||||
|
// Enter download tracing span
|
||||||
|
let stdout = BufReader::new(download_stdout);
|
||||||
|
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||||
|
for line in stdout.lines() {
|
||||||
|
// Parse loguru logs
|
||||||
|
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
|
||||||
|
if let Some(text) = value.get("text") {
|
||||||
|
// Format escaped newlines
|
||||||
|
tracing::info!("{}", text.to_string().replace("\\n", ""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if let Some(status) = download_process.poll() {
|
||||||
|
match status {
|
||||||
|
ExitStatus::Exited(exit_code) => {
|
||||||
|
if exit_code == 0 {
|
||||||
|
tracing::info!("Successfully downloaded weights.");
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
let mut err = String::new();
|
||||||
|
download_process
|
||||||
|
.stderr
|
||||||
|
.take()
|
||||||
|
.unwrap()
|
||||||
|
.read_to_string(&mut err)
|
||||||
|
.unwrap();
|
||||||
|
tracing::error!("Download encountered an error: {err}");
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
tracing::error!("Download process exited with an unkown status.");
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !running.load(Ordering::SeqCst) {
|
||||||
|
download_process.terminate().unwrap();
|
||||||
|
tracing::info!("Waiting for download process to gracefully shutdown");
|
||||||
|
download_process
|
||||||
|
.wait_timeout(Duration::from_secs(90))
|
||||||
|
.unwrap();
|
||||||
|
tracing::info!("Download process terminated");
|
||||||
|
return ExitCode::SUCCESS;
|
||||||
|
}
|
||||||
|
sleep(Duration::from_millis(100));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Shared shutdown bool
|
// Shared shutdown bool
|
||||||
let shutdown = Arc::new(Mutex::new(false));
|
let shutdown = Arc::new(Mutex::new(false));
|
||||||
// Shared shutdown channel
|
// Shared shutdown channel
|
||||||
|
17
server/tests/utils/test_convert.py
Normal file
17
server/tests/utils/test_convert.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
|
||||||
|
|
||||||
|
from text_generation.utils.convert import convert_files
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_files():
|
||||||
|
model_id = "bigscience/bloom-560m"
|
||||||
|
pt_filenames = weight_hub_files(model_id, extension=".bin")
|
||||||
|
local_pt_files = download_weights(pt_filenames, model_id)
|
||||||
|
local_st_files = [
|
||||||
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
|
||||||
|
]
|
||||||
|
convert_files(local_pt_files, local_st_files)
|
||||||
|
|
||||||
|
found_st_files = weight_files(model_id)
|
||||||
|
|
||||||
|
assert all([p in found_st_files for p in local_st_files])
|
40
server/tests/utils/test_hub.py
Normal file
40
server/tests/utils/test_hub.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation.utils.hub import (
|
||||||
|
weight_hub_files,
|
||||||
|
download_weights,
|
||||||
|
weight_files,
|
||||||
|
EntryNotFoundError,
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files():
|
||||||
|
filenames = weight_hub_files("bigscience/bloom-560m")
|
||||||
|
assert filenames == ["model.safetensors"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files_llm():
|
||||||
|
filenames = weight_hub_files("bigscience/bloom")
|
||||||
|
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files_empty():
|
||||||
|
with pytest.raises(EntryNotFoundError):
|
||||||
|
weight_hub_files("bigscience/bloom", extension=".errors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_weights():
|
||||||
|
model_id = "bigscience/bloom-560m"
|
||||||
|
filenames = weight_hub_files(model_id)
|
||||||
|
files = download_weights(filenames, model_id)
|
||||||
|
local_files = weight_files("bigscience/bloom-560m")
|
||||||
|
assert files == local_files
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_files_error():
|
||||||
|
with pytest.raises(RevisionNotFoundError):
|
||||||
|
weight_files("bigscience/bloom-560m", revision="error")
|
||||||
|
with pytest.raises(LocalEntryNotFoundError):
|
||||||
|
weight_files("bert-base-uncased")
|
@ -1,14 +1,6 @@
|
|||||||
import pytest
|
from text_generation.utils.tokens import (
|
||||||
|
|
||||||
from huggingface_hub.utils import RevisionNotFoundError
|
|
||||||
|
|
||||||
from text_generation.utils import (
|
|
||||||
weight_hub_files,
|
|
||||||
download_weights,
|
|
||||||
weight_files,
|
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
LocalEntryNotFoundError,
|
|
||||||
FinishReason,
|
FinishReason,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -41,31 +33,3 @@ def test_stopping_criteria_max():
|
|||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files():
|
|
||||||
filenames = weight_hub_files("bigscience/bloom-560m")
|
|
||||||
assert filenames == ["model.safetensors"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files_llm():
|
|
||||||
filenames = weight_hub_files("bigscience/bloom")
|
|
||||||
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
|
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files_empty():
|
|
||||||
filenames = weight_hub_files("bigscience/bloom", extension=".errors")
|
|
||||||
assert filenames == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_download_weights():
|
|
||||||
files = download_weights("bigscience/bloom-560m")
|
|
||||||
local_files = weight_files("bigscience/bloom-560m")
|
|
||||||
assert files == local_files
|
|
||||||
|
|
||||||
|
|
||||||
def test_weight_files_error():
|
|
||||||
with pytest.raises(RevisionNotFoundError):
|
|
||||||
weight_files("bigscience/bloom-560m", revision="error")
|
|
||||||
with pytest.raises(LocalEntryNotFoundError):
|
|
||||||
weight_files("bert-base-uncased")
|
|
@ -60,18 +60,49 @@ def download_weights(
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
extension: str = ".safetensors",
|
extension: str = ".safetensors",
|
||||||
convert: bool = False,
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
):
|
):
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="{message}",
|
||||||
|
filter="text_generation",
|
||||||
|
level=logger_level,
|
||||||
|
serialize=json_output,
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test if files were already download
|
||||||
|
try:
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
logger.info(
|
||||||
|
"Files are already present in the local cache. " "Skipping download."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Local files not found
|
||||||
|
except utils.LocalEntryNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Download weights directly
|
||||||
try:
|
try:
|
||||||
filenames = utils.weight_hub_files(model_id, revision, extension)
|
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||||
utils.download_weights(model_id, revision, filenames)
|
utils.download_weights(filenames, model_id, revision)
|
||||||
except utils.EntryNotFoundError as e:
|
except utils.EntryNotFoundError as e:
|
||||||
if not convert or not extension == ".safetensors":
|
if not extension == ".safetensors":
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||||
|
f"Converting PyTorch weights instead."
|
||||||
|
)
|
||||||
|
|
||||||
# Try to see if there are pytorch weights
|
# Try to see if there are pytorch weights
|
||||||
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
|
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
|
||||||
# Download pytorch weights
|
# Download pytorch weights
|
||||||
local_pt_files = utils.download_weights(model_id, revision, pt_filenames)
|
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||||
local_st_files = [
|
local_st_files = [
|
||||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||||
for p in local_pt_files
|
for p in local_pt_files
|
||||||
|
@ -6,12 +6,15 @@ from text_generation.utils.hub import (
|
|||||||
download_weights,
|
download_weights,
|
||||||
EntryNotFoundError,
|
EntryNotFoundError,
|
||||||
LocalEntryNotFoundError,
|
LocalEntryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
)
|
)
|
||||||
from text_generation.utils.tokens import (
|
from text_generation.utils.tokens import (
|
||||||
Greedy,
|
Greedy,
|
||||||
NextTokenChooser,
|
NextTokenChooser,
|
||||||
Sampling,
|
Sampling,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
|
StopSequenceCriteria,
|
||||||
|
FinishReason,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -23,8 +26,11 @@ __all__ = [
|
|||||||
"download_weights",
|
"download_weights",
|
||||||
"EntryNotFoundError",
|
"EntryNotFoundError",
|
||||||
"LocalEntryNotFoundError",
|
"LocalEntryNotFoundError",
|
||||||
|
"RevisionNotFoundError",
|
||||||
"Greedy",
|
"Greedy",
|
||||||
"NextTokenChooser",
|
"NextTokenChooser",
|
||||||
"Sampling",
|
"Sampling",
|
||||||
"StoppingCriteria",
|
"StoppingCriteria",
|
||||||
|
"StopSequenceCriteria",
|
||||||
|
"FinishReason",
|
||||||
]
|
]
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import concurrent
|
import concurrent
|
||||||
|
import time
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from datetime import timedelta
|
||||||
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
@ -79,7 +81,16 @@ def convert_files(pt_files: List[Path], st_files: List[Path]):
|
|||||||
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
|
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
|
||||||
for pt_file, st_file in zip(pt_files, st_files)
|
for pt_file, st_file in zip(pt_files, st_files)
|
||||||
]
|
]
|
||||||
[
|
|
||||||
future.result()
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
|
logger.info("Converting weights...")
|
||||||
]
|
start_time = time.time()
|
||||||
|
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||||
|
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||||
|
remaining = len(futures) - (i + 1)
|
||||||
|
if remaining != 0:
|
||||||
|
eta = (elapsed / (i + 1)) * remaining
|
||||||
|
else:
|
||||||
|
eta = 0
|
||||||
|
|
||||||
|
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
||||||
|
@ -1,21 +1,26 @@
|
|||||||
|
import time
|
||||||
import concurrent
|
import concurrent
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from datetime import timedelta
|
||||||
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from huggingface_hub import HfApi, _CACHED_NO_EXIST, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
from huggingface_hub.utils import LocalEntryNotFoundError, EntryNotFoundError
|
from huggingface_hub.utils import (
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
EntryNotFoundError,
|
||||||
|
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
|
||||||
|
)
|
||||||
|
|
||||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
|
|
||||||
|
|
||||||
def weight_hub_files(
|
def weight_hub_files(
|
||||||
model_id: str, revision: str = None, extension: str = ".safetensors"
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Get the weights filenames on the hub"""
|
"""Get the weights filenames on the hub"""
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
@ -32,7 +37,7 @@ def weight_hub_files(
|
|||||||
|
|
||||||
|
|
||||||
def try_to_load_from_cache(
|
def try_to_load_from_cache(
|
||||||
model_id: str, revision: str, filename: str
|
model_id: str, revision: Optional[str], filename: str
|
||||||
) -> Optional[Path]:
|
) -> Optional[Path]:
|
||||||
"""Try to load a file from the Hugging Face cache"""
|
"""Try to load a file from the Hugging Face cache"""
|
||||||
if revision is None:
|
if revision is None:
|
||||||
@ -58,7 +63,7 @@ def try_to_load_from_cache(
|
|||||||
|
|
||||||
# Check if file is cached as "no_exist"
|
# Check if file is cached as "no_exist"
|
||||||
if (no_exist_dir / revision / filename).is_file():
|
if (no_exist_dir / revision / filename).is_file():
|
||||||
return _CACHED_NO_EXIST
|
return None
|
||||||
|
|
||||||
# Check if revision folder exists
|
# Check if revision folder exists
|
||||||
if not snapshots_dir.exists():
|
if not snapshots_dir.exists():
|
||||||
@ -73,7 +78,9 @@ def try_to_load_from_cache(
|
|||||||
return cached_file if cached_file.is_file() else None
|
return cached_file if cached_file.is_file() else None
|
||||||
|
|
||||||
|
|
||||||
def weight_files(model_id: str, revision: str, extension: str) -> List[Path]:
|
def weight_files(
|
||||||
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
|
) -> List[Path]:
|
||||||
"""Get the local files"""
|
"""Get the local files"""
|
||||||
try:
|
try:
|
||||||
filenames = weight_hub_files(model_id, revision, extension)
|
filenames = weight_hub_files(model_id, revision, extension)
|
||||||
@ -116,22 +123,47 @@ def weight_files(model_id: str, revision: str, extension: str) -> List[Path]:
|
|||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
def download_weights(model_id: str, revision: str, filenames: List[str]) -> List[Path]:
|
def download_weights(
|
||||||
|
filenames: List[str], model_id: str, revision: Optional[str] = None
|
||||||
|
) -> List[Path]:
|
||||||
"""Download the safetensors files from the hub"""
|
"""Download the safetensors files from the hub"""
|
||||||
download_function = partial(
|
|
||||||
hf_hub_download,
|
def download_file(filename):
|
||||||
repo_id=model_id,
|
local_file = try_to_load_from_cache(model_id, revision, filename)
|
||||||
local_files_only=False,
|
if local_file is not None:
|
||||||
)
|
logger.info(f"File {filename} already present in cache.")
|
||||||
|
return local_file
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
local_file = hf_hub_download(
|
||||||
|
filename=filename,
|
||||||
|
repo_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
||||||
|
)
|
||||||
|
return local_file
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=5)
|
executor = ThreadPoolExecutor(max_workers=5)
|
||||||
futures = [
|
futures = [
|
||||||
executor.submit(download_function, filename=filename, revision=revision)
|
executor.submit(download_file, filename=filename) for filename in filenames
|
||||||
for filename in filenames
|
|
||||||
]
|
|
||||||
files = [
|
|
||||||
future.result()
|
|
||||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
|
logger.info("Downloading weights...")
|
||||||
|
start_time = time.time()
|
||||||
|
files = []
|
||||||
|
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||||
|
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||||
|
remaining = len(futures) - (i + 1)
|
||||||
|
if remaining != 0:
|
||||||
|
eta = (elapsed / (i + 1)) * remaining
|
||||||
|
else:
|
||||||
|
eta = 0
|
||||||
|
|
||||||
|
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
||||||
|
files.append(Path(future.result()))
|
||||||
|
|
||||||
return [Path(p) for p in files]
|
return [Path(p) for p in files]
|
||||||
|
@ -12,6 +12,7 @@ from transformers import (
|
|||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
|
from text_generation.pb.generate_pb2 import FinishReason
|
||||||
|
|
||||||
|
|
||||||
class Sampling:
|
class Sampling:
|
||||||
@ -115,15 +116,15 @@ class StoppingCriteria:
|
|||||||
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||||
self.current_tokens += 1
|
self.current_tokens += 1
|
||||||
if self.current_tokens >= self.max_new_tokens:
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
return True, generate_pb2.FinishReason.FINISH_REASON_LENGTH
|
return True, FinishReason.FINISH_REASON_LENGTH
|
||||||
|
|
||||||
if last_token == self.eos_token_id:
|
if last_token == self.eos_token_id:
|
||||||
return True, generate_pb2.FinishReason.FINISH_REASON_EOS_TOKEN
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||||
|
|
||||||
self.current_output += last_output
|
self.current_output += last_output
|
||||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||||
if stop_sequence_criteria(self.current_output):
|
if stop_sequence_criteria(self.current_output):
|
||||||
return True, generate_pb2.FinishReason.FINISH_REASON_STOP_SEQUENCE
|
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
||||||
|
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user