mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Move gcs functions file in util folder.
This commit is contained in:
parent
064c110123
commit
5c189c92ac
@ -123,31 +123,10 @@ def download_weights(
|
||||
# Import here after the logger is added to log potential import exceptions
|
||||
from text_generation_server import utils
|
||||
|
||||
GCS_PREFIX = "gs://"
|
||||
if model_id.startswith(GCS_PREFIX):
|
||||
local_dir = "/tmp/gcs_model"
|
||||
from google.cloud import storage
|
||||
def download_gcs_dir_to_local(gcs_dir: str, local_dir: str):
|
||||
if os.path.isdir(local_dir):
|
||||
return
|
||||
# gs://bucket_name/dir
|
||||
bucket_name = gcs_dir.split('/')[2]
|
||||
prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/'
|
||||
client = storage.Client()
|
||||
blobs = client.list_blobs(bucket_name, prefix=prefix)
|
||||
if not blobs:
|
||||
raise ValueError(f"No blobs found in {gcs_dir}")
|
||||
for blob in blobs:
|
||||
if blob.name[-1] == '/':
|
||||
continue
|
||||
file_path = blob.name[len(prefix) :].strip('/')
|
||||
local_file_path = os.path.join(local_dir, file_path)
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.")
|
||||
blob.download_to_filename(local_file_path)
|
||||
logger.info("Download finished.")
|
||||
download_gcs_dir_to_local(model_id, local_dir)
|
||||
model_id = local_dir
|
||||
if model_id.startswith(utils.GCS_PREFIX):
|
||||
utisls.download_gcs_dir_to_local(model_id, utils.GCS_LOCAL_DIR)
|
||||
model_id = utils.GCS_LOCAL_DIR
|
||||
|
||||
# Test if files were already download
|
||||
try:
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils.tokens import (
|
||||
Sampling,
|
||||
Greedy,
|
||||
)
|
||||
from text_generation_server.utils.gcs import GCS_PREFIX, GCS_LOCAL_DIR, download_gcs_dir_to_local
|
||||
|
||||
__all__ = [
|
||||
"convert_file",
|
||||
@ -39,4 +40,7 @@ __all__ = [
|
||||
"StopSequenceCriteria",
|
||||
"FinishReason",
|
||||
"Weights",
|
||||
"GCS_PREFIX",
|
||||
"GCS_LOCAL_DIR",
|
||||
"download_gcs_dir_to_local",
|
||||
]
|
||||
|
24
server/text_generation_server/utils/gcs.py
Normal file
24
server/text_generation_server/utils/gcs.py
Normal file
@ -0,0 +1,24 @@
|
||||
from google.cloud import storage
|
||||
|
||||
GCS_PREFIX = "gs://"
|
||||
GCS_LOCAL_DIR = "/tmp/gcs_model"
|
||||
|
||||
def download_gcs_dir_to_local(gcs_dir: str, local_dir: str):
|
||||
if os.path.isdir(local_dir):
|
||||
return
|
||||
# gs://bucket_name/dir
|
||||
bucket_name = gcs_dir.split('/')[2]
|
||||
prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/'
|
||||
client = storage.Client()
|
||||
blobs = client.list_blobs(bucket_name, prefix=prefix)
|
||||
if not blobs:
|
||||
raise ValueError(f"No blobs found in {gcs_dir}")
|
||||
for blob in blobs:
|
||||
if blob.name[-1] == '/':
|
||||
continue
|
||||
file_path = blob.name[len(prefix) :].strip('/')
|
||||
local_file_path = os.path.join(local_dir, file_path)
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.")
|
||||
blob.download_to_filename(local_file_path)
|
||||
logger.info("Download finished.")
|
Loading…
Reference in New Issue
Block a user