mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
sagemaker support
This commit is contained in:
parent
c9bdaa8b73
commit
ce09fd32a1
@ -37,7 +37,7 @@ ENV LANG=C.UTF-8 \
|
|||||||
MODEL_ID=bigscience/bloom-560m \
|
MODEL_ID=bigscience/bloom-560m \
|
||||||
QUANTIZE=false \
|
QUANTIZE=false \
|
||||||
NUM_SHARD=1 \
|
NUM_SHARD=1 \
|
||||||
PORT=80 \
|
PORT=8080 \
|
||||||
CUDA_HOME=/usr/local/cuda \
|
CUDA_HOME=/usr/local/cuda \
|
||||||
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
||||||
CONDA_DEFAULT_ENV=text-generation \
|
CONDA_DEFAULT_ENV=text-generation \
|
||||||
|
@ -19,13 +19,13 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
|||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
#[clap(default_value = "bigscience/bloom-560m", long, env = "HF_MODEL_ID")]
|
||||||
model_id: String,
|
model_id: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
sharded: Option<bool>,
|
sharded: Option<bool>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env = "SM_NUM_GPUS")]
|
||||||
num_shard: Option<usize>,
|
num_shard: Option<usize>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
|
@ -530,10 +530,12 @@ pub async fn run(
|
|||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||||
.route("/", post(compat_generate))
|
.route("/", post(compat_generate))
|
||||||
|
.route("/invocations", post(compat_generate))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
|
.route("/ping", get(health))
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
|
Loading…
Reference in New Issue
Block a user