feat: Support loading models from GCS

This commit is contained in:
dstnluong-google 2024-02-20 23:16:30 +00:00
parent c9f4c1af31
commit 064c110123
4 changed files with 36 additions and 3 deletions

View File

@ -218,6 +218,9 @@ COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-
# Install flash-attention dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir RUN pip install einops --no-cache-dir
# Install GCS library
RUN pip install --upgrade google-cloud-storage
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server

View File

@ -782,7 +782,7 @@ enum LauncherError {
WebserverCannotStart, WebserverCannotStart,
} }
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> { fn download_convert_model(args: &mut Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
// Enter download tracing span // Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let _span = tracing::span!(tracing::Level::INFO, "download").entered();
@ -907,6 +907,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
} }
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
} }
if args.model_id.starts_with("gs://") {
args.model_id = "/tmp/gcs_model/".to_string();
}
Ok(()) Ok(())
} }
@ -1192,7 +1195,7 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
fn main() -> Result<(), LauncherError> { fn main() -> Result<(), LauncherError> {
// Pattern match configuration // Pattern match configuration
let args: Args = Args::parse(); let mut args: Args = Args::parse();
// Filter events with LOG_LEVEL // Filter events with LOG_LEVEL
let env_filter = let env_filter =
@ -1285,7 +1288,7 @@ fn main() -> Result<(), LauncherError> {
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Download and convert model weights // Download and convert model weights
download_convert_model(&args, running.clone())?; download_convert_model(&mut args, running.clone())?;
if !running.load(Ordering::SeqCst) { if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop // Launcher was asked to stop

View File

@ -123,6 +123,31 @@ def download_weights(
# Import here after the logger is added to log potential import exceptions # Import here after the logger is added to log potential import exceptions
from text_generation_server import utils from text_generation_server import utils
GCS_PREFIX = "gs://"
if model_id.startswith(GCS_PREFIX):
local_dir = "/tmp/gcs_model"
from google.cloud import storage
def download_gcs_dir_to_local(gcs_dir: str, local_dir: str):
if os.path.isdir(local_dir):
return
# gs://bucket_name/dir
bucket_name = gcs_dir.split('/')[2]
prefix = gcs_dir[len(GCS_PREFIX + bucket_name) :].strip('/') + '/'
client = storage.Client()
blobs = client.list_blobs(bucket_name, prefix=prefix)
if not blobs:
raise ValueError(f"No blobs found in {gcs_dir}")
for blob in blobs:
if blob.name[-1] == '/':
continue
file_path = blob.name[len(prefix) :].strip('/')
local_file_path = os.path.join(local_dir, file_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
logger.info(f"==> Download {gcs_dir}/{file_path} to {local_file_path}.")
blob.download_to_filename(local_file_path)
logger.info("Download finished.")
download_gcs_dir_to_local(model_id, local_dir)
model_id = local_dir
# Test if files were already download # Test if files were already download
try: try:
utils.weight_files(model_id, revision, extension) utils.weight_files(model_id, revision, extension)

View File

@ -192,6 +192,8 @@ def serve(
local_url = unix_socket_template.format(uds_path, 0) local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url] server_urls = [local_url]
if model_id.startswith("gs://"):
model_id = "/tmp/gcs_model"
try: try:
model = get_model( model = get_model(
model_id, model_id,