Merge branch 'main' into trtllm-executor-thread

This commit is contained in:
Nicolas Patry 2024-10-25 07:06:35 +02:00 committed by GitHub
commit 01b82b58d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
63 changed files with 2137 additions and 2009 deletions

View File

@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
```bash
curl localhost:3000/v1/chat/completions \
curl localhost:8080/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",

View File

@ -54,14 +54,14 @@ struct Args {
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env)]
auth_token: Option<String>,
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
}
async fn get_tokenizer(
@ -213,10 +213,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
messages_api_enabled,
max_client_batch_size,
auth_token,
executor_worker,
usage_stats,
} = args;
// Launch Tokio runtime
@ -293,7 +293,6 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
false,
None,
None,
messages_api_enabled,
true,
max_client_batch_size,
UsageStatsLevel::Off,

View File

@ -44,6 +44,8 @@ struct Args {
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
@ -63,8 +65,6 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
validation_workers,
api_key,
json_output,
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
@ -184,13 +184,13 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,

View File

@ -44,6 +44,8 @@ struct Args {
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
@ -63,8 +65,6 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
validation_workers,
api_key,
json_output,
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
@ -184,13 +184,13 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,

View File

@ -316,6 +316,98 @@
}
}
},
"/invocations": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens from Sagemaker request",
"operationId": "sagemaker_compatibility",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Chat Completion",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerResponse"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/SagemakerStreamResponse"
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error",
"error_type": "validation"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation",
"error_type": "generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded",
"error_type": "overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation",
"error_type": "incomplete_generation"
}
}
}
}
}
}
},
"/metrics": {
"get": {
"tags": [
@ -1865,6 +1957,45 @@
"type": "string"
}
},
"SagemakerRequest": {
"oneOf": [
{
"$ref": "#/components/schemas/CompatGenerateRequest"
},
{
"$ref": "#/components/schemas/ChatRequest"
},
{
"$ref": "#/components/schemas/CompletionRequest"
}
]
},
"SagemakerResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/GenerateResponse"
},
{
"$ref": "#/components/schemas/ChatCompletion"
},
{
"$ref": "#/components/schemas/CompletionFinal"
}
]
},
"SagemakerStreamResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/StreamResponse"
},
{
"$ref": "#/components/schemas/ChatCompletionChunk"
},
{
"$ref": "#/components/schemas/Chunk"
}
]
},
"SimpleToken": {
"type": "object",
"required": [

View File

@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
## Amazon SageMaker
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
Amazon Sagemaker natively supports the message API:
```python
import json
@ -161,12 +159,11 @@ except ValueError:
hub = {
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
'SM_NUM_GPUS': json.dumps(1),
'MESSAGES_API_ENABLED': True
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
image_uri=get_huggingface_llm_image_uri("huggingface",version="2.3.2"),
env=hub,
role=role,
)

View File

@ -8,6 +8,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)

View File

@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected:
"max_top_n_tokens": 5,
"max_total_tokens": 2048,
"max_waiting_tokens": 20,
"messages_api_enabled": false,
"model_config": {
"model_type": "Bloom"
},

View File

@ -978,15 +978,16 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1728381423,
"narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=",
"lastModified": 1729531056,
"narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e",
"rev": "a84a90281a17b15762873845c947e5c78f5a8dd1",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "marlin-kernels-0.3.0",
"repo": "text-generation-inference-nix",
"type": "github"
}

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
@ -137,6 +137,11 @@
impure = callPackage ./nix/impure-shell.nix { inherit server; };
impureWithCuda = callPackage ./nix/impure-shell.nix {
inherit server;
withCuda = true;
};
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
};

View File

