From 635a5b42ea553f5f340559532ec0a63b8b79b298 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 24 Jan 2024 10:52:40 -0500 Subject: [PATCH] feat: align docs with demo code and rename env var --- .../{openai_compatible.md => message_api.md} | 67 +++++++++++-------- router/src/main.rs | 6 +- router/src/server.rs | 4 +- 3 files changed, 44 insertions(+), 33 deletions(-) rename docs/source/{openai_compatible.md => message_api.md} (62%) diff --git a/docs/source/openai_compatible.md b/docs/source/message_api.md similarity index 62% rename from docs/source/openai_compatible.md rename to docs/source/message_api.md index 895ca6b8..899de865 100644 --- a/docs/source/openai_compatible.md +++ b/docs/source/message_api.md @@ -87,37 +87,48 @@ TGI can be deployed on various cloud providers for scalable and robust text gene ## Amazon SageMaker -Amazon SageMaker allows two routes: `/invocations` and `/ping` (or `/health`) for health checks. By default, we map `/generate` to `/invocations`. However, SageMaker does not allow requests to any other routes. +To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`. -To provide the new feature of Messages API, we have introduced an environment variable `OAI_ENABLED`. If `OAI_ENABLED=true`, the `chat_completions` method is used when `/invocations` is called, otherwise it defaults to `generate`. This allows users to opt in for the OAI format. +This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API. -Here's an example of running the router with `OAI_ENABLED` set to `true`: +```python +import json +import sagemaker +import boto3 +from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri -```bash -OAI_ENABLED=true text-generation-launcher --model-id -``` +try: + role = sagemaker.get_execution_role() +except ValueError: + iam = boto3.client('iam') + role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn'] -And here's an example request: +# Hub Model configuration. https://huggingface.co/models +hub = { + 'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta', + 'SM_NUM_GPUS': json.dumps(1), + 'MESSAGES_API_ENABLED': True +} -```bash -curl /invocations \ - -X POST \ - -d '{ - "model": "tgi", - "messages": [ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "What is deep learning?" - } - ], - "stream": true, - "max_tokens": 20 -}' \ - -H 'Content-Type: application/json' | jq -``` +# create Hugging Face Model Class +huggingface_model = HuggingFaceModel( + image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"), + env=hub, + role=role, +) -Please let us know if any naming changes are needed or if any other routes need similar functionality. +# deploy model to SageMaker Inference +predictor = huggingface_model.deploy( + initial_instance_count=1, + instance_type="ml.g5.2xlarge", + container_startup_health_check_timeout=300, + ) + +# send request +predictor.predict({ +"messages": [ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ] +}) +``` \ No newline at end of file diff --git a/router/src/main.rs b/router/src/main.rs index bf987eb6..b6190908 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -72,7 +72,7 @@ struct Args { #[clap(long, env)] ngrok_edge: Option, #[clap(long, env, default_value_t = false)] - chat_enabled_api: bool, + messages_api_enabled: bool, } #[tokio::main] @@ -104,7 +104,7 @@ async fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, - chat_enabled_api, + messages_api_enabled, } = args; // Launch Tokio runtime @@ -348,7 +348,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, tokenizer_config, - chat_enabled_api, + messages_api_enabled, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index aa1ad202..ff48b4f0 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -708,7 +708,7 @@ pub async fn run( ngrok_authtoken: Option, ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, - chat_enabled_api: bool, + messages_api_enabled: bool, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -872,7 +872,7 @@ pub async fn run( .route("/metrics", get(metrics)); // Conditional AWS Sagemaker route - let aws_sagemaker_route = if chat_enabled_api { + let aws_sagemaker_route = if messages_api_enabled { Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED } else { Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise