feat: add safetensors conversion (#63)

This commit is contained in:
OlivierDehaene 2023-02-14 13:02:16 +01:00 committed by GitHub
parent 9af454142a
commit 0fbc691946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 749 additions and 380 deletions

View File

@ -49,17 +49,17 @@ to power LLMs api-inference widgets.
- Log probabilities - Log probabilities
- Distributed tracing with Open Telemetry - Distributed tracing with Open Telemetry
## Officially supported models ## Officially supported architectures
- [BLOOM](https://huggingface.co/bigscience/bloom) - [BLOOM](https://huggingface.co/bigscience/bloom)
- [BLOOMZ](https://huggingface.co/bigscience/bloomz) - [BLOOMZ](https://huggingface.co/bigscience/bloomz)
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) - [Galactica](https://huggingface.co/facebook/galactica-120b)
- [SantaCoder](https://huggingface.co/bigcode/santacoder) - [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b) - [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) - [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
Other models are supported on a best effort basis using: Other architectures are supported on a best effort basis using:
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")` `AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
@ -191,7 +191,7 @@ Be aware that the official Docker image has them enabled by default.
### Download ### Download
First you need to download the weights: It is advised to download the weights ahead of time with the following command:
```shell ```shell
make download-bloom make download-bloom

View File

@ -12,7 +12,7 @@ use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{fs, io}; use std::{fs, io};
use subprocess::{Popen, PopenConfig, PopenError, Redirection}; use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -43,6 +43,10 @@ struct Args {
#[clap(default_value = "29500", long, env)] #[clap(default_value = "29500", long, env)]
master_port: usize, master_port: usize,
#[clap(long, env)] #[clap(long, env)]
huggingface_hub_cache: Option<String>,
#[clap(long, env)]
weights_cache_override: Option<String>,
#[clap(long, env)]
json_output: bool, json_output: bool,
#[clap(long, env)] #[clap(long, env)]
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
@ -63,6 +67,8 @@ fn main() -> ExitCode {
shard_uds_path, shard_uds_path,
master_addr, master_addr,
master_port, master_port,
huggingface_hub_cache,
weights_cache_override,
json_output, json_output,
otlp_endpoint, otlp_endpoint,
} = Args::parse(); } = Args::parse();
@ -84,6 +90,124 @@ fn main() -> ExitCode {
}) })
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Download weights
if weights_cache_override.is_none() {
let mut download_argv = vec![
"text-generation-server".to_string(),
"download-weights".to_string(),
model_id.clone(),
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
if num_shard == 1 {
download_argv.push("--extension".to_string());
download_argv.push(".bin".to_string());
} else {
download_argv.push("--extension".to_string());
download_argv.push(".safetensors".to_string());
}
// Model optional revision
if let Some(ref revision) = revision {
download_argv.push("--revision".to_string());
download_argv.push(revision.to_string())
}
let mut env = Vec::new();
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// Start process
tracing::info!("Starting download");
let mut download_process = match Popen::create(
&download_argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
env: Some(env),
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
}
}
return ExitCode::FAILURE;
}
};
// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
thread::spawn(move || {
// Enter download tracing span
let stdout = BufReader::new(download_stdout);
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
if let Some(text) = value.get("text") {
// Format escaped newlines
tracing::info!("{}", text.to_string().replace("\\n", ""));
}
}
}
});
loop {
if let Some(status) = download_process.poll() {
match status {
ExitStatus::Exited(exit_code) => {
if exit_code == 0 {
tracing::info!("Successfully downloaded weights.");
break;
} else {
let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
tracing::error!("Download encountered an error: {err}");
return ExitCode::FAILURE;
}
}
_ => {
tracing::error!("Download process exited with an unkown status.");
return ExitCode::FAILURE;
}
}
}
if !running.load(Ordering::SeqCst) {
download_process.terminate().unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process
.wait_timeout(Duration::from_secs(90))
.unwrap();
tracing::info!("Download process terminated");
return ExitCode::SUCCESS;
}
sleep(Duration::from_millis(100));
}
} else {
tracing::info!(
"weights_cache_override is set to {:?}.",
weights_cache_override
);
tracing::info!("Skipping download.")
}
// Shared shutdown bool // Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false)); let shutdown = Arc::new(Mutex::new(false));
// Shared shutdown channel // Shared shutdown channel
@ -99,6 +223,8 @@ fn main() -> ExitCode {
let revision = revision.clone(); let revision = revision.clone();
let uds_path = shard_uds_path.clone(); let uds_path = shard_uds_path.clone();
let master_addr = master_addr.clone(); let master_addr = master_addr.clone();
let huggingface_hub_cache = huggingface_hub_cache.clone();
let weights_cache_override = weights_cache_override.clone();
let status_sender = status_sender.clone(); let status_sender = status_sender.clone();
let shutdown = shutdown.clone(); let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
@ -113,6 +239,8 @@ fn main() -> ExitCode {
num_shard, num_shard,
master_addr, master_addr,
master_port, master_port,
huggingface_hub_cache,
weights_cache_override,
otlp_endpoint, otlp_endpoint,
status_sender, status_sender,
shutdown, shutdown,
@ -232,7 +360,7 @@ fn main() -> ExitCode {
while running.load(Ordering::SeqCst) { while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {} failed:\n{}", rank, err); tracing::error!("Shard {rank} failed:\n{err}");
exit_code = ExitCode::FAILURE; exit_code = ExitCode::FAILURE;
break; break;
}; };
@ -275,6 +403,8 @@ fn shard_manager(
world_size: usize, world_size: usize,
master_addr: String, master_addr: String,
master_port: usize, master_port: usize,
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<Mutex<bool>>,
@ -328,15 +458,15 @@ fn shard_manager(
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()), ("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
]; ];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard // If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { if let Some(huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard // If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint // Useful when running inside a HuggingFace Inference Endpoint
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") { if let Some(weights_cache_override) = weights_cache_override {
env.push(( env.push((
"WEIGHTS_CACHE_OVERRIDE".into(), "WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(), weights_cache_override.into(),
@ -355,7 +485,7 @@ fn shard_manager(
}; };
// Start process // Start process
tracing::info!("Starting shard {}", rank); tracing::info!("Starting shard {rank}");
let mut p = match Popen::create( let mut p = match Popen::create(
&shard_argv, &shard_argv,
PopenConfig { PopenConfig {
@ -419,17 +549,17 @@ fn shard_manager(
if *shutdown.lock().unwrap() { if *shutdown.lock().unwrap() {
p.terminate().unwrap(); p.terminate().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90)); let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {} terminated", rank); tracing::info!("Shard {rank} terminated");
return; return;
} }
// Shard is ready // Shard is ready
if uds.exists() && !ready { if uds.exists() && !ready {
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed()); tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap(); status_sender.send(ShardStatus::Ready).unwrap();
ready = true; ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) { } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {} to be ready...", rank); tracing::info!("Waiting for shard {rank} to be ready...");
wait_time = Instant::now(); wait_time = Instant::now();
} }
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));

View File

@ -0,0 +1,17 @@
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
from text_generation.utils.convert import convert_files
def test_convert_files():
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files)
found_st_files = weight_files(model_id)
assert all([p in found_st_files for p in local_st_files])

View File

@ -0,0 +1,40 @@
import pytest
from text_generation.utils.hub import (
weight_hub_files,
download_weights,
weight_files,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
with pytest.raises(EntryNotFoundError):
weight_hub_files("bigscience/bloom", extension=".errors")
def test_download_weights():
model_id = "bigscience/bloom-560m"
filenames = weight_hub_files(model_id)
files = download_weights(filenames, model_id)
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -1,14 +1,6 @@
import pytest from text_generation.utils.tokens import (
from huggingface_hub.utils import RevisionNotFoundError
from text_generation.utils import (
weight_hub_files,
download_weights,
weight_files,
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
LocalEntryNotFoundError,
FinishReason, FinishReason,
) )
@ -41,31 +33,3 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
filenames = weight_hub_files("bigscience/bloom", extension=".errors")
assert filenames == []
def test_download_weights():
files = download_weights("bigscience/bloom-560m")
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -60,8 +60,55 @@ def download_weights(
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
extension: str = ".safetensors", extension: str = ".safetensors",
logger_level: str = "INFO",
json_output: bool = False,
): ):
utils.download_weights(model_id, revision, extension) # Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Test if files were already download
try:
utils.weight_files(model_id, revision, extension)
logger.info(
"Files are already present in the local cache. " "Skipping download."
)
return
# Local files not found
except utils.LocalEntryNotFoundError:
pass
# Download weights directly
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(filenames, model_id, revision)
except utils.EntryNotFoundError as e:
if not extension == ".safetensors":
raise e
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights instead."
)
# Try to see if there are pytorch weights
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -41,6 +41,15 @@ torch.set_grad_enabled(False)
def get_model( def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model: ) -> Model:
if model_id.startswith("facebook/galactica"):
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:
return Galactica(model_id, revision, quantize=quantize)
if "santacoder" in model_id:
return SantaCoder(model_id, revision, quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision) config = AutoConfig.from_pretrained(model_id, revision=revision)
if config.model_type == "bloom": if config.model_type == "bloom":
@ -48,27 +57,22 @@ def get_model(
return BLOOMSharded(model_id, revision, quantize=quantize) return BLOOMSharded(model_id, revision, quantize=quantize)
else: else:
return BLOOM(model_id, revision, quantize=quantize) return BLOOM(model_id, revision, quantize=quantize)
elif config.model_type == "gpt_neox":
if config.model_type == "gpt_neox":
if sharded: if sharded:
return GPTNeoxSharded(model_id, revision, quantize=quantize) return GPTNeoxSharded(model_id, revision, quantize=quantize)
else: else:
return GPTNeox(model_id, revision, quantize=quantize) return GPTNeox(model_id, revision, quantize=quantize)
elif config.model_type == "t5":
if config.model_type == "t5":
if sharded: if sharded:
return T5Sharded(model_id, revision, quantize=quantize) return T5Sharded(model_id, revision, quantize=quantize)
else: else:
return Seq2SeqLM(model_id, revision, quantize=quantize) return Seq2SeqLM(model_id, revision, quantize=quantize)
elif model_id.startswith("facebook/galactica"):
if sharded: if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize) raise ValueError("sharded is not supported for AutoModel")
else: try:
return Galactica(model_id, revision, quantize=quantize) return CausalLM(model_id, revision, quantize=quantize)
elif "santacoder" in model_id: except Exception:
return SantaCoder(model_id, revision, quantize) return Seq2SeqLM(model_id, revision, quantize=quantize)
else:
if sharded:
raise ValueError("sharded is not supported for AutoModel")
try:
return CausalLM(model_id, revision, quantize=quantize)
except Exception:
return Seq2SeqLM(model_id, revision, quantize=quantize)

View File

@ -23,7 +23,6 @@ from text_generation.pb import generate_pb2
from text_generation.utils import ( from text_generation.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights,
) )
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
@ -80,14 +79,8 @@ class BLOOMSharded(BLOOM):
) )
config.pad_token_id = 3 config.pad_token_id = 3
# Only download weights for small models
if self.master and model_id == "bigscience/bloom-560m":
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)

