mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
feat: add safetensors conversion (#63)
This commit is contained in:
parent
9af454142a
commit
0fbc691946
@ -49,17 +49,17 @@ to power LLMs api-inference widgets.
|
|||||||
- Log probabilities
|
- Log probabilities
|
||||||
- Distributed tracing with Open Telemetry
|
- Distributed tracing with Open Telemetry
|
||||||
|
|
||||||
## Officially supported models
|
## Officially supported architectures
|
||||||
|
|
||||||
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
||||||
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
|
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
|
||||||
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
||||||
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
|
- [Galactica](https://huggingface.co/facebook/galactica-120b)
|
||||||
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
||||||
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||||
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
|
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
|
||||||
|
|
||||||
Other models are supported on a best effort basis using:
|
Other architectures are supported on a best effort basis using:
|
||||||
|
|
||||||
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
|
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
|
||||||
|
|
||||||
@ -191,7 +191,7 @@ Be aware that the official Docker image has them enabled by default.
|
|||||||
|
|
||||||
### Download
|
### Download
|
||||||
|
|
||||||
First you need to download the weights:
|
It is advised to download the weights ahead of time with the following command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make download-bloom
|
make download-bloom
|
||||||
|
@ -12,7 +12,7 @@ use std::thread;
|
|||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{fs, io};
|
use std::{fs, io};
|
||||||
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
|
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -43,6 +43,10 @@ struct Args {
|
|||||||
#[clap(default_value = "29500", long, env)]
|
#[clap(default_value = "29500", long, env)]
|
||||||
master_port: usize,
|
master_port: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
huggingface_hub_cache: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
weights_cache_override: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
@ -63,6 +67,8 @@ fn main() -> ExitCode {
|
|||||||
shard_uds_path,
|
shard_uds_path,
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
|
huggingface_hub_cache,
|
||||||
|
weights_cache_override,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
} = Args::parse();
|
} = Args::parse();
|
||||||
@ -84,6 +90,124 @@ fn main() -> ExitCode {
|
|||||||
})
|
})
|
||||||
.expect("Error setting Ctrl-C handler");
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
|
// Download weights
|
||||||
|
if weights_cache_override.is_none() {
|
||||||
|
let mut download_argv = vec![
|
||||||
|
"text-generation-server".to_string(),
|
||||||
|
"download-weights".to_string(),
|
||||||
|
model_id.clone(),
|
||||||
|
"--logger-level".to_string(),
|
||||||
|
"INFO".to_string(),
|
||||||
|
"--json-output".to_string(),
|
||||||
|
];
|
||||||
|
if num_shard == 1 {
|
||||||
|
download_argv.push("--extension".to_string());
|
||||||
|
download_argv.push(".bin".to_string());
|
||||||
|
} else {
|
||||||
|
download_argv.push("--extension".to_string());
|
||||||
|
download_argv.push(".safetensors".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model optional revision
|
||||||
|
if let Some(ref revision) = revision {
|
||||||
|
download_argv.push("--revision".to_string());
|
||||||
|
download_argv.push(revision.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut env = Vec::new();
|
||||||
|
|
||||||
|
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
||||||
|
// Useful when running inside a docker container
|
||||||
|
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start process
|
||||||
|
tracing::info!("Starting download");
|
||||||
|
let mut download_process = match Popen::create(
|
||||||
|
&download_argv,
|
||||||
|
PopenConfig {
|
||||||
|
stdout: Redirection::Pipe,
|
||||||
|
stderr: Redirection::Pipe,
|
||||||
|
// Needed for the shutdown procedure
|
||||||
|
setpgid: true,
|
||||||
|
env: Some(env),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(err) => {
|
||||||
|
if let PopenError::IoError(ref err) = err {
|
||||||
|
if err.kind() == io::ErrorKind::NotFound {
|
||||||
|
tracing::error!("text-generation-server not found in PATH");
|
||||||
|
tracing::error!("Please install it with `make install-server`")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Redirect STDOUT to the console
|
||||||
|
let download_stdout = download_process.stdout.take().unwrap();
|
||||||
|
thread::spawn(move || {
|
||||||
|
// Enter download tracing span
|
||||||
|
let stdout = BufReader::new(download_stdout);
|
||||||
|
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||||
|
for line in stdout.lines() {
|
||||||
|
// Parse loguru logs
|
||||||
|
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
|
||||||
|
if let Some(text) = value.get("text") {
|
||||||
|
// Format escaped newlines
|
||||||
|
tracing::info!("{}", text.to_string().replace("\\n", ""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if let Some(status) = download_process.poll() {
|
||||||
|
match status {
|
||||||
|
ExitStatus::Exited(exit_code) => {
|
||||||
|
if exit_code == 0 {
|
||||||
|
tracing::info!("Successfully downloaded weights.");
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
let mut err = String::new();
|
||||||
|
download_process
|
||||||
|
.stderr
|
||||||
|
.take()
|
||||||
|
.unwrap()
|
||||||
|
.read_to_string(&mut err)
|
||||||
|
.unwrap();
|
||||||
|
tracing::error!("Download encountered an error: {err}");
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
tracing::error!("Download process exited with an unkown status.");
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !running.load(Ordering::SeqCst) {
|
||||||
|
download_process.terminate().unwrap();
|
||||||
|
tracing::info!("Waiting for download process to gracefully shutdown");
|
||||||
|
download_process
|
||||||
|
.wait_timeout(Duration::from_secs(90))
|
||||||
|
.unwrap();
|
||||||
|
tracing::info!("Download process terminated");
|
||||||
|
return ExitCode::SUCCESS;
|
||||||
|
}
|
||||||
|
sleep(Duration::from_millis(100));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tracing::info!(
|
||||||
|
"weights_cache_override is set to {:?}.",
|
||||||
|
weights_cache_override
|
||||||
|
);
|
||||||
|
tracing::info!("Skipping download.")
|
||||||
|
}
|
||||||
|
|
||||||
// Shared shutdown bool
|
// Shared shutdown bool
|
||||||
let shutdown = Arc::new(Mutex::new(false));
|
let shutdown = Arc::new(Mutex::new(false));
|
||||||
// Shared shutdown channel
|
// Shared shutdown channel
|
||||||
@ -99,6 +223,8 @@ fn main() -> ExitCode {
|
|||||||
let revision = revision.clone();
|
let revision = revision.clone();
|
||||||
let uds_path = shard_uds_path.clone();
|
let uds_path = shard_uds_path.clone();
|
||||||
let master_addr = master_addr.clone();
|
let master_addr = master_addr.clone();
|
||||||
|
let huggingface_hub_cache = huggingface_hub_cache.clone();
|
||||||
|
let weights_cache_override = weights_cache_override.clone();
|
||||||
let status_sender = status_sender.clone();
|
let status_sender = status_sender.clone();
|
||||||
let shutdown = shutdown.clone();
|
let shutdown = shutdown.clone();
|
||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
@ -113,6 +239,8 @@ fn main() -> ExitCode {
|
|||||||
num_shard,
|
num_shard,
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
|
huggingface_hub_cache,
|
||||||
|
weights_cache_override,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
@ -232,7 +360,7 @@ fn main() -> ExitCode {
|
|||||||
|
|
||||||
while running.load(Ordering::SeqCst) {
|
while running.load(Ordering::SeqCst) {
|
||||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
||||||
tracing::error!("Shard {} failed:\n{}", rank, err);
|
tracing::error!("Shard {rank} failed:\n{err}");
|
||||||
exit_code = ExitCode::FAILURE;
|
exit_code = ExitCode::FAILURE;
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
@ -275,6 +403,8 @@ fn shard_manager(
|
|||||||
world_size: usize,
|
world_size: usize,
|
||||||
master_addr: String,
|
master_addr: String,
|
||||||
master_port: usize,
|
master_port: usize,
|
||||||
|
huggingface_hub_cache: Option<String>,
|
||||||
|
weights_cache_override: Option<String>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<Mutex<bool>>,
|
shutdown: Arc<Mutex<bool>>,
|
||||||
@ -328,15 +458,15 @@ fn shard_manager(
|
|||||||
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
|
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
|
||||||
];
|
];
|
||||||
|
|
||||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
// If huggingface_hub_cache is some, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
};
|
};
|
||||||
|
|
||||||
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
|
// If weights_cache_override is some, pass it to the shard
|
||||||
// Useful when running inside a HuggingFace Inference Endpoint
|
// Useful when running inside a HuggingFace Inference Endpoint
|
||||||
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
|
if let Some(weights_cache_override) = weights_cache_override {
|
||||||
env.push((
|
env.push((
|
||||||
"WEIGHTS_CACHE_OVERRIDE".into(),
|
"WEIGHTS_CACHE_OVERRIDE".into(),
|
||||||
weights_cache_override.into(),
|
weights_cache_override.into(),
|
||||||
@ -355,7 +485,7 @@ fn shard_manager(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {}", rank);
|
tracing::info!("Starting shard {rank}");
|
||||||
let mut p = match Popen::create(
|
let mut p = match Popen::create(
|
||||||
&shard_argv,
|
&shard_argv,
|
||||||
PopenConfig {
|
PopenConfig {
|
||||||
@ -419,17 +549,17 @@ fn shard_manager(
|
|||||||
if *shutdown.lock().unwrap() {
|
if *shutdown.lock().unwrap() {
|
||||||
p.terminate().unwrap();
|
p.terminate().unwrap();
|
||||||
let _ = p.wait_timeout(Duration::from_secs(90));
|
let _ = p.wait_timeout(Duration::from_secs(90));
|
||||||
tracing::info!("Shard {} terminated", rank);
|
tracing::info!("Shard {rank} terminated");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shard is ready
|
// Shard is ready
|
||||||
if uds.exists() && !ready {
|
if uds.exists() && !ready {
|
||||||
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
|
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
|
||||||
status_sender.send(ShardStatus::Ready).unwrap();
|
status_sender.send(ShardStatus::Ready).unwrap();
|
||||||
ready = true;
|
ready = true;
|
||||||
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
||||||
tracing::info!("Waiting for shard {} to be ready...", rank);
|
tracing::info!("Waiting for shard {rank} to be ready...");
|
||||||
wait_time = Instant::now();
|
wait_time = Instant::now();
|
||||||
}
|
}
|
||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
|
17
server/tests/utils/test_convert.py
Normal file
17
server/tests/utils/test_convert.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
|
||||||
|
|
||||||
|
from text_generation.utils.convert import convert_files
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_files():
|
||||||
|
model_id = "bigscience/bloom-560m"
|
||||||
|
pt_filenames = weight_hub_files(model_id, extension=".bin")
|
||||||
|
local_pt_files = download_weights(pt_filenames, model_id)
|
||||||
|
local_st_files = [
|
||||||
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
|
||||||
|
]
|
||||||
|
convert_files(local_pt_files, local_st_files)
|
||||||
|
|
||||||
|
found_st_files = weight_files(model_id)
|
||||||
|
|
||||||
|
assert all([p in found_st_files for p in local_st_files])
|
40
server/tests/utils/test_hub.py
Normal file
40
server/tests/utils/test_hub.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation.utils.hub import (
|
||||||
|
weight_hub_files,
|
||||||
|
download_weights,
|
||||||
|
weight_files,
|
||||||
|
EntryNotFoundError,
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files():
|
||||||
|
filenames = weight_hub_files("bigscience/bloom-560m")
|
||||||
|
assert filenames == ["model.safetensors"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files_llm():
|
||||||
|
filenames = weight_hub_files("bigscience/bloom")
|
||||||
|
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files_empty():
|
||||||
|
with pytest.raises(EntryNotFoundError):
|
||||||
|
weight_hub_files("bigscience/bloom", extension=".errors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_weights():
|
||||||
|
model_id = "bigscience/bloom-560m"
|
||||||
|
filenames = weight_hub_files(model_id)
|
||||||
|
files = download_weights(filenames, model_id)
|
||||||
|
local_files = weight_files("bigscience/bloom-560m")
|
||||||
|
assert files == local_files
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_files_error():
|
||||||
|
with pytest.raises(RevisionNotFoundError):
|
||||||
|
weight_files("bigscience/bloom-560m", revision="error")
|
||||||
|
with pytest.raises(LocalEntryNotFoundError):
|
||||||
|
weight_files("bert-base-uncased")
|
@ -1,14 +1,6 @@
|
|||||||
import pytest
|
from text_generation.utils.tokens import (
|
||||||
|
|
||||||
from huggingface_hub.utils import RevisionNotFoundError
|
|
||||||
|
|
||||||
from text_generation.utils import (
|
|
||||||
weight_hub_files,
|
|
||||||
download_weights,
|
|
||||||
weight_files,
|
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
LocalEntryNotFoundError,
|
|
||||||
FinishReason,
|
FinishReason,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -41,31 +33,3 @@ def test_stopping_criteria_max():
|
|||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files():
|
|
||||||
filenames = weight_hub_files("bigscience/bloom-560m")
|
|
||||||
assert filenames == ["model.safetensors"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files_llm():
|
|
||||||
filenames = weight_hub_files("bigscience/bloom")
|
|
||||||
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
|
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files_empty():
|
|
||||||
filenames = weight_hub_files("bigscience/bloom", extension=".errors")
|
|
||||||
assert filenames == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_download_weights():
|
|
||||||
files = download_weights("bigscience/bloom-560m")
|
|
||||||
local_files = weight_files("bigscience/bloom-560m")
|
|
||||||
assert files == local_files
|
|
||||||
|
|
||||||
|
|
||||||
def test_weight_files_error():
|
|
||||||
with pytest.raises(RevisionNotFoundError):
|
|
||||||
weight_files("bigscience/bloom-560m", revision="error")
|
|
||||||
with pytest.raises(LocalEntryNotFoundError):
|
|
||||||
weight_files("bert-base-uncased")
|
|
@ -60,8 +60,55 @@ def download_weights(
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
extension: str = ".safetensors",
|
extension: str = ".safetensors",
|
||||||
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
):
|
):
|
||||||
utils.download_weights(model_id, revision, extension)
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="{message}",
|
||||||
|
filter="text_generation",
|
||||||
|
level=logger_level,
|
||||||
|
serialize=json_output,
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test if files were already download
|
||||||
|
try:
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
logger.info(
|
||||||
|
"Files are already present in the local cache. " "Skipping download."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Local files not found
|
||||||
|
except utils.LocalEntryNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Download weights directly
|
||||||
|
try:
|
||||||
|
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||||
|
utils.download_weights(filenames, model_id, revision)
|
||||||
|
except utils.EntryNotFoundError as e:
|
||||||
|
if not extension == ".safetensors":
|
||||||
|
raise e
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||||
|
f"Converting PyTorch weights instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to see if there are pytorch weights
|
||||||
|
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
|
||||||
|
# Download pytorch weights
|
||||||
|
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||||
|
local_st_files = [
|
||||||
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||||
|
for p in local_pt_files
|
||||||
|
]
|
||||||
|
# Convert pytorch weights to safetensors
|
||||||
|
utils.convert_files(local_pt_files, local_st_files)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -41,6 +41,15 @@ torch.set_grad_enabled(False)
|
|||||||
def get_model(
|
def get_model(
|
||||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
if model_id.startswith("facebook/galactica"):
|
||||||
|
if sharded:
|
||||||
|
return GalacticaSharded(model_id, revision, quantize=quantize)
|
||||||
|
else:
|
||||||
|
return Galactica(model_id, revision, quantize=quantize)
|
||||||
|
|
||||||
|
if "santacoder" in model_id:
|
||||||
|
return SantaCoder(model_id, revision, quantize)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||||
|
|
||||||
if config.model_type == "bloom":
|
if config.model_type == "bloom":
|
||||||
@ -48,27 +57,22 @@ def get_model(
|
|||||||
return BLOOMSharded(model_id, revision, quantize=quantize)
|
return BLOOMSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return BLOOM(model_id, revision, quantize=quantize)
|
return BLOOM(model_id, revision, quantize=quantize)
|
||||||
elif config.model_type == "gpt_neox":
|
|
||||||
|
if config.model_type == "gpt_neox":
|
||||||
if sharded:
|
if sharded:
|
||||||
return GPTNeoxSharded(model_id, revision, quantize=quantize)
|
return GPTNeoxSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return GPTNeox(model_id, revision, quantize=quantize)
|
return GPTNeox(model_id, revision, quantize=quantize)
|
||||||
elif config.model_type == "t5":
|
|
||||||
|
if config.model_type == "t5":
|
||||||
if sharded:
|
if sharded:
|
||||||
return T5Sharded(model_id, revision, quantize=quantize)
|
return T5Sharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||||
elif model_id.startswith("facebook/galactica"):
|
|
||||||
if sharded:
|
if sharded:
|
||||||
return GalacticaSharded(model_id, revision, quantize=quantize)
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
else:
|
try:
|
||||||
return Galactica(model_id, revision, quantize=quantize)
|
return CausalLM(model_id, revision, quantize=quantize)
|
||||||
elif "santacoder" in model_id:
|
except Exception:
|
||||||
return SantaCoder(model_id, revision, quantize)
|
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||||
else:
|
|
||||||
if sharded:
|
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
|
||||||
try:
|
|
||||||
return CausalLM(model_id, revision, quantize=quantize)
|
|
||||||
except Exception:
|
|
||||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
|
||||||
|
@ -23,7 +23,6 @@ from text_generation.pb import generate_pb2
|
|||||||
from text_generation.utils import (
|
from text_generation.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
@ -80,14 +79,8 @@ class BLOOMSharded(BLOOM):
|
|||||||
)
|
)
|
||||||
config.pad_token_id = 3
|
config.pad_token_id = 3
|
||||||
|
|
||||||
# Only download weights for small models
|
|
||||||
if self.master and model_id == "bigscience/bloom-560m":
|
|
||||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if not filenames:
|
|
||||||
raise ValueError("No safetensors weights found")
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
@ -26,7 +26,6 @@ from text_generation.utils import (
|
|||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
@ -172,14 +171,8 @@ class GalacticaSharded(Galactica):
|
|||||||
)
|
)
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
# Only download weights for small models
|
|
||||||
if self.master and model_id == "facebook/galactica-125m":
|
|
||||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if not filenames:
|
|
||||||
raise ValueError("No safetensors weights found")
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
@ -20,7 +20,6 @@ from text_generation.models import CausalLM
|
|||||||
from text_generation.utils import (
|
from text_generation.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
@ -69,14 +68,8 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
model_id, revision=revision, tp_parallel=True
|
model_id, revision=revision, tp_parallel=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only master download weights
|
|
||||||
if self.master:
|
|
||||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if not filenames:
|
|
||||||
raise ValueError("No safetensors weights found")
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation.models import CausalLM
|
||||||
|
@ -20,7 +20,6 @@ from text_generation.models import Seq2SeqLM
|
|||||||
from text_generation.utils import (
|
from text_generation.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
@ -53,14 +52,8 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
)
|
)
|
||||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||||
|
|
||||||
# Only master download weights
|
|
||||||
if self.master:
|
|
||||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if not filenames:
|
|
||||||
raise ValueError("No safetensors weights found")
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForSeq2SeqLM.from_config(config)
|
model = AutoModelForSeq2SeqLM.from_config(config)
|
||||||
|
@ -1,283 +0,0 @@
|
|||||||
import concurrent
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
|
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
|
||||||
from huggingface_hub.utils import LocalEntryNotFoundError
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
from transformers.generation.logits_process import (
|
|
||||||
LogitsProcessorList,
|
|
||||||
RepetitionPenaltyLogitsProcessor,
|
|
||||||
TemperatureLogitsWarper,
|
|
||||||
TopPLogitsWarper,
|
|
||||||
TopKLogitsWarper,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
|
||||||
from text_generation.pb.generate_pb2 import FinishReason
|
|
||||||
|
|
||||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
|
||||||
|
|
||||||
|
|
||||||
class Sampling:
|
|
||||||
def __init__(self, seed: int, device: str = "cpu"):
|
|
||||||
self.generator = torch.Generator(device)
|
|
||||||
self.generator.manual_seed(seed)
|
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
def __call__(self, logits):
|
|
||||||
probs = torch.nn.functional.softmax(logits)
|
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class Greedy:
|
|
||||||
def __call__(self, logits):
|
|
||||||
return logits.argmax()
|
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
temperature=1.0,
|
|
||||||
repetition_penalty=1.0,
|
|
||||||
top_k=None,
|
|
||||||
top_p=None,
|
|
||||||
do_sample=False,
|
|
||||||
seed=0,
|
|
||||||
device="cpu",
|
|
||||||
):
|
|
||||||
warpers = LogitsProcessorList()
|
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
|
||||||
# all samplers can be found in `generation_utils_samplers.py`
|
|
||||||
sampling = do_sample
|
|
||||||
if temperature is not None and temperature != 1.0:
|
|
||||||
temperature = float(temperature)
|
|
||||||
warpers.append(TemperatureLogitsWarper(temperature))
|
|
||||||
sampling = True
|
|
||||||
if top_k is not None and top_k != 0:
|
|
||||||
warpers.append(TopKLogitsWarper(top_k=top_k))
|
|
||||||
sampling = True
|
|
||||||
if top_p is not None and top_p < 1.0:
|
|
||||||
warpers.append(TopPLogitsWarper(top_p=top_p))
|
|
||||||
sampling = True
|
|
||||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
|
||||||
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
|
||||||
|
|
||||||
self.warpers = warpers
|
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
|
||||||
# Warp logits
|
|
||||||
scores = self.warpers(input_ids, scores)
|
|
||||||
|
|
||||||
# Compute logprobs
|
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
|
||||||
|
|
||||||
# Choose tokens
|
|
||||||
next_id = self.choice(scores[-1])
|
|
||||||
|
|
||||||
return next_id.view(1, 1), logprobs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
|
|
||||||
) -> "NextTokenChooser":
|
|
||||||
return NextTokenChooser(
|
|
||||||
temperature=pb.temperature,
|
|
||||||
repetition_penalty=pb.repetition_penalty,
|
|
||||||
top_k=pb.top_k,
|
|
||||||
top_p=pb.top_p,
|
|
||||||
do_sample=pb.do_sample,
|
|
||||||
seed=pb.seed,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StopSequenceCriteria:
|
|
||||||
def __init__(self, stop_sequence: str):
|
|
||||||
self.regex = re.compile(f".*{stop_sequence}$")
|
|
||||||
|
|
||||||
def __call__(self, output: str) -> bool:
|
|
||||||
if self.regex.findall(output):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteria:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
eos_token_id: int,
|
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
|
||||||
max_new_tokens=20,
|
|
||||||
):
|
|
||||||
self.eos_token_id = eos_token_id
|
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
|
||||||
self.max_new_tokens = max_new_tokens
|
|
||||||
self.current_tokens = 0
|
|
||||||
self.current_output = ""
|
|
||||||
|
|
||||||
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
|
||||||
self.current_tokens += 1
|
|
||||||
if self.current_tokens >= self.max_new_tokens:
|
|
||||||
return True, FinishReason.FINISH_REASON_LENGTH
|
|
||||||
|
|
||||||
if last_token == self.eos_token_id:
|
|
||||||
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
|
||||||
|
|
||||||
self.current_output += last_output
|
|
||||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
|
||||||
if stop_sequence_criteria(self.current_output):
|
|
||||||
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
|
||||||
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls,
|
|
||||||
pb: generate_pb2.StoppingCriteriaParameters,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
) -> "StoppingCriteria":
|
|
||||||
stop_sequence_criterias = [
|
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
|
||||||
]
|
|
||||||
return StoppingCriteria(
|
|
||||||
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_torch_distributed():
|
|
||||||
rank = int(os.getenv("RANK", "0"))
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
from torch.distributed import ProcessGroupNCCL
|
|
||||||
|
|
||||||
# Set the device id.
|
|
||||||
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
|
|
||||||
device = rank % torch.cuda.device_count()
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
backend = "nccl"
|
|
||||||
options = ProcessGroupNCCL.Options()
|
|
||||||
options.is_high_priority_stream = True
|
|
||||||
options._timeout = timedelta(seconds=60)
|
|
||||||
else:
|
|
||||||
backend = "gloo"
|
|
||||||
options = None
|
|
||||||
|
|
||||||
# Call the init process.
|
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend=backend,
|
|
||||||
world_size=world_size,
|
|
||||||
rank=rank,
|
|
||||||
timeout=timedelta(seconds=60),
|
|
||||||
pg_options=options,
|
|
||||||
)
|
|
||||||
|
|
||||||
return torch.distributed.group.WORLD, rank, world_size
|
|
||||||
|
|
||||||
|
|
||||||
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
|
|
||||||
"""Get the safetensors filenames on the hub"""
|
|
||||||
api = HfApi()
|
|
||||||
info = api.model_info(model_id, revision=revision)
|
|
||||||
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
|
||||||
return filenames
|
|
||||||
|
|
||||||
|
|
||||||
def try_to_load_from_cache(model_id, revision, filename):
|
|
||||||
"""Try to load a file from the Hugging Face cache"""
|
|
||||||
if revision is None:
|
|
||||||
revision = "main"
|
|
||||||
|
|
||||||
object_id = model_id.replace("/", "--")
|
|
||||||
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
|
|
||||||
|
|
||||||
if not repo_cache.is_dir():
|
|
||||||
# No cache for this model
|
|
||||||
return None
|
|
||||||
|
|
||||||
refs_dir = repo_cache / "refs"
|
|
||||||
snapshots_dir = repo_cache / "snapshots"
|
|
||||||
no_exist_dir = repo_cache / ".no_exist"
|
|
||||||
|
|
||||||
# Resolve refs (for instance to convert main to the associated commit sha)
|
|
||||||
if refs_dir.is_dir():
|
|
||||||
revision_file = refs_dir / revision
|
|
||||||
if revision_file.exists():
|
|
||||||
with revision_file.open() as f:
|
|
||||||
revision = f.read()
|
|
||||||
|
|
||||||
# Check if file is cached as "no_exist"
|
|
||||||
if (no_exist_dir / revision / filename).is_file():
|
|
||||||
return _CACHED_NO_EXIST
|
|
||||||
|
|
||||||
# Check if revision folder exists
|
|
||||||
if not snapshots_dir.exists():
|
|
||||||
return None
|
|
||||||
cached_shas = os.listdir(snapshots_dir)
|
|
||||||
if revision not in cached_shas:
|
|
||||||
# No cache for this revision and we won't try to return a random revision
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check if file exists in cache
|
|
||||||
cached_file = snapshots_dir / revision / filename
|
|
||||||
return str(cached_file) if cached_file.is_file() else None
|
|
||||||
|
|
||||||
|
|
||||||
def weight_files(model_id, revision=None, extension=".safetensors"):
|
|
||||||
"""Get the local safetensors filenames"""
|
|
||||||
if WEIGHTS_CACHE_OVERRIDE is not None:
|
|
||||||
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
|
||||||
|
|
||||||
filenames = weight_hub_files(model_id, revision, extension)
|
|
||||||
files = []
|
|
||||||
for filename in filenames:
|
|
||||||
cache_file = try_to_load_from_cache(
|
|
||||||
model_id, revision=revision, filename=filename
|
|
||||||
)
|
|
||||||
if cache_file is None:
|
|
||||||
raise LocalEntryNotFoundError(
|
|
||||||
f"File {filename} of model {model_id} not found in "
|
|
||||||
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
|
||||||
f"Please run `text-generation-server download-weights {model_id}` first."
|
|
||||||
)
|
|
||||||
files.append(cache_file)
|
|
||||||
|
|
||||||
return files
|
|
||||||
|
|
||||||
|
|
||||||
def download_weights(model_id, revision=None, extension=".safetensors"):
|
|
||||||
"""Download the safetensors files from the hub"""
|
|
||||||
if WEIGHTS_CACHE_OVERRIDE is not None:
|
|
||||||
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
|
||||||
|
|
||||||
filenames = weight_hub_files(model_id, revision, extension)
|
|
||||||
|
|
||||||
download_function = partial(
|
|
||||||
hf_hub_download,
|
|
||||||
repo_id=model_id,
|
|
||||||
local_files_only=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=5)
|
|
||||||
futures = [
|
|
||||||
executor.submit(download_function, filename=filename, revision=revision)
|
|
||||||
for filename in filenames
|
|
||||||
]
|
|
||||||
files = [
|
|
||||||
future.result()
|
|
||||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
|
|
||||||
]
|
|
||||||
|
|
||||||
return files
|
|
36
server/text_generation/utils/__init__.py
Normal file
36
server/text_generation/utils/__init__.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from text_generation.utils.convert import convert_file, convert_files
|
||||||
|
from text_generation.utils.dist import initialize_torch_distributed
|
||||||
|
from text_generation.utils.hub import (
|
||||||
|
weight_files,
|
||||||
|
weight_hub_files,
|
||||||
|
download_weights,
|
||||||
|
EntryNotFoundError,
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
)
|
||||||
|
from text_generation.utils.tokens import (
|
||||||
|
Greedy,
|
||||||
|
NextTokenChooser,
|
||||||
|
Sampling,
|
||||||
|
StoppingCriteria,
|
||||||
|
StopSequenceCriteria,
|
||||||
|
FinishReason,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"convert_file",
|
||||||
|
"convert_files",
|
||||||
|
"initialize_torch_distributed",
|
||||||
|
"weight_files",
|
||||||
|
"weight_hub_files",
|
||||||
|
"download_weights",
|
||||||
|
"EntryNotFoundError",
|
||||||
|
"LocalEntryNotFoundError",
|
||||||
|
"RevisionNotFoundError",
|
||||||
|
"Greedy",
|
||||||
|
"NextTokenChooser",
|
||||||
|
"Sampling",
|
||||||
|
"StoppingCriteria",
|
||||||
|
"StopSequenceCriteria",
|
||||||
|
"FinishReason",
|
||||||
|
]
|
96
server/text_generation/utils/convert.py
Normal file
96
server/text_generation/utils/convert.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import concurrent
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import timedelta
|
||||||
|
from loguru import logger
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_size(source_file: Path, target_file: Path):
|
||||||
|
"""
|
||||||
|
Check that two files are close in size
|
||||||
|
"""
|
||||||
|
source_file_size = source_file.stat().st_size
|
||||||
|
target_file_size = target_file.stat().st_size
|
||||||
|
|
||||||
|
if (source_file_size - target_file_size) / source_file_size > 0.01:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"""The file size different is more than 1%:
|
||||||
|
- {source_file}: {source_file_size}
|
||||||
|
- {target_file}: {target_file_size}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
|
||||||
|
"""
|
||||||
|
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
|
||||||
|
remove them
|
||||||
|
"""
|
||||||
|
ptrs = defaultdict(list)
|
||||||
|
for k, v in tensors.items():
|
||||||
|
ptrs[v.data_ptr()].append(k)
|
||||||
|
|
||||||
|
# Iterate over all found memory addresses
|
||||||
|
for ptr, names in ptrs.items():
|
||||||
|
if len(names) > 1:
|
||||||
|
# Multiple tensors are point to the same memory
|
||||||
|
# Only keep the first tensor
|
||||||
|
for name in names[1:]:
|
||||||
|
tensors.pop(name)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_file(pt_file: Path, st_file: Path):
|
||||||
|
"""
|
||||||
|
Convert a pytorch file to a safetensors file
|
||||||
|
"""
|
||||||
|
pt_state = torch.load(pt_file, map_location="cpu")
|
||||||
|
if "state_dict" in pt_state:
|
||||||
|
pt_state = pt_state["state_dict"]
|
||||||
|
|
||||||
|
remove_shared_pointers(pt_state)
|
||||||
|
|
||||||
|
# Tensors need to be contiguous
|
||||||
|
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
|
||||||
|
|
||||||
|
st_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_file(pt_state, str(st_file), metadata={"format": "pt"})
|
||||||
|
|
||||||
|
# Check that both files are close in size
|
||||||
|
check_file_size(pt_file, st_file)
|
||||||
|
|
||||||
|
# Load safetensors state
|
||||||
|
st_state = load_file(str(st_file))
|
||||||
|
for k in st_state:
|
||||||
|
pt_tensor = pt_state[k]
|
||||||
|
st_tensor = st_state[k]
|
||||||
|
if not torch.equal(pt_tensor, st_tensor):
|
||||||
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_files(pt_files: List[Path], st_files: List[Path]):
|
||||||
|
assert len(pt_files) == len(st_files)
|
||||||
|
|
||||||
|
executor = ThreadPoolExecutor(max_workers=5)
|
||||||
|
futures = [
|
||||||
|
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
|
||||||
|
for pt_file, st_file in zip(pt_files, st_files)
|
||||||
|
]
|
||||||
|
|
||||||
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
|
logger.info("Converting weights...")
|
||||||
|
start_time = time.time()
|
||||||
|
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||||
|
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||||
|
remaining = len(futures) - (i + 1)
|
||||||
|
if remaining != 0:
|
||||||
|
eta = (elapsed / (i + 1)) * remaining
|
||||||
|
else:
|
||||||
|
eta = 0
|
||||||
|
|
||||||
|
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
35
server/text_generation/utils/dist.py
Normal file
35
server/text_generation/utils/dist.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_torch_distributed():
|
||||||
|
rank = int(os.getenv("RANK", "0"))
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
from torch.distributed import ProcessGroupNCCL
|
||||||
|
|
||||||
|
# Set the device id.
|
||||||
|
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
|
||||||
|
device = rank % torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
backend = "nccl"
|
||||||
|
options = ProcessGroupNCCL.Options()
|
||||||
|
options.is_high_priority_stream = True
|
||||||
|
options._timeout = timedelta(seconds=60)
|
||||||
|
else:
|
||||||
|
backend = "gloo"
|
||||||
|
options = None
|
||||||
|
|
||||||
|
# Call the init process.
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=backend,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.distributed.group.WORLD, rank, world_size
|
169
server/text_generation/utils/hub.py
Normal file
169
server/text_generation/utils/hub.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import time
|
||||||
|
import concurrent
|
||||||
|
import os
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from datetime import timedelta
|
||||||
|
from loguru import logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
|
from huggingface_hub.utils import (
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
EntryNotFoundError,
|
||||||
|
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
|
||||||
|
)
|
||||||
|
|
||||||
|
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
|
|
||||||
|
|
||||||
|
def weight_hub_files(
|
||||||
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
|
) -> List[str]:
|
||||||
|
"""Get the weights filenames on the hub"""
|
||||||
|
api = HfApi()
|
||||||
|
info = api.model_info(model_id, revision=revision)
|
||||||
|
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
||||||
|
|
||||||
|
if not filenames:
|
||||||
|
raise EntryNotFoundError(
|
||||||
|
f"No {extension} weights found for model {model_id} and revision {revision}.",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
|
def try_to_load_from_cache(
|
||||||
|
model_id: str, revision: Optional[str], filename: str
|
||||||
|
) -> Optional[Path]:
|
||||||
|
"""Try to load a file from the Hugging Face cache"""
|
||||||
|
if revision is None:
|
||||||
|
revision = "main"
|
||||||
|
|
||||||
|
object_id = model_id.replace("/", "--")
|
||||||
|
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
|
||||||
|
|
||||||
|
if not repo_cache.is_dir():
|
||||||
|
# No cache for this model
|
||||||
|
return None
|
||||||
|
|
||||||
|
refs_dir = repo_cache / "refs"
|
||||||
|
snapshots_dir = repo_cache / "snapshots"
|
||||||
|
no_exist_dir = repo_cache / ".no_exist"
|
||||||
|
|
||||||
|
# Resolve refs (for instance to convert main to the associated commit sha)
|
||||||
|
if refs_dir.is_dir():
|
||||||
|
revision_file = refs_dir / revision
|
||||||
|
if revision_file.exists():
|
||||||
|
with revision_file.open() as f:
|
||||||
|
revision = f.read()
|
||||||
|
|
||||||
|
# Check if file is cached as "no_exist"
|
||||||
|
if (no_exist_dir / revision / filename).is_file():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if revision folder exists
|
||||||
|
if not snapshots_dir.exists():
|
||||||
|
return None
|
||||||
|
cached_shas = os.listdir(snapshots_dir)
|
||||||
|
if revision not in cached_shas:
|
||||||
|
# No cache for this revision and we won't try to return a random revision
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if file exists in cache
|
||||||
|
cached_file = snapshots_dir / revision / filename
|
||||||
|
return cached_file if cached_file.is_file() else None
|
||||||
|
|
||||||
|
|
||||||
|
def weight_files(
|
||||||
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
|
) -> List[Path]:
|
||||||
|
"""Get the local files"""
|
||||||
|
try:
|
||||||
|
filenames = weight_hub_files(model_id, revision, extension)
|
||||||
|
except EntryNotFoundError as e:
|
||||||
|
if extension != ".safetensors":
|
||||||
|
raise e
|
||||||
|
# Try to see if there are pytorch weights
|
||||||
|
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
|
||||||
|
# Change pytorch extension to safetensors extension
|
||||||
|
# It is possible that we have safetensors weights locally even though they are not on the
|
||||||
|
# hub if we converted weights locally without pushing them
|
||||||
|
filenames = [
|
||||||
|
f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
|
||||||
|
]
|
||||||
|
|
||||||
|
if WEIGHTS_CACHE_OVERRIDE is not None:
|
||||||
|
files = []
|
||||||
|
for filename in filenames:
|
||||||
|
p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
|
||||||
|
if not p.exists():
|
||||||
|
raise LocalEntryNotFoundError(
|
||||||
|
f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
|
||||||
|
)
|
||||||
|
files.append(p)
|
||||||
|
return files
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for filename in filenames:
|
||||||
|
cache_file = try_to_load_from_cache(
|
||||||
|
model_id, revision=revision, filename=filename
|
||||||
|
)
|
||||||
|
if cache_file is None:
|
||||||
|
raise LocalEntryNotFoundError(
|
||||||
|
f"File {filename} of model {model_id} not found in "
|
||||||
|
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
||||||
|
f"Please run `text-generation-server download-weights {model_id}` first."
|
||||||
|
)
|
||||||
|
files.append(cache_file)
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def download_weights(
|
||||||
|
filenames: List[str], model_id: str, revision: Optional[str] = None
|
||||||
|
) -> List[Path]:
|
||||||
|
"""Download the safetensors files from the hub"""
|
||||||
|
|
||||||
|
def download_file(filename):
|
||||||
|
local_file = try_to_load_from_cache(model_id, revision, filename)
|
||||||
|
if local_file is not None:
|
||||||
|
logger.info(f"File {filename} already present in cache.")
|
||||||
|
return local_file
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
local_file = hf_hub_download(
|
||||||
|
filename=filename,
|
||||||
|
repo_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
||||||
|
)
|
||||||
|
return local_file
|
||||||
|
|
||||||
|
executor = ThreadPoolExecutor(max_workers=5)
|
||||||
|
futures = [
|
||||||
|
executor.submit(download_file, filename=filename) for filename in filenames
|
||||||
|
]
|
||||||
|
|
||||||
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
|
logger.info("Downloading weights...")
|
||||||
|
start_time = time.time()
|
||||||
|
files = []
|
||||||
|
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||||
|
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||||
|
remaining = len(futures) - (i + 1)
|
||||||
|
if remaining != 0:
|
||||||
|
eta = (elapsed / (i + 1)) * remaining
|
||||||
|
else:
|
||||||
|
eta = 0
|
||||||
|
|
||||||
|
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
||||||
|
files.append(Path(future.result()))
|
||||||
|
|
||||||
|
return [Path(p) for p in files]
|
142
server/text_generation/utils/tokens.py
Normal file
142
server/text_generation/utils/tokens.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import re
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
LogitsProcessorList,
|
||||||
|
TemperatureLogitsWarper,
|
||||||
|
TopKLogitsWarper,
|
||||||
|
TopPLogitsWarper,
|
||||||
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
)
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
from text_generation.pb import generate_pb2
|
||||||
|
from text_generation.pb.generate_pb2 import FinishReason
|
||||||
|
|
||||||
|
|
||||||
|
class Sampling:
|
||||||
|
def __init__(self, seed: int, device: str = "cpu"):
|
||||||
|
self.generator = torch.Generator(device)
|
||||||
|
self.generator.manual_seed(seed)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, logits):
|
||||||
|
probs = torch.nn.functional.softmax(logits)
|
||||||
|
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class Greedy:
|
||||||
|
def __call__(self, logits):
|
||||||
|
return logits.argmax()
|
||||||
|
|
||||||
|
|
||||||
|
class NextTokenChooser:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature=1.0,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
top_k=None,
|
||||||
|
top_p=None,
|
||||||
|
do_sample=False,
|
||||||
|
seed=0,
|
||||||
|
device="cpu",
|
||||||
|
):
|
||||||
|
warpers = LogitsProcessorList()
|
||||||
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
|
sampling = do_sample
|
||||||
|
if temperature is not None and temperature != 1.0:
|
||||||
|
temperature = float(temperature)
|
||||||
|
warpers.append(TemperatureLogitsWarper(temperature))
|
||||||
|
sampling = True
|
||||||
|
if top_k is not None and top_k != 0:
|
||||||
|
warpers.append(TopKLogitsWarper(top_k=top_k))
|
||||||
|
sampling = True
|
||||||
|
if top_p is not None and top_p < 1.0:
|
||||||
|
warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
|
sampling = True
|
||||||
|
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||||
|
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||||
|
|
||||||
|
self.warpers = warpers
|
||||||
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores):
|
||||||
|
# Warp logits
|
||||||
|
scores = self.warpers(input_ids, scores)
|
||||||
|
|
||||||
|
# Compute logprobs
|
||||||
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
|
|
||||||
|
# Choose tokens
|
||||||
|
next_id = self.choice(scores[-1])
|
||||||
|
|
||||||
|
return next_id.view(1, 1), logprobs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
|
||||||
|
) -> "NextTokenChooser":
|
||||||
|
return NextTokenChooser(
|
||||||
|
temperature=pb.temperature,
|
||||||
|
repetition_penalty=pb.repetition_penalty,
|
||||||
|
top_k=pb.top_k,
|
||||||
|
top_p=pb.top_p,
|
||||||
|
do_sample=pb.do_sample,
|
||||||
|
seed=pb.seed,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StopSequenceCriteria:
|
||||||
|
def __init__(self, stop_sequence: str):
|
||||||
|
self.regex = re.compile(f".*{stop_sequence}$")
|
||||||
|
|
||||||
|
def __call__(self, output: str) -> bool:
|
||||||
|
if self.regex.findall(output):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class StoppingCriteria:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
eos_token_id: int,
|
||||||
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
|
max_new_tokens=20,
|
||||||
|
):
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
|
self.max_new_tokens = max_new_tokens
|
||||||
|
self.current_tokens = 0
|
||||||
|
self.current_output = ""
|
||||||
|
|
||||||
|
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
self.current_tokens += 1
|
||||||
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
|
return True, FinishReason.FINISH_REASON_LENGTH
|
||||||
|
|
||||||
|
if last_token == self.eos_token_id:
|
||||||
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||||
|
|
||||||
|
self.current_output += last_output
|
||||||
|
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||||
|
if stop_sequence_criteria(self.current_output):
|
||||||
|
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.StoppingCriteriaParameters,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> "StoppingCriteria":
|
||||||
|
stop_sequence_criterias = [
|
||||||
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
|
]
|
||||||
|
return StoppingCriteria(
|
||||||
|
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user