@ -11,27 +11,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.1875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.93359375,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1796875,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -39,66 +39,66 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.109375,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.079956055,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.2763672,
"logprob": -0.028808594,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37548828,
"logprob": -0.013671875,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4628906,
"logprob": -0.69921875,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02885437,
"logprob": -0.0005874634,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.2565918,
"logprob": -0.026855469,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0063438416,
"logprob": -0.00020885468,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3056641,
"logprob": -0.17773438,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.6035156,
"id": 18065,
"logprob": -0.703125,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
}

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 3,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
@ -11,22 +11,22 @@
},
{
"id": 374,
"logprob": -22.96875,
"logprob": -18.0,
"text": " is"
},
{
"id": 5655,
"logprob": -10.71875,
"logprob": -11.75,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.6992188,
"logprob": -2.0625,
"text": " learning"
},
{
"id": 30,
"logprob": -4.8398438,
"logprob": -6.0,
"text": "?"
}
],
@ -34,24 +34,66 @@
"tokens": [
{
"id": 720,
"logprob": -0.4411621,
"logprob": 0.0,
"special": false,
"text": " \n"
},
{
"id": 220,
"logprob": -0.35864258,
"id": 34564,
"logprob": -0.11279297,
"special": false,
"text": " "
"text": "Deep"
},
{
"id": 128001,
"id": 6975,
"logprob": -0.16015625,
"special": false,
"text": " learning"
},
{
"id": 320,
"logprob": -0.25195312,
"special": false,
"text": " ("
},
{
"id": 16931,
"logprob": -1.703125,
"special": false,
"text": "DL"
},
{
"id": 8,
"logprob": 0.0,
"special": true,
"text": "<|end_of_text|>"
"special": false,
"text": ")"
},
{
"id": 374,
"logprob": -1.140625,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 1207,
"logprob": -1.3125,
"special": false,
"text": " sub"
},
{
"id": 2630,
"logprob": 0.0,
"special": false,
"text": "field"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning? \n "
"generated_text": "What is deep learning? \nDeep learning (DL) is a subfield"
}

View File

@ -12,27 +12,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.1875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.93359375,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1796875,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -40,68 +40,68 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.109375,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.0047912598,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.025512695,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.012145996,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.72265625,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0005760193,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02722168,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00023651123,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.17285156,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.703125,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
},
{
"details": {
@ -116,27 +116,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.21875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.95703125,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.9375,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1328125,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -144,68 +144,68 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.1796875,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.02758789,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.013366699,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.6953125,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0004863739,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02709961,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00022506714,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.19726562,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.77734375,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
},
{
"details": {
@ -220,27 +220,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.21875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.95703125,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.9375,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1328125,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -248,68 +248,68 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.1796875,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.02758789,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.013366699,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.6953125,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0004863739,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02709961,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00022506714,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.19726562,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.77734375,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
},
{
"details": {
@ -324,27 +324,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.21875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.95703125,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.9375,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1328125,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -352,67 +352,67 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.1796875,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.02758789,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.013366699,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.6953125,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0004863739,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02709961,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00022506714,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.19726562,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.77734375,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
}
]

View File

@ -10,80 +10,95 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 1824,
"logprob": -9.2890625,
"text": "What"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 349,
"logprob": -1.1503906,
"text": "is"
},
{
"id": 3534,
"logprob": -9.5859375,
"text": "deep"
},
{
"id": 5168,
"logprob": -1.3945312,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.4555664,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"logprob": -0.6953125,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"logprob": -0.4777832,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"id": 23229,
"logprob": -0.13256836,
"special": false,
"text": "#"
"text": "Deep"
},
{
"id": 3735,
"logprob": -2.8808594,
"id": 5168,
"logprob": -0.023849487,
"special": false,
"text": " Test"
"text": " learning"
},
{
"id": 2159,
"logprob": -0.37280273,
"id": 349,
"logprob": -0.13977051,
"special": false,
"text": " request"
"text": " is"
},
{
"id": 13,
"logprob": -0.26098633,
"id": 264,
"logprob": -0.14489746,
"special": false,
"text": "\n"
"text": " a"
},
{
"id": 13,
"logprob": -0.0017137527,
"id": 19804,
"logprob": -0.63183594,
"special": false,
"text": "\n"
"text": " subset"
},
{
"id": 1064,
"logprob": -2.2695312,
"id": 302,
"logprob": -0.010314941,
"special": false,
"text": "##"
"text": " of"
},
{
"id": 3735,
"logprob": -1.9238281,
"id": 5599,
"logprob": -0.0635376,
"special": false,
"text": " Test"
"text": " machine"
},
{
"id": 2159,
"logprob": -0.48828125,
"id": 5168,
"logprob": -0.0028572083,
"special": false,
"text": " request"
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
"generated_text": "\n\nDeep learning is a subset of machine learning"
}

View File

@ -10,42 +10,28 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 349,
"logprob": -12.0546875,
"text": "is"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 3534,
"logprob": -10.53125,
"text": "deep"
},
{
"id": 5168,
"logprob": -2.71875,
"text": "learning"
},
{
"id": 28804,
"logprob": -5.0078125,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -0.34838867,
"special": false,
"text": "\n"
},
{
"id": 13940,
"logprob": -0.38916016,
"special": false,
"text": "``"
},
{
"id": 28832,
"logprob": 0.0,
"special": false,
"text": "`"
},
{
"id": 3371,
"logprob": -1.2529297,
"special": false,
"text": "json"
},
{
"id": 13,
"logprob": 0.0,
@ -53,37 +39,61 @@
"text": "\n"
},
{
"id": 28751,
"logprob": 0.0,
"id": 23229,
"logprob": -0.18237305,
"special": false,
"text": "{"
"text": "Deep"
},
{
"id": 13,
"id": 17504,
"logprob": 0.0,
"special": false,
"text": "\n"
"text": " Learning"
},
{
"id": 2287,
"id": 349,
"logprob": 0.0,
"special": false,
"text": " "
"text": " is"
},
{
"id": 345,
"id": 264,
"logprob": 0.0,
"special": false,
"text": " \""
"text": " a"
},
{
"id": 3134,
"logprob": -0.640625,
"id": 19804,
"logprob": 0.0,
"special": false,
"text": "request"
"text": " subset"
},
{
"id": 302,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 13253,
"logprob": -0.6040039,
"special": false,
"text": " Machine"
},
{
"id": 17504,
"logprob": 0.0,
"special": false,
"text": " Learning"
},
{
"id": 28725,
"logprob": -0.11621094,
"special": false,
"text": ","
}
],
"top_tokens": null
},
"generated_text": "Test request\n```json\n{\n \"request"
"generated_text": "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
}

View File

