mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: Support loading models from GCS
This commit is contained in:
parent
c9f4c1af31
commit
064c110123
@ -218,6 +218,9 @@ COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
||||
# Install GCS library
|
||||
RUN pip install --upgrade google-cloud-storage
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
COPY server server
|
||||
|
@ -782,7 +782,7 @@ enum LauncherError {
|
||||
WebserverCannotStart,
|
||||
}
|
||||
|
||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
fn download_convert_model(args: &mut Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
// Enter download tracing span
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
|
||||
@ -907,6 +907,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
if args.model_id.starts_with("gs://") {
|
||||
args.model_id = "/tmp/gcs_model/".to_string();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1192,7 +1195,7 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
|
||||
|
||||
fn main() -> Result<(), LauncherError> {
|
||||
// Pattern match configuration
|
||||
let args: Args = Args::parse();
|
||||
let mut args: Args = Args::parse();
|
||||
|
||||
// Filter events with LOG_LEVEL
|
||||
let env_filter =
|
||||
@ -1285,7 +1288,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Download and convert model weights
|
||||
download_convert_model(&args, running.clone())?;
|
||||
download_convert_model(&mut args, running.clone())?;
|
||||
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
// Launcher was asked to stop
|
||||
|
@ -123,6 +123,31 @@ def download_weights(
|
||||
# Import here after the logger is added to log potential import exceptions
|
||||
from text_generation_server import utils
|
||||
|
||||
GCS_PREFIX = "gs://"
|
||||
if model_id.startswith(GCS_PREFIX):
|
||||
local_dir = "/tmp/gcs_model"
|
||||
from google.cloud import storage
|
||||
def download_gcs_dir_to_local(gcs_dir: str, local_dir: str):
|
||||
if os.path.isdir(local_dir):
|
||||
return
|
||||
# gs://bucket_name/dir
|
||||
bucket_name = gcs_dir.split('/')[2]
|
||||
prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/'
|
||||
client = storage.Client()
|
||||
blobs = client.list_blobs(bucket_name, prefix=prefix)
|
||||
if not blobs:
|
||||
raise ValueError(f"No blobs found in {gcs_dir}")
|
||||
for blob in blobs:
|
||||
if blob.name[-1] == '/':
|
||||
continue
|
||||
file_path = blob.name[len(prefix) :].strip('/')
|
||||
local_file_path = os.path.join(local_dir, file_path)
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.")
|
||||
blob.download_to_filename(local_file_path)
|
||||
logger.info("Download finished.")
|
||||
download_gcs_dir_to_local(model_id, local_dir)
|
||||
model_id = local_dir
|
||||
# Test if files were already download
|
||||
try:
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
|
@ -192,6 +192,8 @@ def serve(
|
||||
local_url = unix_socket_template.format(uds_path, 0)
|
||||
server_urls = [local_url]
|
||||
|
||||
if model_id.startswith("gs://"):
|
||||
model_id = "/tmp/gcs_model"
|
||||
try:
|
||||
model = get_model(
|
||||
model_id,
|
||||
|
Loading…
Reference in New Issue
Block a user