Move gcs functions file in util folder.

This commit is contained in:
dstnluong-google 2024-02-20 23:25:57 +00:00
parent 064c110123
commit 5c189c92ac
3 changed files with 32 additions and 25 deletions

View File

@ -123,31 +123,10 @@ def download_weights(
# Import here after the logger is added to log potential import exceptions
from text_generation_server import utils
GCS_PREFIX = "gs://"
if model_id.startswith(GCS_PREFIX):
local_dir = "/tmp/gcs_model"
from google.cloud import storage
def download_gcs_dir_to_local(gcs_dir: str, local_dir: str):
if os.path.isdir(local_dir):
return
# gs://bucket_name/dir
bucket_name = gcs_dir.split('/')[2]
prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/'
client = storage.Client()
blobs = client.list_blobs(bucket_name, prefix=prefix)
if not blobs:
raise ValueError(f"No blobs found in {gcs_dir}")
for blob in blobs:
if blob.name[-1] == '/':
continue
file_path = blob.name[len(prefix) :].strip('/')
local_file_path = os.path.join(local_dir, file_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.")
blob.download_to_filename(local_file_path)
logger.info("Download finished.")
download_gcs_dir_to_local(model_id, local_dir)
model_id = local_dir
if model_id.startswith(utils.GCS_PREFIX):
utisls.download_gcs_dir_to_local(model_id, utils.GCS_LOCAL_DIR)
model_id = utils.GCS_LOCAL_DIR
# Test if files were already download
try:
utils.weight_files(model_id, revision, extension)

View File

@ -19,6 +19,7 @@ from text_generation_server.utils.tokens import (
Sampling,
Greedy,
)
from text_generation_server.utils.gcs import GCS_PREFIX, GCS_LOCAL_DIR, download_gcs_dir_to_local
__all__ = [
"convert_file",
@ -39,4 +40,7 @@ __all__ = [
"StopSequenceCriteria",
"FinishReason",
"Weights",
"GCS_PREFIX",
"GCS_LOCAL_DIR",
"download_gcs_dir_to_local",
]

View File

@ -0,0 +1,24 @@
from google.cloud import storage
GCS_PREFIX = "gs://"
GCS_LOCAL_DIR = "/tmp/gcs_model"
def download_gcs_dir_to_local(gcs_dir: str, local_dir: str):
if os.path.isdir(local_dir):
return
# gs://bucket_name/dir
bucket_name = gcs_dir.split('/')[2]
prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/'
client = storage.Client()
blobs = client.list_blobs(bucket_name, prefix=prefix)
if not blobs:
raise ValueError(f"No blobs found in {gcs_dir}")
for blob in blobs:
if blob.name[-1] == '/':
continue
file_path = blob.name[len(prefix) :].strip('/')
local_file_path = os.path.join(local_dir, file_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.")
blob.download_to_filename(local_file_path)
logger.info("Download finished.")