@ -11,82 +11,97 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 1824,
"logprob": -9.2890625,
"text": "What"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 349,
"logprob": -1.1503906,
"text": "is"
},
{
"id": 3534,
"logprob": -9.5859375,
"text": "deep"
},
{
"id": 5168,
"logprob": -1.3945312,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.4555664,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"logprob": -0.6953125,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"logprob": -0.4777832,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"id": 23229,
"logprob": -0.13232422,
"special": false,
"text": "#"
"text": "Deep"
},
{
"id": 3735,
"logprob": -2.8828125,
"id": 5168,
"logprob": -0.023834229,
"special": false,
"text": " Test"
"text": " learning"
},
{
"id": 2159,
"logprob": -0.37329102,
"id": 349,
"logprob": -0.13977051,
"special": false,
"text": " request"
"text": " is"
},
{
"id": 13,
"logprob": -0.2602539,
"id": 264,
"logprob": -0.14416504,
"special": false,
"text": "\n"
"text": " a"
},
{
"id": 13,
"logprob": -0.0017185211,
"id": 19804,
"logprob": -0.63183594,
"special": false,
"text": "\n"
"text": " subset"
},
{
"id": 1064,
"logprob": -2.2753906,
"id": 302,
"logprob": -0.010223389,
"special": false,
"text": "##"
"text": " of"
},
{
"id": 3735,
"logprob": -1.9316406,
"id": 5599,
"logprob": -0.064208984,
"special": false,
"text": " Test"
"text": " machine"
},
{
"id": 2159,
"logprob": -0.48217773,
"id": 5168,
"logprob": -0.0028266907,
"special": false,
"text": " request"
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
@ -100,82 +115,97 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 1824,
"logprob": -9.2890625,
"text": "What"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 349,
"logprob": -1.1425781,
"text": "is"
},
{
"id": 3534,
"logprob": -9.59375,
"text": "deep"
},
{
"id": 5168,
"logprob": -1.390625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.45532227,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"logprob": -0.6953125,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"logprob": -0.48339844,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"id": 23229,
"logprob": -0.13256836,
"special": false,
"text": "#"
"text": "Deep"
},
{
"id": 3735,
"logprob": -2.8828125,
"id": 5168,
"logprob": -0.02420044,
"special": false,
"text": " Test"
"text": " learning"
},
{
"id": 2159,
"logprob": -0.37329102,
"id": 349,
"logprob": -0.13977051,
"special": false,
"text": " request"
"text": " is"
},
{
"id": 13,
"logprob": -0.2602539,
"id": 264,
"logprob": -0.14501953,
"special": false,
"text": "\n"
"text": " a"
},
{
"id": 13,
"logprob": -0.0017185211,
"id": 19804,
"logprob": -0.63134766,
"special": false,
"text": "\n"
"text": " subset"
},
{
"id": 1064,
"logprob": -2.2753906,
"id": 302,
"logprob": -0.010223389,
"special": false,
"text": "##"
"text": " of"
},
{
"id": 3735,
"logprob": -1.9316406,
"id": 5599,
"logprob": -0.06427002,
"special": false,
"text": " Test"
"text": " machine"
},
{
"id": 2159,
"logprob": -0.48217773,
"id": 5168,
"logprob": -0.002817154,
"special": false,
"text": " request"
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
@ -189,82 +219,97 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 1824,
"logprob": -9.2890625,
"text": "What"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 349,
"logprob": -1.1425781,
"text": "is"
},
{
"id": 3534,
"logprob": -9.59375,
"text": "deep"
},
{
"id": 5168,
"logprob": -1.390625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.45532227,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"logprob": -0.6953125,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"logprob": -0.48339844,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"id": 23229,
"logprob": -0.13256836,
"special": false,
"text": "#"
"text": "Deep"
},
{
"id": 3735,
"logprob": -2.8828125,
"id": 5168,
"logprob": -0.02420044,
"special": false,
"text": " Test"
"text": " learning"
},
{
"id": 2159,
"logprob": -0.37329102,
"id": 349,
"logprob": -0.13977051,
"special": false,
"text": " request"
"text": " is"
},
{
"id": 13,
"logprob": -0.2602539,
"id": 264,
"logprob": -0.14501953,
"special": false,
"text": "\n"
"text": " a"
},
{
"id": 13,
"logprob": -0.0017185211,
"id": 19804,
"logprob": -0.63134766,
"special": false,
"text": "\n"
"text": " subset"
},
{
"id": 1064,
"logprob": -2.2753906,
"id": 302,
"logprob": -0.010223389,
"special": false,
"text": "##"
"text": " of"
},
{
"id": 3735,
"logprob": -1.9316406,
"id": 5599,
"logprob": -0.06427002,
"special": false,
"text": " Test"
"text": " machine"
},
{
"id": 2159,
"logprob": -0.48217773,
"id": 5168,
"logprob": -0.002817154,
"special": false,
"text": " request"
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
@ -278,81 +323,96 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 1824,
"logprob": -9.2890625,
"text": "What"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 349,
"logprob": -1.1425781,
"text": "is"
},
{
"id": 3534,
"logprob": -9.59375,
"text": "deep"
},
{
"id": 5168,
"logprob": -1.390625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.45532227,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"logprob": -0.6953125,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"logprob": -0.48339844,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"id": 23229,
"logprob": -0.13256836,
"special": false,
"text": "#"
"text": "Deep"
},
{
"id": 3735,
"logprob": -2.8828125,
"id": 5168,
"logprob": -0.02420044,
"special": false,
"text": " Test"
"text": " learning"
},
{
"id": 2159,
"logprob": -0.37329102,
"id": 349,
"logprob": -0.13977051,
"special": false,
"text": " request"
"text": " is"
},
{
"id": 13,
"logprob": -0.2602539,
"id": 264,
"logprob": -0.14501953,
"special": false,
"text": "\n"
"text": " a"
},
{
"id": 13,
"logprob": -0.0017185211,
"id": 19804,
"logprob": -0.63134766,
"special": false,
"text": "\n"
"text": " subset"
},
{
"id": 1064,
"logprob": -2.2753906,
"id": 302,
"logprob": -0.010223389,
"special": false,
"text": "##"
"text": " of"
},
{
"id": 3735,
"logprob": -1.9316406,
"id": 5599,
"logprob": -0.06427002,
"special": false,
"text": " Test"
"text": " machine"
},
{
"id": 2159,
"logprob": -0.48217773,
"id": 5168,
"logprob": -0.002817154,
"special": false,
"text": " request"
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
"generated_text": "\n\nDeep learning is a subset of machine learning"
}
]

View File

@ -11,32 +11,32 @@
},
{
"id": 338,
"logprob": -0.7133789,
"logprob": -0.6201172,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"logprob": -13.6484375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"logprob": -0.003894806,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6386719,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"logprob": -6.46875,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"logprob": -6.6875,
"text": "\n"
}
],
@ -44,66 +44,66 @@
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"logprob": -0.008979797,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027313232,
"logprob": -8.34465e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.0009407997,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0623207e-05,
"logprob": -0.0003838539,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5361328,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17578125,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011539459,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.47436523,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027680397,
"logprob": -0.00024354458,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6582031,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092840195,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19470215,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
}

View File