View File

@ -26,7 +26,6 @@ from text_generation.utils import (
StoppingCriteria, StoppingCriteria,
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights,
) )
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
@ -172,14 +171,8 @@ class GalacticaSharded(Galactica):
) )
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
# Only download weights for small models
if self.master and model_id == "facebook/galactica-125m":
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)

View File

@ -20,7 +20,6 @@ from text_generation.models import CausalLM
from text_generation.utils import ( from text_generation.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights,
) )
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
@ -69,14 +68,8 @@ class GPTNeoxSharded(GPTNeox):
model_id, revision=revision, tp_parallel=True model_id, revision=revision, tp_parallel=True
) )
# Only master download weights
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import Optional, List, Tuple from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM from text_generation.models import CausalLM

View File

@ -20,7 +20,6 @@ from text_generation.models import Seq2SeqLM
from text_generation.utils import ( from text_generation.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights,
) )
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
@ -53,14 +52,8 @@ class T5Sharded(Seq2SeqLM):
) )
tokenizer.bos_token_id = config.decoder_start_token_id tokenizer.bos_token_id = config.decoder_start_token_id
# Only master download weights
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config) model = AutoModelForSeq2SeqLM.from_config(config)

View File

@ -1,283 +0,0 @@
import concurrent
import os
import re
import torch
import torch.distributed
from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
TopKLogitsWarper,
)
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
"""Get the safetensors filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames
def try_to_load_from_cache(model_id, revision, filename):
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return _CACHED_NO_EXIST
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return str(cached_file) if cached_file.is_file() else None
def weight_files(model_id, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_id, revision, extension)
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(model_id, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_id, revision, extension)
download_function = partial(
hf_hub_download,
repo_id=model_id,
local_files_only=False,
)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_function, filename=filename, revision=revision)
for filename in filenames
]
files = [
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]
return files

View File

@ -0,0 +1,36 @@
from text_generation.utils.convert import convert_file, convert_files
from text_generation.utils.dist import initialize_torch_distributed
from text_generation.utils.hub import (
weight_files,
weight_hub_files,
download_weights,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
from text_generation.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,
StoppingCriteria,
StopSequenceCriteria,
FinishReason,
)
__all__ = [
"convert_file",
"convert_files",
"initialize_torch_distributed",
"weight_files",
"weight_hub_files",
"download_weights",
"EntryNotFoundError",
"LocalEntryNotFoundError",
"RevisionNotFoundError",
"Greedy",
"NextTokenChooser",
"Sampling",
"StoppingCriteria",
"StopSequenceCriteria",
"FinishReason",
]

View File

@ -0,0 +1,96 @@
import concurrent
import time
import torch
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from datetime import timedelta
from loguru import logger
from pathlib import Path
from safetensors.torch import load_file, save_file
from typing import Dict, List
def check_file_size(source_file: Path, target_file: Path):
"""
Check that two files are close in size
"""
source_file_size = source_file.stat().st_size
target_file_size = target_file.stat().st_size
if (source_file_size - target_file_size) / source_file_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {source_file}: {source_file_size}
- {target_file}: {target_file_size}
"""
)
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
"""
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
remove them
"""
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
# Iterate over all found memory addresses
for ptr, names in ptrs.items():
if len(names) > 1:
# Multiple tensors are point to the same memory
# Only keep the first tensor
for name in names[1:]:
tensors.pop(name)
def convert_file(pt_file: Path, st_file: Path):
"""
Convert a pytorch file to a safetensors file
"""
pt_state = torch.load(pt_file, map_location="cpu")
if "state_dict" in pt_state:
pt_state = pt_state["state_dict"]
remove_shared_pointers(pt_state)
# Tensors need to be contiguous
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
st_file.parent.mkdir(parents=True, exist_ok=True)
save_file(pt_state, str(st_file), metadata={"format": "pt"})
# Check that both files are close in size
check_file_size(pt_file, st_file)
# Load safetensors state
st_state = load_file(str(st_file))
for k in st_state:
pt_tensor = pt_state[k]
st_tensor = st_state[k]
if not torch.equal(pt_tensor, st_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], st_files: List[Path]):
assert len(pt_files) == len(st_files)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
for pt_file, st_file in zip(pt_files, st_files)
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
logger.info("Converting weights...")
start_time = time.time()
for i, future in enumerate(concurrent.futures.as_completed(futures)):
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1)
if remaining != 0:
eta = (elapsed / (i + 1)) * remaining
else:
eta = 0
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")

