mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat: align docs with demo code and rename env var
This commit is contained in:
parent
e4fa84ba26
commit
635a5b42ea
@ -87,37 +87,48 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
|
|||||||
|
|
||||||
## Amazon SageMaker
|
## Amazon SageMaker
|
||||||
|
|
||||||
Amazon SageMaker allows two routes: `/invocations` and `/ping` (or `/health`) for health checks. By default, we map `/generate` to `/invocations`. However, SageMaker does not allow requests to any other routes.
|
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
|
||||||
|
|
||||||
To provide the new feature of Messages API, we have introduced an environment variable `OAI_ENABLED`. If `OAI_ENABLED=true`, the `chat_completions` method is used when `/invocations` is called, otherwise it defaults to `generate`. This allows users to opt in for the OAI format.
|
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
|
||||||
|
|
||||||
Here's an example of running the router with `OAI_ENABLED` set to `true`:
|
```python
|
||||||
|
import json
|
||||||
|
import sagemaker
|
||||||
|
import boto3
|
||||||
|
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
|
||||||
|
|
||||||
```bash
|
try:
|
||||||
OAI_ENABLED=true text-generation-launcher --model-id <MODEL-ID>
|
role = sagemaker.get_execution_role()
|
||||||
```
|
except ValueError:
|
||||||
|
iam = boto3.client('iam')
|
||||||
|
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
|
||||||
|
|
||||||
And here's an example request:
|
# Hub Model configuration. https://huggingface.co/models
|
||||||
|
hub = {
|
||||||
|
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
|
||||||
|
'SM_NUM_GPUS': json.dumps(1),
|
||||||
|
'MESSAGES_API_ENABLED': True
|
||||||
|
}
|
||||||
|
|
||||||
```bash
|
# create Hugging Face Model Class
|
||||||
curl <SAGEMAKER-ENDPOINT>/invocations \
|
huggingface_model = HuggingFaceModel(
|
||||||
-X POST \
|
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
|
||||||
-d '{
|
env=hub,
|
||||||
"model": "tgi",
|
role=role,
|
||||||
"messages": [
|
)
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "What is deep learning?"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"stream": true,
|
|
||||||
"max_tokens": 20
|
|
||||||
}' \
|
|
||||||
-H 'Content-Type: application/json' | jq
|
|
||||||
```
|
|
||||||
|
|
||||||
Please let us know if any naming changes are needed or if any other routes need similar functionality.
|
# deploy model to SageMaker Inference
|
||||||
|
predictor = huggingface_model.deploy(
|
||||||
|
initial_instance_count=1,
|
||||||
|
instance_type="ml.g5.2xlarge",
|
||||||
|
container_startup_health_check_timeout=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
# send request
|
||||||
|
predictor.predict({
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant." },
|
||||||
|
{"role": "user", "content": "What is deep learning?"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
```
|
@ -72,7 +72,7 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
chat_enabled_api: bool,
|
messages_api_enabled: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -104,7 +104,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
chat_enabled_api,
|
messages_api_enabled,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
@ -348,7 +348,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
chat_enabled_api,
|
messages_api_enabled,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -708,7 +708,7 @@ pub async fn run(
|
|||||||
ngrok_authtoken: Option<String>,
|
ngrok_authtoken: Option<String>,
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
chat_enabled_api: bool,
|
messages_api_enabled: bool,
|
||||||
) -> Result<(), axum::BoxError> {
|
) -> Result<(), axum::BoxError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
@ -872,7 +872,7 @@ pub async fn run(
|
|||||||
.route("/metrics", get(metrics));
|
.route("/metrics", get(metrics));
|
||||||
|
|
||||||
// Conditional AWS Sagemaker route
|
// Conditional AWS Sagemaker route
|
||||||
let aws_sagemaker_route = if chat_enabled_api {
|
let aws_sagemaker_route = if messages_api_enabled {
|
||||||
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
|
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
|
||||||
} else {
|
} else {
|
||||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
||||||
|
Loading…
Reference in New Issue
Block a user