diff --git a/Dockerfile b/Dockerfile index e79372a3..b6560d39 100644 --- a/Dockerfile +++ b/Dockerfile @@ -218,6 +218,9 @@ COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython- # Install flash-attention dependencies RUN pip install einops --no-cache-dir +# Install GCS library +RUN pip install --upgrade google-cloud-storage + # Install server COPY proto proto COPY server server diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d52e2669..64e8bfde 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -782,7 +782,7 @@ enum LauncherError { WebserverCannotStart, } -fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { +fn download_convert_model(args: &mut Args, running: Arc) -> Result<(), LauncherError> { // Enter download tracing span let _span = tracing::span!(tracing::Level::INFO, "download").entered(); @@ -907,6 +907,9 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L } sleep(Duration::from_millis(100)); } + if args.model_id.starts_with("gs://") { + args.model_id = "/tmp/gcs_model/".to_string(); + } Ok(()) } @@ -1192,7 +1195,7 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R fn main() -> Result<(), LauncherError> { // Pattern match configuration - let args: Args = Args::parse(); + let mut args: Args = Args::parse(); // Filter events with LOG_LEVEL let env_filter = @@ -1285,7 +1288,7 @@ fn main() -> Result<(), LauncherError> { .expect("Error setting Ctrl-C handler"); // Download and convert model weights - download_convert_model(&args, running.clone())?; + download_convert_model(&mut args, running.clone())?; if !running.load(Ordering::SeqCst) { // Launcher was asked to stop diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b74fbe36..0f9e1792 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -123,6 +123,31 @@ def download_weights( # Import here after the logger is added to log potential import exceptions from text_generation_server import utils + GCS_PREFIX = "gs://" + if model_id.startswith(GCS_PREFIX): + local_dir = "/tmp/gcs_model" + from google.cloud import storage + def download_gcs_dir_to_local(gcs_dir: str, local_dir: str): + if os.path.isdir(local_dir): + return + # gs://bucket_name/dir + bucket_name = gcs_dir.split('/')[2] + prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/' + client = storage.Client() + blobs = client.list_blobs(bucket_name, prefix=prefix) + if not blobs: + raise ValueError(f"No blobs found in {gcs_dir}") + for blob in blobs: + if blob.name[-1] == '/': + continue + file_path = blob.name[len(prefix) :].strip('/') + local_file_path = os.path.join(local_dir, file_path) + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) + logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.") + blob.download_to_filename(local_file_path) + logger.info("Download finished.") + download_gcs_dir_to_local(model_id, local_dir) + model_id = local_dir # Test if files were already download try: utils.weight_files(model_id, revision, extension) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index d5adbd32..c5cd40d8 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -192,6 +192,8 @@ def serve( local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] + if model_id.startswith("gs://"): + model_id = "/tmp/gcs_model" try: model = get_model( model_id,