@ -5,95 +5,95 @@
"generated_tokens": 10,
"prefill": [
{
"id": 16030,
"id": 338,
"logprob": null,
"text": "is"
},
{
"id": 16030,
"logprob": -13.328125,
"text": "gradient"
},
{
"id": 26815,
"logprob": -6.4960938,
"logprob": -0.24023438,
"text": "descent"
},
{
"id": 29973,
"logprob": -5.1484375,
"logprob": -3.1386719,
"text": "?"
},
{
"id": 13,
"logprob": -4.0351562,
"text": "\n"
},
{
"id": 13,
"logprob": -5.2265625,
"logprob": -3.0878906,
"text": "\n"
}
],
"seed": 0,
"tokens": [
{
"id": 10994,
"logprob": -1.1542969,
"special": false,
"text": "Hello"
},
{
"id": 29991,
"id": 25584,
"logprob": 0.0,
"special": false,
"text": "!"
"text": "Grad"
},
{
"id": 739,
"id": 993,
"logprob": 0.0,
"special": false,
"text": " It"
"text": "ient"
},
{
"id": 2444,
"logprob": -0.42260742,
"special": false,
"text": " seems"
},
{
"id": 366,
"id": 2726,
"logprob": 0.0,
"special": false,
"text": " you"
"text": " Des"
},
{
"id": 29915,
"id": 1760,
"logprob": 0.0,
"special": false,
"text": "'"
"text": "cent"
},
{
"id": 276,
"logprob": -0.9838867,
"id": 313,
"logprob": -0.12322998,
"special": false,
"text": "re"
"text": " ("
},
{
"id": 3211,
"id": 29954,
"logprob": 0.0,
"special": false,
"text": " address"
"text": "G"
},
{
"id": 292,
"id": 29928,
"logprob": 0.0,
"special": false,
"text": "ing"
"text": "D"
},
{
"id": 263,
"logprob": -0.15124512,
"id": 29897,
"logprob": 0.0,
"special": false,
"text": " a"
"text": ")"
},
{
"id": 338,
"logprob": -0.6040039,
"special": false,
"text": " is"
},
{
"id": 385,
"logprob": -0.1796875,
"special": false,
"text": " an"
}
],
"top_tokens": null
},
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
"generated_text": "What is gradient descent?\nGradient Descent (GD) is an"
}

View File

@ -12,32 +12,32 @@
},
{
"id": 338,
"logprob": -0.7133789,
"logprob": -0.6201172,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"logprob": -13.6484375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"logprob": -0.003894806,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6386719,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"logprob": -6.46875,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"logprob": -6.6875,
"text": "\n"
}
],
@ -45,68 +45,68 @@
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"logprob": -0.008979797,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028476715,
"logprob": -8.34465e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023971558,
"logprob": -0.00097084045,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"logprob": -0.0003838539,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.23840332,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.000116467476,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.47436523,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027871132,
"logprob": -0.0002501011,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6582031,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092840195,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.18933105,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
},
{
"details": {
@ -121,32 +121,32 @@
},
{
"id": 338,
"logprob": -0.7128906,
"logprob": -0.6113281,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"logprob": -13.6640625,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.05053711,
"logprob": -0.003929138,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0058594,
"logprob": -2.625,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"logprob": -6.484375,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"logprob": -6.6875,
"text": "\n"
}
],
@ -154,68 +154,68 @@
"tokens": [
{
"id": 25584,
"logprob": -0.018859863,
"logprob": -0.009017944,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.002822876,
"logprob": -9.536743e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.00097084045,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"logprob": -0.0003838539,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.0001155138,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.47436523,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027036667,
"logprob": -0.0002501011,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.0009279251,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.18933105,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
},
{
"details": {
@ -230,32 +230,32 @@
},
{
"id": 338,
"logprob": -0.71484375,
"logprob": -0.609375,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"logprob": -13.671875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.049346924,
"logprob": -0.0040016174,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6230469,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"logprob": -6.453125,
"text": "\n"
},
{
"id": 13,
"logprob": -0.86328125,
"logprob": -6.6875,
"text": "\n"
}
],
@ -263,68 +263,68 @@
"tokens": [
{
"id": 25584,
"logprob": -0.017196655,
"logprob": -0.008956909,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028438568,
"logprob": -8.34465e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.0009407997,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.026558e-05,
"logprob": -0.0003721714,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011622906,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.48608398,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"logprob": -0.0002501011,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092601776,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19177246,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
},
{
"details": {
@ -339,32 +339,32 @@
},
{
"id": 338,
"logprob": -0.7192383,
"logprob": -0.609375,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"logprob": -13.6640625,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.050445557,
"logprob": -0.0038967133,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6347656,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"logprob": -6.453125,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8276367,
"logprob": -6.6875,
"text": "\n"
}
],
@ -372,67 +372,67 @@
"tokens": [
{
"id": 25584,
"logprob": -0.01727295,
"logprob": -0.008979797,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027542114,
"logprob": -9.536743e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.0009407997,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"logprob": -0.00038409233,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011301041,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.48608398,
"logprob": -0.010414124,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"logprob": -0.00024354458,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.0009279251,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19470215,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
}
]

View File

@ -4,7 +4,9 @@ import pytest
@pytest.fixture(scope="module")
def flash_llama_fp8_kv_cache_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
num_shard=2,
kv_cache_dtype="fp8_e4m3fn",
) as handle:
yield handle
@ -25,7 +27,7 @@ async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snaps
assert (
response.generated_text
== " Deep learning is a subset of machine learning that is"
== " Deep learning is a subset of machine learning that involves"
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@ -69,7 +71,7 @@ async def test_flash_llama_fp8_kv_cache_load(
assert len(responses) == 4
assert (
responses[0].generated_text
== " Deep learning is a subset of machine learning that is"
== " Deep learning is a subset of machine learning that involves"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]

View File

@ -3,7 +3,11 @@ import pytest
@pytest.fixture(scope="module")
def flash_mixtral_gptq_handle(launcher):
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
with launcher(
"TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ",
revision="gptq-4bit-128g-actorder_True",
num_shard=2,
) as handle:
yield handle
@ -16,7 +20,12 @@ async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
@pytest.mark.asyncio
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
response = await flash_mixtral_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert (
response.generated_text == "\n\nDeep learning is a subset of machine learning"
)
assert response == response_snapshot
@ -25,7 +34,7 @@ async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
@pytest.mark.asyncio
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
response = await flash_mixtral_gptq.generate(
"Test request",
"What is deep learning?",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
@ -41,6 +50,10 @@ async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapsh
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
)
assert response == response_snapshot
@ -49,10 +62,14 @@ async def test_flash_mixtral_gptq_load(
flash_mixtral_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
flash_mixtral_gptq, "What is deep learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert (
responses[0].generated_text
== "\n\nDeep learning is a subset of machine learning"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"

View File

@ -25,7 +25,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "Gradient descent is a first-order optimization algorithm"
== "Gradient descent is an optimization algorithm commonly used in"
)
assert response == response_snapshot
@ -33,7 +33,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
@pytest.mark.asyncio
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n",
"What is gradient descent?\n",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
@ -51,7 +51,7 @@ async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is gradient descent?\n\nHello! It seems you're addressing a"
== "What is gradient descent?\nGradient Descent (GD) is an"
)
assert response == response_snapshot
@ -66,7 +66,7 @@ async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_sna
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "Gradient descent is a first-order optimization algorithm"
== "Gradient descent is an optimization algorithm commonly used in"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]

View File

@ -1104,6 +1104,8 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
}
}
}
} else {
break;
}
}
}
@ -1507,6 +1509,10 @@ fn spawn_webserver(
router_args.push(revision.to_string())
}
if args.trust_remote_code {
router_args.push("--trust-remote-code".to_string());
}
if args.json_output {
router_args.push("--json-output".to_string());
}

View File

@ -1,7 +1,12 @@
{
lib,
mkShell,
black,
cmake,
isort,
ninja,
which,
cudaPackages,
openssl,
pkg-config,
protobuf,
@ -11,14 +16,17 @@
ruff,
rust-bin,
server,
# Enable dependencies for building CUDA packages. Useful for e.g.
# developing marlin/moe-kernels in-place.
withCuda ? false,
}:
mkShell {
buildInputs =
nativeBuildInputs =
[
black
isort
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
@ -31,6 +39,19 @@ mkShell {
redocly
ruff
]
++ (lib.optionals withCuda [
cmake
ninja
which
# For most Torch-based extensions, setting CUDA_HOME is enough, but
# some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH.
cudaPackages.cuda_nvcc
]);
buildInputs =
[
openssl.dev
]
++ (with python3.pkgs; [
venvShellHook
docker
@ -40,10 +61,29 @@ mkShell {
pytest
pytest-asyncio
syrupy
]);
])
++ (lib.optionals withCuda (
with cudaPackages;
[
cuda_cccl
cuda_cudart
cuda_nvrtc
cuda_nvtx
cuda_profiler_api
cudnn
libcublas
libcusolver
libcusparse
]
));
inputsFrom = [ server ];
env = lib.optionalAttrs withCuda {
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" python3.pkgs.torch.cudaCapabilities;
};
venvDir = "./.venv";
postVenvCreation = ''
@ -51,6 +91,7 @@ mkShell {
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin

View File

@ -150,6 +150,7 @@ pub enum Config {
Idefics2(Idefics2),
Ssm,
GptBigcode,
Granite,
Santacoder,
Bloom,
Mpt,

View File

@ -8,6 +8,7 @@ pub mod validation;
mod kserve;
pub mod logging;
mod sagemaker;
pub mod usage_stats;
mod vertex;

View File

@ -1,748 +0,0 @@
use axum::http::HeaderValue;
use clap::Parser;
use clap::Subcommand;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use text_generation_router::config::Config;
use text_generation_router::usage_stats;
use text_generation_router::{
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
use thiserror::Error;
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env, default_value_t)]
disable_usage_stats: bool,
#[clap(long, env, default_value_t)]
disable_crash_reports: bool,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
let args = Args::parse();
// Pattern match configuration
let Args {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
revision,
validation_workers,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
api_key,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
command,
} = args;
let print_schema_command = match command {
Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
init_logging(otlp_endpoint, otlp_service_name, json_output);
false
}
};
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let tokenizer_class = tokenizer_config.tokenizer_class.clone();
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
});
let preprocessor_config =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
true
}
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
// Only send usage stats when TGI is run in container and the function returns Some
let is_container = matches!(usage_stats::is_container(), Ok(true));
let user_agent = if !disable_usage_stats && is_container {
let reduced_args = usage_stats::Args::new(
config.clone(),
tokenizer_class,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
revision,
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
);
Some(usage_stats::UserAgent::new(reduced_args))
} else {
None
};
if let Some(ref ua) = user_agent {
let start_event =
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
tokio::spawn(async move {
start_event.send().await;
});
};
// Run server
let result = server::run(
master_shard_uds_path,
model_info,
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
tokenizer,
config,
validation_workers,
addr,
cors_allow_origin,
api_key,
ngrok,
ngrok_authtoken,
ngrok_edge,
tokenizer_config,
preprocessor_config,
processor_config,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
print_schema_command,
)
.await;
match result {
Ok(_) => {
if let Some(ref ua) = user_agent {
let stop_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Stop,
None,
);
stop_event.send().await;
};
Ok(())
}
Err(e) => {
if let Some(ref ua) = user_agent {
if !disable_crash_reports {
let error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some(e.to_string()),
);
error_event.send().await;
} else {
let unknow_error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some("unknow_error".to_string()),
);
unknow_error_event.send().await;
}
};
Err(RouterError::WebServer(e))
}
}
}
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_ansi(ansi)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
init_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}
/// get model info from the Huggingface Hub
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
let response = api.info_request().send().await.ok()?;
if response.status().is_success() {
let hub_model_info: HubModelInfo =
serde_json::from_str(&response.text().await.ok()?).ok()?;
if let Some(sha) = &hub_model_info.sha {
tracing::info!(
"Serving revision {sha} of model {}",
hub_model_info.model_id
);
}
Some(hub_model_info)
} else {
None
}
}
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
api_base_repo.get("tokenizer.json").await.ok()
} else {
None
}
}
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(tokenizer_config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
.map_err(|e| {
tracing::warn!("Unable to parse tokenizer config: {}", e);
e
})
.ok()?;
Some(tokenizer_config)
}
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
use text_generation_router::TokenizerConfigToken;
#[test]
fn test_create_post_processor() {
let tokenizer_config = HubTokenizerConfig {
add_bos_token: None,
add_eos_token: None,
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
chat_template: None,
tokenizer_class: None,
completion_template: None,
};
let tokenizer =
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
let expected = TemplateProcessing::builder()
.try_single("<s>:0 $A:0")
.unwrap()
.try_pair("<s>:0 $A:0 <s>:1 $B:1")
.unwrap()
.special_tokens(vec![("<s>".to_string(), 1)])
.build()
.unwrap();
assert_eq!(post_processor, expected);
}
}

82
router/src/sagemaker.rs Normal file
View File

@ -0,0 +1,82 @@
use crate::infer::Infer;
use crate::server::{chat_completions, compat_generate, completions, ComputeType};
use crate::{
ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest,
CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse,
};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::response::Response;
use axum::Json;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use utoipa::ToSchema;
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerRequest {
Generate(CompatGenerateRequest),
Chat(ChatRequest),
Completion(CompletionRequest),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerResponse {
Generate(GenerateResponse),
Chat(ChatCompletion),
Completion(CompletionFinal),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerStreamResponse {
Generate(StreamResponse),
Chat(ChatCompletionChunk),
Completion(Chunk),
}
/// Generate tokens from Sagemaker request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/invocations",
request_body = SagemakerRequest,
responses(
(status = 200, description = "Generated Chat Completion",
content(
("application/json" = SagemakerResponse),
("text/event-stream" = SagemakerStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error", "error_type": "validation"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation", "error_type": "incomplete_generation"})),
)
)]
#[instrument(skip_all)]
pub(crate) async fn sagemaker_compatibility(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
info: Extension<Info>,
Json(req): Json<SagemakerRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
match req {
SagemakerRequest::Generate(req) => {
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
}
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
SagemakerRequest::Completion(req) => {
completions(infer, compute_type, info, Json(req)).await
}
}
}

View File

@ -7,6 +7,10 @@ use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::sagemaker::{
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
__path_sagemaker_compatibility,
};
use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
@ -83,7 +87,7 @@ example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn compat_generate(
pub(crate) async fn compat_generate(
Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
@ -678,7 +682,7 @@ time_per_token,
seed,
)
)]
async fn completions(
pub(crate) async fn completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
@ -1202,7 +1206,7 @@ time_per_token,
seed,
)
)]
async fn chat_completions(
pub(crate) async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
@ -1513,11 +1517,13 @@ completions,
tokenize,
metrics,
openai_get_model_info,
sagemaker_compatibility,
),
components(
schemas(
Info,
CompatGenerateRequest,
SagemakerRequest,
GenerateRequest,
GrammarType,
ChatRequest,
@ -1540,6 +1546,8 @@ ChatCompletionTopLogprob,
ChatCompletion,
CompletionRequest,
CompletionComplete,
SagemakerResponse,
SagemakerStreamResponse,
Chunk,
Completion,
CompletionFinal,
@ -1601,13 +1609,13 @@ pub async fn run(
tokenizer_name: String,
tokenizer_config_path: Option<String>,
revision: Option<String>,
trust_remote_code: bool,
hostname: String,
port: u16,
cors_allow_origin: Option<Vec<String>>,
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
@ -1761,10 +1769,13 @@ pub async fn run(
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),);
let kwargs = [(
"revision",
revision.clone().unwrap_or_else(|| "main".to_string()),
)]
let kwargs = [
(
"revision",
(revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py),
),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py);
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
@ -1836,7 +1847,6 @@ pub async fn run(
// max_batch_size,
revision.clone(),
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats_level,
@ -1878,7 +1888,6 @@ pub async fn run(
ngrok,
_ngrok_authtoken,
_ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
model_info,
@ -1938,7 +1947,6 @@ async fn start(
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
model_info: HubModelInfo,
@ -2253,6 +2261,7 @@ async fn start(
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility))
.route("/invocations", post(sagemaker_compatibility))
.route("/tokenize", post(tokenize));
if let Some(api_key) = api_key {
@ -2288,13 +2297,6 @@ async fn start(
.route("/metrics", get(metrics))
.route("/v1/models", get(openai_get_model_info));
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if messages_api_enabled {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
let compute_type =
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
@ -2302,8 +2304,7 @@ async fn start(
let mut app = Router::new()
.merge(swagger_ui)
.merge(base_routes)
.merge(info_routes)
.merge(aws_sagemaker_route);
.merge(info_routes);
#[cfg(feature = "google")]
{

View File

@ -93,7 +93,6 @@ pub struct Args {
// max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel,
@ -117,7 +116,6 @@ impl Args {
// max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel,
@ -138,7 +136,6 @@ impl Args {
// max_batch_size,
revision,
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats_level,

View File

@ -31,7 +31,7 @@ install: install-cuda
echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
pip install -e ".[bnb]"
pip install -e ".[bnb,marlin,moe]"
pip install nvidia-nccl-cu12==2.22.3
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm

1379
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
numpy = "^1.26"
marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -28,10 +28,11 @@ else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache
from .kv_cache import KVCache, get_kv_scales
__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"SUPPORTS_WINDOWING",
"KVCache",

View File

@ -1,5 +1,5 @@
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import (
ATTENTION,
@ -8,6 +8,7 @@ from text_generation_server.models.globals import (
from text_generation_server.layers.attention import Seqlen
from typing import Optional
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
@ -21,6 +22,8 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -46,6 +49,8 @@ def paged_attention(
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
can_scale = kv_cache.can_scale(kv_scales)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
@ -55,10 +60,13 @@ def paged_attention(
from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
query.contiguous(),
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
)
elif ATTENTION == "flashdecoding":
max_q = 1
@ -204,6 +212,7 @@ def attention(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
@ -211,6 +220,8 @@ def attention(
causal: bool = True,
softcap: Optional[float] = None,
):
can_scale = kv_cache.can_scale(kv_scales)
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
@ -220,12 +231,15 @@ def attention(
softcap = 0.0
return prefill_with_paged_kv_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
query.contiguous(),
causal=causal,
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
)
# If we are using flashdecoding or paged, we always use flash-attn for

View File

@ -204,6 +204,7 @@ def use_decode_state(
num_kv_heads: int,
head_size: int,
page_size: int,
kv_cache_dtype: torch.dtype,
dtype: torch.dtype,
window_left: int,
):
@ -240,7 +241,7 @@ def use_decode_state(
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
data_type=dtype,
data_type=kv_cache_dtype,
q_data_type=dtype,
window_left=window_left,
)

View File

@ -1,6 +1,6 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
@ -14,6 +14,7 @@ def attention(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
@ -55,6 +56,8 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
if softcap is not None:

View File

@ -1,8 +1,38 @@
from typing import Tuple
from dataclasses import dataclass, field
from loguru import logger
import torch
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache:
@ -76,6 +106,33 @@ class KVCache:
),
)
def can_scale(self, kv_scales: KVScales) -> bool:
"""Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False
elif (
self.dtype == torch.float8_e4m3fn
and ATTENTION == "flashinfer"
and SYSTEM == "cuda"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
)
return False
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property
def key(self):
"""Get the key cache."""
@ -94,17 +151,33 @@ class KVCache:
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
if self.can_scale(kv_scales):
if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize(
key.float(),
scale=kv_scales.key_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if kv_scales.value_scale_cpu != 1.0:
value = fp8_quantize(
value.float(),
scale=kv_scales.value_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if ATTENTION in {"flashdecoding", "flashinfer"}:
# TODO: add scale
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
# put as raw data instead.
key_cache = key_cache.view(torch.uint8)
@ -151,5 +224,23 @@ def paged_reshape_and_cache(
)
else:
raise NotImplementedError(
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention"
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported"
)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)

View File

@ -1,7 +1,7 @@
import os
from typing import Optional
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
@ -36,6 +36,8 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -210,6 +212,7 @@ def attention(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,

View File

@ -26,6 +26,12 @@ def is_fbgemm_gpu_available():
return False
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
@ -94,6 +100,17 @@ def fp8_quantize(
)
return qweight, scale
if marlin_kernels is not None:
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
dtype=qdtype,
scale=scale,
scale_ub=scale_upper_bound,
)
return qweight.reshape(shape), scale
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)

View File

@ -195,6 +195,11 @@ class ModelType(enum.Enum):
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
@ -862,7 +867,12 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
elif (
model_type == LLAMA
or model_type == BAICHUAN
or model_type == PHI3
or model_type == GRANITE
):
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -876,7 +886,9 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
model_id,

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
attention,
Seqlen,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -227,6 +228,7 @@ class FlashCohereAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
@ -289,7 +291,12 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin)
kv_cache.store(key=key, value=value, slots=slots)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -299,6 +306,7 @@ class FlashCohereAttention(torch.nn.Module):
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -313,6 +321,7 @@ class FlashCohereAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(

View File

@ -20,6 +20,7 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "ipex":
@ -288,6 +289,7 @@ class DbrxAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -328,7 +330,12 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -338,6 +345,7 @@ class DbrxAttention(torch.nn.Module):
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -352,6 +360,7 @@ class DbrxAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -34,6 +34,7 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
@ -230,6 +231,8 @@ class DeepseekV2Attention(torch.nn.Module):
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
@ -258,7 +261,7 @@ class DeepseekV2Attention(torch.nn.Module):
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
kv_cache: KVCache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
@ -319,7 +322,12 @@ class DeepseekV2Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store(key=key, value=value, slots=slots)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -329,6 +337,7 @@ class DeepseekV2Attention(torch.nn.Module):
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -343,6 +352,7 @@ class DeepseekV2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
# Remove padding.

View File

@ -39,6 +39,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -206,6 +207,7 @@ class FlashGemma2Attention(torch.nn.Module):
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
@ -251,7 +253,12 @@ class FlashGemma2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -261,6 +268,7 @@ class FlashGemma2Attention(torch.nn.Module):
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -278,6 +286,7 @@ class FlashGemma2Attention(torch.nn.Module):
seqlen,
max_s,
softcap=self.softcap,
kv_scales=self.kv_scales,
)
return self.o_proj(

View File

@ -37,6 +37,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -185,6 +186,7 @@ class FlashGemmaAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -222,7 +224,12 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -232,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module):
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -247,6 +255,7 @@ class FlashGemmaAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -36,6 +36,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -193,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module):
head_size=self.head_size,
num_heads=self.num_heads,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row(
config,
@ -222,7 +224,12 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
kv_cache.store(key=key, value=value, slots=slots)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -232,6 +239,7 @@ class FlashGPT2Attention(torch.nn.Module):
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -246,6 +254,7 @@ class FlashGPT2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -24,6 +24,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
@ -138,6 +139,7 @@ class FlashGPTJAttention(torch.nn.Module):
prefix=prefix,
weights=weights,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row(
config,
@ -184,7 +186,12 @@ class FlashGPTJAttention(torch.nn.Module):
else:
self.rotary_emb(query, key, cos, sin)
kv_cache.store(key=key, value=value, slots=slots)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -194,6 +201,7 @@ class FlashGPTJAttention(torch.nn.Module):
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -208,6 +216,7 @@ class FlashGPTJAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -27,7 +27,10 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import KVCache
from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
@ -156,7 +159,10 @@ class FlashLlamaAttention(torch.nn.Module):
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
config, "attention_multiplier", self.head_size**-0.5
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
@ -176,11 +182,13 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
bias=getattr(config, "attention_bias", False),
)
self.o_proj = TensorParallelAdapterRowLinear.load(
@ -221,7 +229,12 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -230,6 +243,7 @@ class FlashLlamaAttention(torch.nn.Module):
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
@ -245,6 +259,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(
@ -436,6 +451,11 @@ class FlashLlamaLayer(nn.Module):
eps=config.rms_norm_eps,
)
# Used in Granite
# This could eventually be baked into the weights like we do for the embeddings/lm_head
# but this would mean modifying the lora code
self.residual_multiplier = getattr(config, "residual_multiplier", None)
def forward(
self,
hidden_states,
@ -466,13 +486,16 @@ class FlashLlamaLayer(nn.Module):
max_s,
adapter_data,
)
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.dense(normed_attn_res_output, adapter_data)
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier
return mlp_output, attn_res
@ -624,6 +647,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else:
suffix = "lm_head"
# Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
@ -631,6 +659,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
weights=weights,
)
# Used in Granite
self.logits_scaling = getattr(config, "logits_scaling", None)
if self.logits_scaling is not None and self.lm_head.head is not None:
try:
# Scale the weights directly
self.lm_head.head.linear.weight.data /= self.logits_scaling
self.logits_scaled = True
except Exception:
self.logits_scaled = False
def forward(
self,
input_ids: torch.Tensor,
@ -664,4 +702,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
# Used in Granite
if self.logits_scaling is not None and not self.logits_scaled:
logits /= self.logits_scaling
if speculative_logits is not None:
speculative_logits /= self.logits_scaling
return logits, speculative_logits

View File

@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
@ -158,6 +159,7 @@ class MistralAttention(torch.nn.Module):
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
@ -208,7 +210,12 @@ class MistralAttention(torch.nn.Module):
else:
kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -218,6 +225,7 @@ class MistralAttention(torch.nn.Module):
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -233,6 +241,7 @@ class MistralAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(

View File

@ -38,6 +38,7 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding
@ -213,6 +214,7 @@ class MixtralAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -256,7 +258,12 @@ class MixtralAttention(torch.nn.Module):
else:
kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -266,6 +273,7 @@ class MixtralAttention(torch.nn.Module):
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -281,6 +289,7 @@ class MixtralAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -130,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module):
head_size=self.head_size,
hidden_size=self.hidden_size,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
@ -163,7 +165,12 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots)
kv_cache.store(
key=qkv[:, 1],
value=qkv[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -173,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module):
key=qkv[:, 1],
value=qkv[:, 2],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -187,6 +195,7 @@ class FlashNeoxAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,6 +18,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -137,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
# in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load(
@ -186,7 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
)
# Reshape key and value and cache
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -194,6 +201,7 @@ class FlashPhiAttention(torch.nn.Module):
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
@ -209,6 +217,7 @@ class FlashPhiAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -16,6 +16,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -84,6 +85,8 @@ class Qwen2Attention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
@ -126,7 +129,12 @@ class Qwen2Attention(torch.nn.Module):
else:
kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -136,6 +144,7 @@ class Qwen2Attention(torch.nn.Module):
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -151,6 +160,7 @@ class Qwen2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -12,6 +12,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (
@ -158,6 +159,7 @@ class FlashRWAttention(torch.nn.Module):
weights=weights,
bias=config.bias,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
@ -198,7 +200,12 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -208,6 +215,7 @@ class FlashRWAttention(torch.nn.Module):
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -222,6 +230,7 @@ class FlashRWAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -276,6 +285,7 @@ class FlashRWLargeAttention(torch.nn.Module):
weights=weights,
bias=config.bias,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
@ -311,7 +321,10 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
kv_cache.store(
key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots
key=kv[:, :, 0].contiguous(),
value=kv[:, :, 1].contiguous(),
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
@ -322,6 +335,7 @@ class FlashRWLargeAttention(torch.nn.Module):
key=kv[:, :, 0],
value=kv[:, :, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -336,6 +350,7 @@ class FlashRWLargeAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.dense(

View File

@ -17,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
@ -257,6 +258,7 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
@ -282,7 +284,12 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots)
kv_cache.store(
key=key_value[:, 0],
value=key_value[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -292,6 +299,7 @@ class FlashMQAttention(torch.nn.Module):
key=key_value[:, 0],
value=key_value[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -306,6 +314,7 @@ class FlashMQAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
FastLayerNorm,
FastRMSNorm,
@ -188,6 +189,7 @@ class Starcoder2Attention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -231,7 +233,12 @@ class Starcoder2Attention(torch.nn.Module):
else:
kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -241,6 +248,7 @@ class Starcoder2Attention(torch.nn.Module):
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -256,6 +264,7 @@ class Starcoder2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -2283,6 +2283,7 @@ class FlashCausalLM(Model):
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
page_size=BLOCK_SIZE,
kv_cache_dtype=self.kv_cache_dtype,
dtype=self.dtype,
window_left=self.sliding_window,
)

View File

@ -207,7 +207,9 @@ class Weights:
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
def get_tensor(
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
) -> torch.Tensor:
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)

View File

@ -172,6 +172,8 @@ def check_openapi(check: bool):
# allow for trailing whitespace since it's not significant
# and the precommit hook will remove it
"lint",
"--skip-rule",
"security-defined",
filename,
],
capture_output=True,