Move gcs functions file in util folder.

This commit is contained in:
dstnluong-google 2024-02-20 23:25:57 +00:00
parent 064c110123
commit 5c189c92ac
3 changed files with 32 additions and 25 deletions

View File

@ -123,31 +123,10 @@ def download_weights(
# Import here after the logger is added to log potential import exceptions # Import here after the logger is added to log potential import exceptions
from text_generation_server import utils from text_generation_server import utils
GCS_PREFIX = "gs://" if model_id.startswith(utils.GCS_PREFIX):
if model_id.startswith(GCS_PREFIX): utisls.download_gcs_dir_to_local(model_id, utils.GCS_LOCAL_DIR)
local_dir = "/tmp/gcs_model" model_id = utils.GCS_LOCAL_DIR
from google.cloud import storage
def download_gcs_dir_to_local(gcs_dir: str, local_dir: str):
if os.path.isdir(local_dir):
return
# gs://bucket_name/dir
bucket_name = gcs_dir.split('/')[2]
prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/'
client = storage.Client()
blobs = client.list_blobs(bucket_name, prefix=prefix)
if not blobs:
raise ValueError(f"No blobs found in {gcs_dir}")
for blob in blobs:
if blob.name[-1] == '/':
continue
file_path = blob.name[len(prefix) :].strip('/')
local_file_path = os.path.join(local_dir, file_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.")
blob.download_to_filename(local_file_path)
logger.info("Download finished.")
download_gcs_dir_to_local(model_id, local_dir)
model_id = local_dir
# Test if files were already download # Test if files were already download
try: try:
utils.weight_files(model_id, revision, extension) utils.weight_files(model_id, revision, extension)

View File

@ -19,6 +19,7 @@ from text_generation_server.utils.tokens import (
Sampling, Sampling,
Greedy, Greedy,
) )
from text_generation_server.utils.gcs import GCS_PREFIX, GCS_LOCAL_DIR, download_gcs_dir_to_local
__all__ = [ __all__ = [
"convert_file", "convert_file",
@ -39,4 +40,7 @@ __all__ = [
"StopSequenceCriteria", "StopSequenceCriteria",
"FinishReason", "FinishReason",
"Weights", "Weights",
"GCS_PREFIX",
"GCS_LOCAL_DIR",
"download_gcs_dir_to_local",
] ]

View File

@ -0,0 +1,24 @@
from google.cloud import storage
GCS_PREFIX = "gs://"
GCS_LOCAL_DIR = "/tmp/gcs_model"
def download_gcs_dir_to_local(gcs_dir: str, local_dir: str):
if os.path.isdir(local_dir):
return
# gs://bucket_name/dir
bucket_name = gcs_dir.split('/')[2]
prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/'
client = storage.Client()
blobs = client.list_blobs(bucket_name, prefix=prefix)
if not blobs:
raise ValueError(f"No blobs found in {gcs_dir}")
for blob in blobs:
if blob.name[-1] == '/':
continue
file_path = blob.name[len(prefix) :].strip('/')
local_file_path = os.path.join(local_dir, file_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.")
blob.download_to_filename(local_file_path)
logger.info("Download finished.")