sagemaker support

This commit is contained in:
OlivierDehaene 2023-02-28 10:34:29 +01:00
parent c9bdaa8b73
commit ce09fd32a1
3 changed files with 5 additions and 3 deletions

View File

@ -37,7 +37,7 @@ ENV LANG=C.UTF-8 \
MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \
NUM_SHARD=1 \
PORT=80 \
PORT=8080 \
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \

View File

@ -19,13 +19,13 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)]
#[clap(default_value = "bigscience/bloom-560m", long, env = "HF_MODEL_ID")]
model_id: String,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env)]
sharded: Option<bool>,
#[clap(long, env)]
#[clap(long, env = "SM_NUM_GPUS")]
num_shard: Option<usize>,
#[clap(long, env)]
quantize: bool,

View File

@ -530,10 +530,12 @@ pub async fn run(
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
.route("/", post(compat_generate))
.route("/invocations", post(compat_generate))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/", get(health))
.route("/health", get(health))
.route("/ping", get(health))
.route("/metrics", get(metrics))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))