feat: align docs with demo code and rename env var

This commit is contained in:
drbh 2024-01-24 10:52:40 -05:00
parent e4fa84ba26
commit 635a5b42ea
3 changed files with 44 additions and 33 deletions

View File

@ -87,37 +87,48 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
## Amazon SageMaker
Amazon SageMaker allows two routes: `/invocations` and `/ping` (or `/health`) for health checks. By default, we map `/generate` to `/invocations`. However, SageMaker does not allow requests to any other routes.
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
To provide the new feature of Messages API, we have introduced an environment variable `OAI_ENABLED`. If `OAI_ENABLED=true`, the `chat_completions` method is used when `/invocations` is called, otherwise it defaults to `generate`. This allows users to opt in for the OAI format.
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
Here's an example of running the router with `OAI_ENABLED` set to `true`:
```python
import json
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
```bash
OAI_ENABLED=true text-generation-launcher --model-id <MODEL-ID>
```
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
And here's an example request:
```bash
curl <SAGEMAKER-ENDPOINT>/invocations \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
# Hub Model configuration. https://huggingface.co/models
hub = {
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
'SM_NUM_GPUS': json.dumps(1),
'MESSAGES_API_ENABLED': True
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json' | jq
```
Please let us know if any naming changes are needed or if any other routes need similar functionality.
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
env=hub,
role=role,
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type="ml.g5.2xlarge",
container_startup_health_check_timeout=300,
)
# send request
predictor.predict({
"messages": [
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
]
})
```

View File

@ -72,7 +72,7 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
chat_enabled_api: bool,
messages_api_enabled: bool,
}
#[tokio::main]
@ -104,7 +104,7 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
chat_enabled_api,
messages_api_enabled,
} = args;
// Launch Tokio runtime
@ -348,7 +348,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken,
ngrok_edge,
tokenizer_config,
chat_enabled_api,
messages_api_enabled,
)
.await?;
Ok(())

View File

@ -708,7 +708,7 @@ pub async fn run(
ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
chat_enabled_api: bool,
messages_api_enabled: bool,
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
@ -872,7 +872,7 @@ pub async fn run(
.route("/metrics", get(metrics));
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if chat_enabled_api {
let aws_sagemaker_route = if messages_api_enabled {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise