mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Move gcs functions file in util folder.
This commit is contained in:
parent
064c110123
commit
5c189c92ac
@ -123,31 +123,10 @@ def download_weights(
|
|||||||
# Import here after the logger is added to log potential import exceptions
|
# Import here after the logger is added to log potential import exceptions
|
||||||
from text_generation_server import utils
|
from text_generation_server import utils
|
||||||
|
|
||||||
GCS_PREFIX = "gs://"
|
if model_id.startswith(utils.GCS_PREFIX):
|
||||||
if model_id.startswith(GCS_PREFIX):
|
utisls.download_gcs_dir_to_local(model_id, utils.GCS_LOCAL_DIR)
|
||||||
local_dir = "/tmp/gcs_model"
|
model_id = utils.GCS_LOCAL_DIR
|
||||||
from google.cloud import storage
|
|
||||||
def download_gcs_dir_to_local(gcs_dir: str, local_dir: str):
|
|
||||||
if os.path.isdir(local_dir):
|
|
||||||
return
|
|
||||||
# gs://bucket_name/dir
|
|
||||||
bucket_name = gcs_dir.split('/')[2]
|
|
||||||
prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/'
|
|
||||||
client = storage.Client()
|
|
||||||
blobs = client.list_blobs(bucket_name, prefix=prefix)
|
|
||||||
if not blobs:
|
|
||||||
raise ValueError(f"No blobs found in {gcs_dir}")
|
|
||||||
for blob in blobs:
|
|
||||||
if blob.name[-1] == '/':
|
|
||||||
continue
|
|
||||||
file_path = blob.name[len(prefix) :].strip('/')
|
|
||||||
local_file_path = os.path.join(local_dir, file_path)
|
|
||||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
|
||||||
logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.")
|
|
||||||
blob.download_to_filename(local_file_path)
|
|
||||||
logger.info("Download finished.")
|
|
||||||
download_gcs_dir_to_local(model_id, local_dir)
|
|
||||||
model_id = local_dir
|
|
||||||
# Test if files were already download
|
# Test if files were already download
|
||||||
try:
|
try:
|
||||||
utils.weight_files(model_id, revision, extension)
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
@ -19,6 +19,7 @@ from text_generation_server.utils.tokens import (
|
|||||||
Sampling,
|
Sampling,
|
||||||
Greedy,
|
Greedy,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.gcs import GCS_PREFIX, GCS_LOCAL_DIR, download_gcs_dir_to_local
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_file",
|
"convert_file",
|
||||||
@ -39,4 +40,7 @@ __all__ = [
|
|||||||
"StopSequenceCriteria",
|
"StopSequenceCriteria",
|
||||||
"FinishReason",
|
"FinishReason",
|
||||||
"Weights",
|
"Weights",
|
||||||
|
"GCS_PREFIX",
|
||||||
|
"GCS_LOCAL_DIR",
|
||||||
|
"download_gcs_dir_to_local",
|
||||||
]
|
]
|
||||||
|
24
server/text_generation_server/utils/gcs.py
Normal file
24
server/text_generation_server/utils/gcs.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from google.cloud import storage
|
||||||
|
|
||||||
|
GCS_PREFIX = "gs://"
|
||||||
|
GCS_LOCAL_DIR = "/tmp/gcs_model"
|
||||||
|
|
||||||
|
def download_gcs_dir_to_local(gcs_dir: str, local_dir: str):
|
||||||
|
if os.path.isdir(local_dir):
|
||||||
|
return
|
||||||
|
# gs://bucket_name/dir
|
||||||
|
bucket_name = gcs_dir.split('/')[2]
|
||||||
|
prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/'
|
||||||
|
client = storage.Client()
|
||||||
|
blobs = client.list_blobs(bucket_name, prefix=prefix)
|
||||||
|
if not blobs:
|
||||||
|
raise ValueError(f"No blobs found in {gcs_dir}")
|
||||||
|
for blob in blobs:
|
||||||
|
if blob.name[-1] == '/':
|
||||||
|
continue
|
||||||
|
file_path = blob.name[len(prefix) :].strip('/')
|
||||||
|
local_file_path = os.path.join(local_dir, file_path)
|
||||||
|
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||||
|
logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.")
|
||||||
|
blob.download_to_filename(local_file_path)
|
||||||
|
logger.info("Download finished.")
|
Loading…
Reference in New Issue
Block a user