From 5c189c92ac8d1e81dfef23f160bd23ba4c367c5d Mon Sep 17 00:00:00 2001 From: dstnluong-google <129889805+dstnluong-google@users.noreply.github.com> Date: Tue, 20 Feb 2024 23:25:57 +0000 Subject: [PATCH] Move gcs functions file in util folder. --- server/text_generation_server/cli.py | 29 +++---------------- .../text_generation_server/utils/__init__.py | 4 +++ server/text_generation_server/utils/gcs.py | 24 +++++++++++++++ 3 files changed, 32 insertions(+), 25 deletions(-) create mode 100644 server/text_generation_server/utils/gcs.py diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 0f9e1792..c62f21d6 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -123,31 +123,10 @@ def download_weights( # Import here after the logger is added to log potential import exceptions from text_generation_server import utils - GCS_PREFIX = "gs://" - if model_id.startswith(GCS_PREFIX): - local_dir = "/tmp/gcs_model" - from google.cloud import storage - def download_gcs_dir_to_local(gcs_dir: str, local_dir: str): - if os.path.isdir(local_dir): - return - # gs://bucket_name/dir - bucket_name = gcs_dir.split('/')[2] - prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/' - client = storage.Client() - blobs = client.list_blobs(bucket_name, prefix=prefix) - if not blobs: - raise ValueError(f"No blobs found in {gcs_dir}") - for blob in blobs: - if blob.name[-1] == '/': - continue - file_path = blob.name[len(prefix) :].strip('/') - local_file_path = os.path.join(local_dir, file_path) - os.makedirs(os.path.dirname(local_file_path), exist_ok=True) - logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.") - blob.download_to_filename(local_file_path) - logger.info("Download finished.") - download_gcs_dir_to_local(model_id, local_dir) - model_id = local_dir + if model_id.startswith(utils.GCS_PREFIX): + utisls.download_gcs_dir_to_local(model_id, utils.GCS_LOCAL_DIR) + model_id = utils.GCS_LOCAL_DIR + # Test if files were already download try: utils.weight_files(model_id, revision, extension) diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index 08ba808d..a9f67c02 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -19,6 +19,7 @@ from text_generation_server.utils.tokens import ( Sampling, Greedy, ) +from text_generation_server.utils.gcs import GCS_PREFIX, GCS_LOCAL_DIR, download_gcs_dir_to_local __all__ = [ "convert_file", @@ -39,4 +40,7 @@ __all__ = [ "StopSequenceCriteria", "FinishReason", "Weights", + "GCS_PREFIX", + "GCS_LOCAL_DIR", + "download_gcs_dir_to_local", ] diff --git a/server/text_generation_server/utils/gcs.py b/server/text_generation_server/utils/gcs.py new file mode 100644 index 00000000..46099971 --- /dev/null +++ b/server/text_generation_server/utils/gcs.py @@ -0,0 +1,24 @@ +from google.cloud import storage + +GCS_PREFIX = "gs://" +GCS_LOCAL_DIR = "/tmp/gcs_model" + +def download_gcs_dir_to_local(gcs_dir: str, local_dir: str): + if os.path.isdir(local_dir): + return + # gs://bucket_name/dir + bucket_name = gcs_dir.split('/')[2] + prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/' + client = storage.Client() + blobs = client.list_blobs(bucket_name, prefix=prefix) + if not blobs: + raise ValueError(f"No blobs found in {gcs_dir}") + for blob in blobs: + if blob.name[-1] == '/': + continue + file_path = blob.name[len(prefix) :].strip('/') + local_file_path = os.path.join(local_dir, file_path) + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) + logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.") + blob.download_to_filename(local_file_path) + logger.info("Download finished.")