View File

@ -0,0 +1,35 @@
import os
import torch
from datetime import timedelta
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size

View File

@ -0,0 +1,169 @@
import time
import concurrent
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from loguru import logger
from pathlib import Path
from typing import Optional, List
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import (
LocalEntryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
)
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
if not filenames:
raise EntryNotFoundError(
f"No {extension} weights found for model {model_id} and revision {revision}.",
None,
)
return filenames
def try_to_load_from_cache(
model_id: str, revision: Optional[str], filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return None
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return cached_file if cached_file.is_file() else None
def weight_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[Path]:
"""Get the local files"""
try:
filenames = weight_hub_files(model_id, revision, extension)
except EntryNotFoundError as e:
if extension != ".safetensors":
raise e
# Try to see if there are pytorch weights
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
# Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them
filenames = [
f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
]
if WEIGHTS_CACHE_OVERRIDE is not None:
files = []
for filename in filenames:
p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
if not p.exists():
raise LocalEntryNotFoundError(
f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
)
files.append(p)
return files
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(
filenames: List[str], model_id: str, revision: Optional[str] = None
) -> List[Path]:
"""Download the safetensors files from the hub"""
def download_file(filename):
local_file = try_to_load_from_cache(model_id, revision, filename)
if local_file is not None:
logger.info(f"File {filename} already present in cache.")
return local_file
start_time = time.time()
local_file = hf_hub_download(
filename=filename,
repo_id=model_id,
revision=revision,
local_files_only=False,
)
logger.info(
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
)
return local_file
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_file, filename=filename) for filename in filenames
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
logger.info("Downloading weights...")
start_time = time.time()
files = []
for i, future in enumerate(concurrent.futures.as_completed(futures)):
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1)
if remaining != 0:
eta = (elapsed / (i + 1)) * remaining
else:
eta = 0
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
files.append(Path(future.result()))
return [Path(p) for p in files]

View File

@ -0,0 +1,142 @@
import re
import torch
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)