mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Merge branch 'main' into trtllm-executor-thread
This commit is contained in:
commit
01b82b58d2
@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \
|
|||||||
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
|
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl localhost:3000/v1/chat/completions \
|
curl localhost:8080/v1/chat/completions \
|
||||||
-X POST \
|
-X POST \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "tgi",
|
"model": "tgi",
|
||||||
|
@ -54,14 +54,14 @@ struct Args {
|
|||||||
otlp_service_name: String,
|
otlp_service_name: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
auth_token: Option<String>,
|
auth_token: Option<String>,
|
||||||
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
||||||
executor_worker: PathBuf,
|
executor_worker: PathBuf,
|
||||||
|
#[clap(default_value = "on", long, env)]
|
||||||
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_tokenizer(
|
async fn get_tokenizer(
|
||||||
@ -213,10 +213,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
otlp_service_name,
|
otlp_service_name,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
messages_api_enabled,
|
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
auth_token,
|
auth_token,
|
||||||
executor_worker,
|
executor_worker,
|
||||||
|
usage_stats,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
@ -293,7 +293,6 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
false,
|
false,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
messages_api_enabled,
|
|
||||||
true,
|
true,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
UsageStatsLevel::Off,
|
UsageStatsLevel::Off,
|
||||||
|
@ -44,6 +44,8 @@ struct Args {
|
|||||||
tokenizer_config_path: Option<String>,
|
tokenizer_config_path: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
#[clap(long, env, value_enum)]
|
||||||
|
trust_remote_code: bool,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -63,8 +65,6 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
trust_remote_code,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
api_key,
|
api_key,
|
||||||
json_output,
|
json_output,
|
||||||
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
@ -184,13 +184,13 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
trust_remote_code,
|
||||||
hostname,
|
hostname,
|
||||||
port,
|
port,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
@ -44,6 +44,8 @@ struct Args {
|
|||||||
tokenizer_config_path: Option<String>,
|
tokenizer_config_path: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
#[clap(long, env, value_enum)]
|
||||||
|
trust_remote_code: bool,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -63,8 +65,6 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
trust_remote_code,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
api_key,
|
api_key,
|
||||||
json_output,
|
json_output,
|
||||||
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
@ -184,13 +184,13 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
trust_remote_code,
|
||||||
hostname,
|
hostname,
|
||||||
port,
|
port,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
@ -316,6 +316,98 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/invocations": {
|
||||||
|
"post": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Generate tokens from Sagemaker request",
|
||||||
|
"operationId": "sagemaker_compatibility",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/SagemakerRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Generated Chat Completion",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/SagemakerResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/SagemakerStreamResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Input validation error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Input validation error",
|
||||||
|
"error_type": "validation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"424": {
|
||||||
|
"description": "Generation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Request failed during generation",
|
||||||
|
"error_type": "generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"description": "Model is overloaded",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Model is overloaded",
|
||||||
|
"error_type": "overloaded"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Incomplete generation",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Incomplete generation",
|
||||||
|
"error_type": "incomplete_generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/metrics": {
|
"/metrics": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": [
|
"tags": [
|
||||||
@ -1865,6 +1957,45 @@
|
|||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"SagemakerRequest": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/CompatGenerateRequest"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ChatRequest"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/CompletionRequest"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"SagemakerResponse": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/GenerateResponse"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ChatCompletion"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/CompletionFinal"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"SagemakerStreamResponse": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/StreamResponse"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ChatCompletionChunk"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/Chunk"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"SimpleToken": {
|
"SimpleToken": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
|
|||||||
|
|
||||||
## Amazon SageMaker
|
## Amazon SageMaker
|
||||||
|
|
||||||
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
|
Amazon Sagemaker natively supports the message API:
|
||||||
|
|
||||||
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import json
|
import json
|
||||||
@ -161,12 +159,11 @@ except ValueError:
|
|||||||
hub = {
|
hub = {
|
||||||
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
|
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
|
||||||
'SM_NUM_GPUS': json.dumps(1),
|
'SM_NUM_GPUS': json.dumps(1),
|
||||||
'MESSAGES_API_ENABLED': True
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# create Hugging Face Model Class
|
# create Hugging Face Model Class
|
||||||
huggingface_model = HuggingFaceModel(
|
huggingface_model = HuggingFaceModel(
|
||||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
|
image_uri=get_huggingface_llm_image_uri("huggingface",version="2.3.2"),
|
||||||
env=hub,
|
env=hub,
|
||||||
role=role,
|
role=role,
|
||||||
)
|
)
|
||||||
|
@ -8,6 +8,7 @@ Text Generation Inference enables serving optimized models. The following sectio
|
|||||||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
||||||
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||||
|
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
|
||||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||||
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||||
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
||||||
|
@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected:
|
|||||||
"max_top_n_tokens": 5,
|
"max_top_n_tokens": 5,
|
||||||
"max_total_tokens": 2048,
|
"max_total_tokens": 2048,
|
||||||
"max_waiting_tokens": 20,
|
"max_waiting_tokens": 20,
|
||||||
"messages_api_enabled": false,
|
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model_type": "Bloom"
|
"model_type": "Bloom"
|
||||||
},
|
},
|
||||||
|
@ -978,15 +978,16 @@
|
|||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1728381423,
|
"lastModified": 1729531056,
|
||||||
"narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=",
|
"narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=",
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e",
|
"rev": "a84a90281a17b15762873845c947e5c78f5a8dd1",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
|
"ref": "marlin-kernels-0.3.0",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
@ -137,6 +137,11 @@
|
|||||||
|
|
||||||
impure = callPackage ./nix/impure-shell.nix { inherit server; };
|
impure = callPackage ./nix/impure-shell.nix { inherit server; };
|
||||||
|
|
||||||
|
impureWithCuda = callPackage ./nix/impure-shell.nix {
|
||||||
|
inherit server;
|
||||||
|
withCuda = true;
|
||||||
|
};
|
||||||
|
|
||||||
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
|
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
|
||||||
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
|
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
|
||||||
};
|
};
|
||||||
|
@ -11,27 +11,27 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3923,
|
"id": 3923,
|
||||||
"logprob": -5.6328125,
|
"logprob": -6.1875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.2265625,
|
"logprob": -0.93359375,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -9.1015625,
|
"logprob": -9.875,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -1.8085938,
|
"logprob": -1.1796875,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -1.0439453,
|
"logprob": -1.75,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -39,66 +39,66 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 18682,
|
"id": 18682,
|
||||||
"logprob": -2.1992188,
|
"logprob": -1.109375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Deep"
|
"text": " Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.079956055,
|
"logprob": -0.005432129,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -0.2763672,
|
"logprob": -0.028808594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37548828,
|
"logprob": -0.013671875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27084,
|
"id": 27084,
|
||||||
"logprob": -1.4628906,
|
"logprob": -0.69921875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 315,
|
"id": 315,
|
||||||
"logprob": -0.02885437,
|
"logprob": -0.0005874634,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5780,
|
"id": 5780,
|
||||||
"logprob": -0.2565918,
|
"logprob": -0.026855469,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.0063438416,
|
"logprob": -0.00020885468,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 430,
|
"id": 430,
|
||||||
"logprob": -1.3056641,
|
"logprob": -0.17773438,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 18065,
|
||||||
"logprob": -1.6035156,
|
"logprob": -0.703125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " involves"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
"generated_text": " Deep learning is a subset of machine learning that involves"
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "eos_token",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 3,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 128000,
|
"id": 128000,
|
||||||
@ -11,22 +11,22 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -22.96875,
|
"logprob": -18.0,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -10.71875,
|
"logprob": -11.75,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -2.6992188,
|
"logprob": -2.0625,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -4.8398438,
|
"logprob": -6.0,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -34,24 +34,66 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 720,
|
"id": 720,
|
||||||
"logprob": -0.4411621,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " \n"
|
"text": " \n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 220,
|
"id": 34564,
|
||||||
"logprob": -0.35864258,
|
"logprob": -0.11279297,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 128001,
|
"id": 6975,
|
||||||
|
"logprob": -0.16015625,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 320,
|
||||||
|
"logprob": -0.25195312,
|
||||||
|
"special": false,
|
||||||
|
"text": " ("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16931,
|
||||||
|
"logprob": -1.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": "DL"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": true,
|
"special": false,
|
||||||
"text": "<|end_of_text|>"
|
"text": ")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1207,
|
||||||
|
"logprob": -1.3125,
|
||||||
|
"special": false,
|
||||||
|
"text": " sub"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2630,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "field"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "What is deep learning? \n "
|
"generated_text": "What is deep learning? \nDeep learning (DL) is a subfield"
|
||||||
}
|
}
|
||||||
|
@ -12,27 +12,27 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3923,
|
"id": 3923,
|
||||||
"logprob": -5.6328125,
|
"logprob": -6.1875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.2265625,
|
"logprob": -0.93359375,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -9.1015625,
|
"logprob": -9.875,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -1.8085938,
|
"logprob": -1.1796875,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -1.0439453,
|
"logprob": -1.75,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -40,68 +40,68 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 18682,
|
"id": 18682,
|
||||||
"logprob": -2.1992188,
|
"logprob": -1.109375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Deep"
|
"text": " Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.07897949,
|
"logprob": -0.0047912598,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -0.27734375,
|
"logprob": -0.025512695,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37402344,
|
"logprob": -0.012145996,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27084,
|
"id": 27084,
|
||||||
"logprob": -1.4511719,
|
"logprob": -0.72265625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 315,
|
"id": 315,
|
||||||
"logprob": -0.02909851,
|
"logprob": -0.0005760193,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5780,
|
"id": 5780,
|
||||||
"logprob": -0.25854492,
|
"logprob": -0.02722168,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.0061798096,
|
"logprob": -0.00023651123,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 430,
|
"id": 430,
|
||||||
"logprob": -1.3046875,
|
"logprob": -0.17285156,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 18065,
|
||||||
"logprob": -1.5537109,
|
"logprob": -0.703125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " involves"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
"generated_text": " Deep learning is a subset of machine learning that involves"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -116,27 +116,27 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3923,
|
"id": 3923,
|
||||||
"logprob": -5.6328125,
|
"logprob": -6.21875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.2265625,
|
"logprob": -0.95703125,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -9.1015625,
|
"logprob": -9.9375,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -1.8085938,
|
"logprob": -1.1328125,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -1.0439453,
|
"logprob": -1.75,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -144,68 +144,68 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 18682,
|
"id": 18682,
|
||||||
"logprob": -2.1992188,
|
"logprob": -1.1796875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Deep"
|
"text": " Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.07897949,
|
"logprob": -0.005432129,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -0.27734375,
|
"logprob": -0.02758789,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37402344,
|
"logprob": -0.013366699,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27084,
|
"id": 27084,
|
||||||
"logprob": -1.4511719,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 315,
|
"id": 315,
|
||||||
"logprob": -0.02909851,
|
"logprob": -0.0004863739,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5780,
|
"id": 5780,
|
||||||
"logprob": -0.25854492,
|
"logprob": -0.02709961,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.0061798096,
|
"logprob": -0.00022506714,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 430,
|
"id": 430,
|
||||||
"logprob": -1.3046875,
|
"logprob": -0.19726562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 18065,
|
||||||
"logprob": -1.5537109,
|
"logprob": -0.77734375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " involves"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
"generated_text": " Deep learning is a subset of machine learning that involves"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -220,27 +220,27 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3923,
|
"id": 3923,
|
||||||
"logprob": -5.6328125,
|
"logprob": -6.21875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.2265625,
|
"logprob": -0.95703125,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -9.1015625,
|
"logprob": -9.9375,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -1.8085938,
|
"logprob": -1.1328125,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -1.0439453,
|
"logprob": -1.75,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -248,68 +248,68 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 18682,
|
"id": 18682,
|
||||||
"logprob": -2.1992188,
|
"logprob": -1.1796875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Deep"
|
"text": " Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.07897949,
|
"logprob": -0.005432129,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -0.27734375,
|
"logprob": -0.02758789,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37402344,
|
"logprob": -0.013366699,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27084,
|
"id": 27084,
|
||||||
"logprob": -1.4511719,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 315,
|
"id": 315,
|
||||||
"logprob": -0.02909851,
|
"logprob": -0.0004863739,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5780,
|
"id": 5780,
|
||||||
"logprob": -0.25854492,
|
"logprob": -0.02709961,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.0061798096,
|
"logprob": -0.00022506714,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 430,
|
"id": 430,
|
||||||
"logprob": -1.3046875,
|
"logprob": -0.19726562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 18065,
|
||||||
"logprob": -1.5537109,
|
"logprob": -0.77734375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " involves"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
"generated_text": " Deep learning is a subset of machine learning that involves"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -324,27 +324,27 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3923,
|
"id": 3923,
|
||||||
"logprob": -5.6328125,
|
"logprob": -6.21875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.2265625,
|
"logprob": -0.95703125,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -9.1015625,
|
"logprob": -9.9375,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -1.8085938,
|
"logprob": -1.1328125,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -1.0439453,
|
"logprob": -1.75,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -352,67 +352,67 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 18682,
|
"id": 18682,
|
||||||
"logprob": -2.1992188,
|
"logprob": -1.1796875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Deep"
|
"text": " Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.07897949,
|
"logprob": -0.005432129,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -0.27734375,
|
"logprob": -0.02758789,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37402344,
|
"logprob": -0.013366699,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27084,
|
"id": 27084,
|
||||||
"logprob": -1.4511719,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 315,
|
"id": 315,
|
||||||
"logprob": -0.02909851,
|
"logprob": -0.0004863739,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5780,
|
"id": 5780,
|
||||||
"logprob": -0.25854492,
|
"logprob": -0.02709961,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.0061798096,
|
"logprob": -0.00022506714,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 430,
|
"id": 430,
|
||||||
"logprob": -1.3046875,
|
"logprob": -0.19726562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 18065,
|
||||||
"logprob": -1.5537109,
|
"logprob": -0.77734375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " involves"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
"generated_text": " Deep learning is a subset of machine learning that involves"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -10,80 +10,95 @@
|
|||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 1824,
|
||||||
"logprob": -11.0078125,
|
"logprob": -9.2890625,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -13.59375,
|
"logprob": -1.1503906,
|
||||||
"text": "request"
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3534,
|
||||||
|
"logprob": -9.5859375,
|
||||||
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -1.3945312,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.4555664,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7089844,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.68847656,
|
"logprob": -0.4777832,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28771,
|
"id": 23229,
|
||||||
"logprob": -1.9394531,
|
"logprob": -0.13256836,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "#"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5168,
|
||||||
"logprob": -2.8808594,
|
"logprob": -0.023849487,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -0.37280273,
|
"logprob": -0.13977051,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 264,
|
||||||
"logprob": -0.26098633,
|
"logprob": -0.14489746,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 19804,
|
||||||
"logprob": -0.0017137527,
|
"logprob": -0.63183594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1064,
|
"id": 302,
|
||||||
"logprob": -2.2695312,
|
"logprob": -0.010314941,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "##"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5599,
|
||||||
"logprob": -1.9238281,
|
"logprob": -0.0635376,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 5168,
|
||||||
"logprob": -0.48828125,
|
"logprob": -0.0028572083,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " learning"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||||
}
|
}
|
||||||
|
@ -10,42 +10,28 @@
|
|||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 349,
|
||||||
"logprob": -11.0078125,
|
"logprob": -12.0546875,
|
||||||
"text": "Test"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 3534,
|
||||||
"logprob": -13.59375,
|
"logprob": -10.53125,
|
||||||
"text": "request"
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -2.71875,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -5.0078125,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -0.34838867,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13940,
|
|
||||||
"logprob": -0.38916016,
|
|
||||||
"special": false,
|
|
||||||
"text": "``"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 28832,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "`"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3371,
|
|
||||||
"logprob": -1.2529297,
|
|
||||||
"special": false,
|
|
||||||
"text": "json"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
@ -53,37 +39,61 @@
|
|||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28751,
|
"id": 23229,
|
||||||
"logprob": 0.0,
|
"logprob": -0.18237305,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "{"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 17504,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " Learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2287,
|
"id": 349,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 345,
|
"id": 264,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " \""
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3134,
|
"id": 19804,
|
||||||
"logprob": -0.640625,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "request"
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 302,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13253,
|
||||||
|
"logprob": -0.6040039,
|
||||||
|
"special": false,
|
||||||
|
"text": " Machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 17504,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28725,
|
||||||
|
"logprob": -0.11621094,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Test request\n```json\n{\n \"request"
|
"generated_text": "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
|
||||||
}
|
}
|
||||||
|
@ -11,82 +11,97 @@
|
|||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 1824,
|
||||||
"logprob": -11.0078125,
|
"logprob": -9.2890625,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -13.59375,
|
"logprob": -1.1503906,
|
||||||
"text": "request"
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3534,
|
||||||
|
"logprob": -9.5859375,
|
||||||
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -1.3945312,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.4555664,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7089844,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.68847656,
|
"logprob": -0.4777832,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28771,
|
"id": 23229,
|
||||||
"logprob": -1.9394531,
|
"logprob": -0.13232422,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "#"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5168,
|
||||||
"logprob": -2.8828125,
|
"logprob": -0.023834229,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -0.37329102,
|
"logprob": -0.13977051,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 264,
|
||||||
"logprob": -0.2602539,
|
"logprob": -0.14416504,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 19804,
|
||||||
"logprob": -0.0017185211,
|
"logprob": -0.63183594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1064,
|
"id": 302,
|
||||||
"logprob": -2.2753906,
|
"logprob": -0.010223389,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "##"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5599,
|
||||||
"logprob": -1.9316406,
|
"logprob": -0.064208984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 5168,
|
||||||
"logprob": -0.48217773,
|
"logprob": -0.0028266907,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " learning"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -100,82 +115,97 @@
|
|||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 1824,
|
||||||
"logprob": -11.0078125,
|
"logprob": -9.2890625,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -13.59375,
|
"logprob": -1.1425781,
|
||||||
"text": "request"
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3534,
|
||||||
|
"logprob": -9.59375,
|
||||||
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -1.390625,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.45532227,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7089844,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.68847656,
|
"logprob": -0.48339844,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28771,
|
"id": 23229,
|
||||||
"logprob": -1.9394531,
|
"logprob": -0.13256836,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "#"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5168,
|
||||||
"logprob": -2.8828125,
|
"logprob": -0.02420044,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -0.37329102,
|
"logprob": -0.13977051,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 264,
|
||||||
"logprob": -0.2602539,
|
"logprob": -0.14501953,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 19804,
|
||||||
"logprob": -0.0017185211,
|
"logprob": -0.63134766,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1064,
|
"id": 302,
|
||||||
"logprob": -2.2753906,
|
"logprob": -0.010223389,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "##"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5599,
|
||||||
"logprob": -1.9316406,
|
"logprob": -0.06427002,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 5168,
|
||||||
"logprob": -0.48217773,
|
"logprob": -0.002817154,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " learning"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -189,82 +219,97 @@
|
|||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 1824,
|
||||||
"logprob": -11.0078125,
|
"logprob": -9.2890625,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -13.59375,
|
"logprob": -1.1425781,
|
||||||
"text": "request"
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3534,
|
||||||
|
"logprob": -9.59375,
|
||||||
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -1.390625,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.45532227,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7089844,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.68847656,
|
"logprob": -0.48339844,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28771,
|
"id": 23229,
|
||||||
"logprob": -1.9394531,
|
"logprob": -0.13256836,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "#"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5168,
|
||||||
"logprob": -2.8828125,
|
"logprob": -0.02420044,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -0.37329102,
|
"logprob": -0.13977051,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 264,
|
||||||
"logprob": -0.2602539,
|
"logprob": -0.14501953,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 19804,
|
||||||
"logprob": -0.0017185211,
|
"logprob": -0.63134766,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1064,
|
"id": 302,
|
||||||
"logprob": -2.2753906,
|
"logprob": -0.010223389,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "##"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5599,
|
||||||
"logprob": -1.9316406,
|
"logprob": -0.06427002,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 5168,
|
||||||
"logprob": -0.48217773,
|
"logprob": -0.002817154,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " learning"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -278,81 +323,96 @@
|
|||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 1824,
|
||||||
"logprob": -11.0078125,
|
"logprob": -9.2890625,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -13.59375,
|
"logprob": -1.1425781,
|
||||||
"text": "request"
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3534,
|
||||||
|
"logprob": -9.59375,
|
||||||
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -1.390625,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.45532227,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7089844,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.68847656,
|
"logprob": -0.48339844,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28771,
|
"id": 23229,
|
||||||
"logprob": -1.9394531,
|
"logprob": -0.13256836,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "#"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5168,
|
||||||
"logprob": -2.8828125,
|
"logprob": -0.02420044,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -0.37329102,
|
"logprob": -0.13977051,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 264,
|
||||||
"logprob": -0.2602539,
|
"logprob": -0.14501953,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 19804,
|
||||||
"logprob": -0.0017185211,
|
"logprob": -0.63134766,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1064,
|
"id": 302,
|
||||||
"logprob": -2.2753906,
|
"logprob": -0.010223389,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "##"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5599,
|
||||||
"logprob": -1.9316406,
|
"logprob": -0.06427002,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 5168,
|
||||||
"logprob": -0.48217773,
|
"logprob": -0.002817154,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " learning"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -11,32 +11,32 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.7133789,
|
"logprob": -0.6201172,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 16030,
|
||||||
"logprob": -13.9296875,
|
"logprob": -13.6484375,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.048919678,
|
"logprob": -0.003894806,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -3.0078125,
|
"logprob": -2.6386719,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -2.8105469,
|
"logprob": -6.46875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.84521484,
|
"logprob": -6.6875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -44,66 +44,66 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25584,
|
"id": 25584,
|
||||||
"logprob": -0.017028809,
|
"logprob": -0.008979797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Grad"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 993,
|
"id": 993,
|
||||||
"logprob": -0.0027313232,
|
"logprob": -8.34465e-07,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ient"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.023254395,
|
"logprob": -0.0009407997,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " descent"
|
"text": " descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -2.0623207e-05,
|
"logprob": -0.0003838539,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 385,
|
||||||
"logprob": -0.5361328,
|
"logprob": -0.24499512,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " an"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 937,
|
|
||||||
"logprob": -0.17578125,
|
|
||||||
"special": false,
|
|
||||||
"text": " first"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29899,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "-"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2098,
|
|
||||||
"logprob": -0.00011539459,
|
|
||||||
"special": false,
|
|
||||||
"text": "order"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13883,
|
"id": 13883,
|
||||||
"logprob": -0.47436523,
|
"logprob": -0.010406494,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " optimization"
|
"text": " optimization"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5687,
|
"id": 5687,
|
||||||
"logprob": -0.00027680397,
|
"logprob": -0.00024354458,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " algorithm"
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15574,
|
||||||
|
"logprob": -0.6582031,
|
||||||
|
"special": false,
|
||||||
|
"text": " commonly"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1304,
|
||||||
|
"logprob": -0.00092840195,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.19470215,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
|
||||||
}
|
}
|
||||||
|
@ -5,95 +5,95 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 338,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16030,
|
||||||
|
"logprob": -13.328125,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -6.4960938,
|
"logprob": -0.24023438,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -5.1484375,
|
"logprob": -3.1386719,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -4.0351562,
|
"logprob": -3.0878906,
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -5.2265625,
|
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 10994,
|
"id": 25584,
|
||||||
"logprob": -1.1542969,
|
|
||||||
"special": false,
|
|
||||||
"text": "Hello"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29991,
|
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "!"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 739,
|
"id": 993,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " It"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2444,
|
"id": 2726,
|
||||||
"logprob": -0.42260742,
|
|
||||||
"special": false,
|
|
||||||
"text": " seems"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 366,
|
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " you"
|
"text": " Des"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29915,
|
"id": 1760,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "'"
|
"text": "cent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 276,
|
"id": 313,
|
||||||
"logprob": -0.9838867,
|
"logprob": -0.12322998,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "re"
|
"text": " ("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3211,
|
"id": 29954,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " address"
|
"text": "G"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 292,
|
"id": 29928,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ing"
|
"text": "D"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 29897,
|
||||||
"logprob": -0.15124512,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": ")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.6040039,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 385,
|
||||||
|
"logprob": -0.1796875,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
|
"generated_text": "What is gradient descent?\nGradient Descent (GD) is an"
|
||||||
}
|
}
|
||||||
|
@ -12,32 +12,32 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.7133789,
|
"logprob": -0.6201172,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 16030,
|
||||||
"logprob": -13.9296875,
|
"logprob": -13.6484375,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.048919678,
|
"logprob": -0.003894806,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -3.0078125,
|
"logprob": -2.6386719,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -2.8105469,
|
"logprob": -6.46875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.84521484,
|
"logprob": -6.6875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -45,68 +45,68 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25584,
|
"id": 25584,
|
||||||
"logprob": -0.017028809,
|
"logprob": -0.008979797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Grad"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 993,
|
"id": 993,
|
||||||
"logprob": -0.0028476715,
|
"logprob": -8.34465e-07,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ient"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.023971558,
|
"logprob": -0.00097084045,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " descent"
|
"text": " descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -2.0384789e-05,
|
"logprob": -0.0003838539,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 385,
|
||||||
"logprob": -0.5229492,
|
"logprob": -0.23840332,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " an"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 937,
|
|
||||||
"logprob": -0.17602539,
|
|
||||||
"special": false,
|
|
||||||
"text": " first"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29899,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "-"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2098,
|
|
||||||
"logprob": -0.000116467476,
|
|
||||||
"special": false,
|
|
||||||
"text": "order"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13883,
|
"id": 13883,
|
||||||
"logprob": -0.47436523,
|
"logprob": -0.010406494,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " optimization"
|
"text": " optimization"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5687,
|
"id": 5687,
|
||||||
"logprob": -0.00027871132,
|
"logprob": -0.0002501011,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " algorithm"
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15574,
|
||||||
|
"logprob": -0.6582031,
|
||||||
|
"special": false,
|
||||||
|
"text": " commonly"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1304,
|
||||||
|
"logprob": -0.00092840195,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.18933105,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -121,32 +121,32 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.7128906,
|
"logprob": -0.6113281,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 16030,
|
||||||
"logprob": -13.9375,
|
"logprob": -13.6640625,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.05053711,
|
"logprob": -0.003929138,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -3.0058594,
|
"logprob": -2.625,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -2.8242188,
|
"logprob": -6.484375,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.84521484,
|
"logprob": -6.6875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -154,68 +154,68 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25584,
|
"id": 25584,
|
||||||
"logprob": -0.018859863,
|
"logprob": -0.009017944,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Grad"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 993,
|
"id": 993,
|
||||||
"logprob": -0.002822876,
|
"logprob": -9.536743e-07,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ient"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.023254395,
|
"logprob": -0.00097084045,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " descent"
|
"text": " descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -2.0384789e-05,
|
"logprob": -0.0003838539,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 385,
|
||||||
"logprob": -0.5229492,
|
"logprob": -0.24499512,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " an"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 937,
|
|
||||||
"logprob": -0.17126465,
|
|
||||||
"special": false,
|
|
||||||
"text": " first"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29899,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "-"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2098,
|
|
||||||
"logprob": -0.0001155138,
|
|
||||||
"special": false,
|
|
||||||
"text": "order"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13883,
|
"id": 13883,
|
||||||
"logprob": -0.47436523,
|
"logprob": -0.010406494,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " optimization"
|
"text": " optimization"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5687,
|
"id": 5687,
|
||||||
"logprob": -0.00027036667,
|
"logprob": -0.0002501011,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " algorithm"
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15574,
|
||||||
|
"logprob": -0.6435547,
|
||||||
|
"special": false,
|
||||||
|
"text": " commonly"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1304,
|
||||||
|
"logprob": -0.0009279251,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.18933105,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -230,32 +230,32 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.71484375,
|
"logprob": -0.609375,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 16030,
|
||||||
"logprob": -13.9375,
|
"logprob": -13.671875,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.049346924,
|
"logprob": -0.0040016174,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -3.0078125,
|
"logprob": -2.6230469,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -2.8242188,
|
"logprob": -6.453125,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.86328125,
|
"logprob": -6.6875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -263,68 +263,68 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25584,
|
"id": 25584,
|
||||||
"logprob": -0.017196655,
|
"logprob": -0.008956909,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Grad"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 993,
|
"id": 993,
|
||||||
"logprob": -0.0028438568,
|
"logprob": -8.34465e-07,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ient"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.023254395,
|
"logprob": -0.0009407997,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " descent"
|
"text": " descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -2.026558e-05,
|
"logprob": -0.0003721714,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 385,
|
||||||
"logprob": -0.5229492,
|
"logprob": -0.24499512,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " an"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 937,
|
|
||||||
"logprob": -0.17602539,
|
|
||||||
"special": false,
|
|
||||||
"text": " first"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29899,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "-"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2098,
|
|
||||||
"logprob": -0.00011622906,
|
|
||||||
"special": false,
|
|
||||||
"text": "order"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13883,
|
"id": 13883,
|
||||||
"logprob": -0.48608398,
|
"logprob": -0.010406494,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " optimization"
|
"text": " optimization"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5687,
|
"id": 5687,
|
||||||
"logprob": -0.00027894974,
|
"logprob": -0.0002501011,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " algorithm"
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15574,
|
||||||
|
"logprob": -0.6435547,
|
||||||
|
"special": false,
|
||||||
|
"text": " commonly"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1304,
|
||||||
|
"logprob": -0.00092601776,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.19177246,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -339,32 +339,32 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.7192383,
|
"logprob": -0.609375,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 16030,
|
||||||
"logprob": -13.9375,
|
"logprob": -13.6640625,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.050445557,
|
"logprob": -0.0038967133,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -3.0078125,
|
"logprob": -2.6347656,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -2.8242188,
|
"logprob": -6.453125,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.8276367,
|
"logprob": -6.6875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -372,67 +372,67 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25584,
|
"id": 25584,
|
||||||
"logprob": -0.01727295,
|
"logprob": -0.008979797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Grad"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 993,
|
"id": 993,
|
||||||
"logprob": -0.0027542114,
|
"logprob": -9.536743e-07,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ient"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.023254395,
|
"logprob": -0.0009407997,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " descent"
|
"text": " descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -2.0384789e-05,
|
"logprob": -0.00038409233,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 385,
|
||||||
"logprob": -0.5229492,
|
"logprob": -0.24499512,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " an"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 937,
|
|
||||||
"logprob": -0.17126465,
|
|
||||||
"special": false,
|
|
||||||
"text": " first"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29899,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "-"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2098,
|
|
||||||
"logprob": -0.00011301041,
|
|
||||||
"special": false,
|
|
||||||
"text": "order"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13883,
|
"id": 13883,
|
||||||
"logprob": -0.48608398,
|
"logprob": -0.010414124,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " optimization"
|
"text": " optimization"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5687,
|
"id": 5687,
|
||||||
"logprob": -0.00027894974,
|
"logprob": -0.00024354458,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " algorithm"
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15574,
|
||||||
|
"logprob": -0.6435547,
|
||||||
|
"special": false,
|
||||||
|
"text": " commonly"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1304,
|
||||||
|
"logprob": -0.0009279251,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.19470215,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -4,7 +4,9 @@ import pytest
|
|||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_fp8_kv_cache_handle(launcher):
|
def flash_llama_fp8_kv_cache_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
|
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
|
||||||
|
num_shard=2,
|
||||||
|
kv_cache_dtype="fp8_e4m3fn",
|
||||||
) as handle:
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
@ -25,7 +27,7 @@ async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snaps
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
response.generated_text
|
response.generated_text
|
||||||
== " Deep learning is a subset of machine learning that is"
|
== " Deep learning is a subset of machine learning that involves"
|
||||||
)
|
)
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
@ -69,7 +71,7 @@ async def test_flash_llama_fp8_kv_cache_load(
|
|||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert (
|
assert (
|
||||||
responses[0].generated_text
|
responses[0].generated_text
|
||||||
== " Deep learning is a subset of machine learning that is"
|
== " Deep learning is a subset of machine learning that involves"
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[r.generated_text == responses[0].generated_text for r in responses]
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
@ -3,7 +3,11 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_mixtral_gptq_handle(launcher):
|
def flash_mixtral_gptq_handle(launcher):
|
||||||
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
|
with launcher(
|
||||||
|
"TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ",
|
||||||
|
revision="gptq-4bit-128g-actorder_True",
|
||||||
|
num_shard=2,
|
||||||
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@ -16,7 +20,12 @@ async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
|
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
|
||||||
response = await flash_mixtral_gptq.generate(
|
response = await flash_mixtral_gptq.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text == "\n\nDeep learning is a subset of machine learning"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
@ -25,7 +34,7 @@ async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
|
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
|
||||||
response = await flash_mixtral_gptq.generate(
|
response = await flash_mixtral_gptq.generate(
|
||||||
"Test request",
|
"What is deep learning?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
repetition_penalty=1.2,
|
repetition_penalty=1.2,
|
||||||
return_full_text=True,
|
return_full_text=True,
|
||||||
@ -41,6 +50,10 @@ async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapsh
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
|
||||||
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -49,10 +62,14 @@ async def test_flash_mixtral_gptq_load(
|
|||||||
flash_mixtral_gptq, generate_load, response_snapshot
|
flash_mixtral_gptq, generate_load, response_snapshot
|
||||||
):
|
):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
|
flash_mixtral_gptq, "What is deep learning?", max_new_tokens=10, n=4
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
|
assert (
|
||||||
|
responses[0].generated_text
|
||||||
|
== "\n\nDeep learning is a subset of machine learning"
|
||||||
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[r.generated_text == responses[0].generated_text for r in responses]
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
), f"{[r.generated_text for r in responses]}"
|
), f"{[r.generated_text for r in responses]}"
|
||||||
|
@ -25,7 +25,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
|
|||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert (
|
assert (
|
||||||
response.generated_text
|
response.generated_text
|
||||||
== "Gradient descent is a first-order optimization algorithm"
|
== "Gradient descent is an optimization algorithm commonly used in"
|
||||||
)
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
|
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
|
||||||
response = await flash_phi35_moe.generate(
|
response = await flash_phi35_moe.generate(
|
||||||
"What is gradient descent?\n\n",
|
"What is gradient descent?\n",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
repetition_penalty=1.2,
|
repetition_penalty=1.2,
|
||||||
return_full_text=True,
|
return_full_text=True,
|
||||||
@ -51,7 +51,7 @@ async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
|
|||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert (
|
assert (
|
||||||
response.generated_text
|
response.generated_text
|
||||||
== "What is gradient descent?\n\nHello! It seems you're addressing a"
|
== "What is gradient descent?\nGradient Descent (GD) is an"
|
||||||
)
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_sna
|
|||||||
assert responses[0].details.generated_tokens == 10
|
assert responses[0].details.generated_tokens == 10
|
||||||
assert (
|
assert (
|
||||||
responses[0].generated_text
|
responses[0].generated_text
|
||||||
== "Gradient descent is a first-order optimization algorithm"
|
== "Gradient descent is an optimization algorithm commonly used in"
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[r.generated_text == responses[0].generated_text for r in responses]
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
@ -1104,6 +1104,8 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1507,6 +1509,10 @@ fn spawn_webserver(
|
|||||||
router_args.push(revision.to_string())
|
router_args.push(revision.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.trust_remote_code {
|
||||||
|
router_args.push("--trust-remote-code".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
if args.json_output {
|
if args.json_output {
|
||||||
router_args.push("--json-output".to_string());
|
router_args.push("--json-output".to_string());
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
{
|
{
|
||||||
|
lib,
|
||||||
mkShell,
|
mkShell,
|
||||||
black,
|
black,
|
||||||
|
cmake,
|
||||||
isort,
|
isort,
|
||||||
|
ninja,
|
||||||
|
which,
|
||||||
|
cudaPackages,
|
||||||
openssl,
|
openssl,
|
||||||
pkg-config,
|
pkg-config,
|
||||||
protobuf,
|
protobuf,
|
||||||
@ -11,14 +16,17 @@
|
|||||||
ruff,
|
ruff,
|
||||||
rust-bin,
|
rust-bin,
|
||||||
server,
|
server,
|
||||||
|
|
||||||
|
# Enable dependencies for building CUDA packages. Useful for e.g.
|
||||||
|
# developing marlin/moe-kernels in-place.
|
||||||
|
withCuda ? false,
|
||||||
}:
|
}:
|
||||||
|
|
||||||
mkShell {
|
mkShell {
|
||||||
buildInputs =
|
nativeBuildInputs =
|
||||||
[
|
[
|
||||||
black
|
black
|
||||||
isort
|
isort
|
||||||
openssl.dev
|
|
||||||
pkg-config
|
pkg-config
|
||||||
(rust-bin.stable.latest.default.override {
|
(rust-bin.stable.latest.default.override {
|
||||||
extensions = [
|
extensions = [
|
||||||
@ -31,6 +39,19 @@ mkShell {
|
|||||||
redocly
|
redocly
|
||||||
ruff
|
ruff
|
||||||
]
|
]
|
||||||
|
++ (lib.optionals withCuda [
|
||||||
|
cmake
|
||||||
|
ninja
|
||||||
|
which
|
||||||
|
|
||||||
|
# For most Torch-based extensions, setting CUDA_HOME is enough, but
|
||||||
|
# some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH.
|
||||||
|
cudaPackages.cuda_nvcc
|
||||||
|
]);
|
||||||
|
buildInputs =
|
||||||
|
[
|
||||||
|
openssl.dev
|
||||||
|
]
|
||||||
++ (with python3.pkgs; [
|
++ (with python3.pkgs; [
|
||||||
venvShellHook
|
venvShellHook
|
||||||
docker
|
docker
|
||||||
@ -40,10 +61,29 @@ mkShell {
|
|||||||
pytest
|
pytest
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
syrupy
|
syrupy
|
||||||
]);
|
])
|
||||||
|
++ (lib.optionals withCuda (
|
||||||
|
with cudaPackages;
|
||||||
|
[
|
||||||
|
cuda_cccl
|
||||||
|
cuda_cudart
|
||||||
|
cuda_nvrtc
|
||||||
|
cuda_nvtx
|
||||||
|
cuda_profiler_api
|
||||||
|
cudnn
|
||||||
|
libcublas
|
||||||
|
libcusolver
|
||||||
|
libcusparse
|
||||||
|
]
|
||||||
|
));
|
||||||
|
|
||||||
inputsFrom = [ server ];
|
inputsFrom = [ server ];
|
||||||
|
|
||||||
|
env = lib.optionalAttrs withCuda {
|
||||||
|
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
|
||||||
|
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" python3.pkgs.torch.cudaCapabilities;
|
||||||
|
};
|
||||||
|
|
||||||
venvDir = "./.venv";
|
venvDir = "./.venv";
|
||||||
|
|
||||||
postVenvCreation = ''
|
postVenvCreation = ''
|
||||||
@ -51,6 +91,7 @@ mkShell {
|
|||||||
( cd server ; python -m pip install --no-dependencies -e . )
|
( cd server ; python -m pip install --no-dependencies -e . )
|
||||||
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
||||||
'';
|
'';
|
||||||
|
|
||||||
postShellHook = ''
|
postShellHook = ''
|
||||||
unset SOURCE_DATE_EPOCH
|
unset SOURCE_DATE_EPOCH
|
||||||
export PATH=$PATH:~/.cargo/bin
|
export PATH=$PATH:~/.cargo/bin
|
||||||
|
@ -150,6 +150,7 @@ pub enum Config {
|
|||||||
Idefics2(Idefics2),
|
Idefics2(Idefics2),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
|
Granite,
|
||||||
Santacoder,
|
Santacoder,
|
||||||
Bloom,
|
Bloom,
|
||||||
Mpt,
|
Mpt,
|
||||||
|
@ -8,6 +8,7 @@ pub mod validation;
|
|||||||
mod kserve;
|
mod kserve;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
|
|
||||||
|
mod sagemaker;
|
||||||
pub mod usage_stats;
|
pub mod usage_stats;
|
||||||
mod vertex;
|
mod vertex;
|
||||||
|
|
||||||
|
@ -1,748 +0,0 @@
|
|||||||
use axum::http::HeaderValue;
|
|
||||||
use clap::Parser;
|
|
||||||
use clap::Subcommand;
|
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
|
||||||
use hf_hub::{Cache, Repo, RepoType};
|
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
|
||||||
use opentelemetry::sdk::trace;
|
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
|
||||||
use opentelemetry::sdk::Resource;
|
|
||||||
use opentelemetry::{global, KeyValue};
|
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::io::BufReader;
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use text_generation_router::config::Config;
|
|
||||||
use text_generation_router::usage_stats;
|
|
||||||
use text_generation_router::{
|
|
||||||
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
|
|
||||||
};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
|
|
||||||
use tower_http::cors::AllowOrigin;
|
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
|
||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
|
||||||
|
|
||||||
/// App Configuration
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[clap(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
#[command(subcommand)]
|
|
||||||
command: Option<Commands>,
|
|
||||||
|
|
||||||
#[clap(default_value = "128", long, env)]
|
|
||||||
max_concurrent_requests: usize,
|
|
||||||
#[clap(default_value = "2", long, env)]
|
|
||||||
max_best_of: usize,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
|
||||||
max_stop_sequences: usize,
|
|
||||||
#[clap(default_value = "5", long, env)]
|
|
||||||
max_top_n_tokens: u32,
|
|
||||||
#[clap(default_value = "1024", long, env)]
|
|
||||||
max_input_tokens: usize,
|
|
||||||
#[clap(default_value = "2048", long, env)]
|
|
||||||
max_total_tokens: usize,
|
|
||||||
#[clap(default_value = "1.2", long, env)]
|
|
||||||
waiting_served_ratio: f32,
|
|
||||||
#[clap(default_value = "4096", long, env)]
|
|
||||||
max_batch_prefill_tokens: u32,
|
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_total_tokens: Option<u32>,
|
|
||||||
#[clap(default_value = "20", long, env)]
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
|
||||||
hostname: String,
|
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
|
||||||
port: u16,
|
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
|
||||||
master_shard_uds_path: String,
|
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
|
||||||
tokenizer_name: String,
|
|
||||||
#[clap(long, env)]
|
|
||||||
tokenizer_config_path: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
revision: Option<String>,
|
|
||||||
#[clap(default_value = "2", long, env)]
|
|
||||||
validation_workers: usize,
|
|
||||||
#[clap(long, env)]
|
|
||||||
json_output: bool,
|
|
||||||
#[clap(long, env)]
|
|
||||||
otlp_endpoint: Option<String>,
|
|
||||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
|
||||||
otlp_service_name: String,
|
|
||||||
#[clap(long, env)]
|
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
api_key: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok: bool,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_authtoken: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_edge: Option<String>,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
disable_grammar_support: bool,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
|
||||||
max_client_batch_size: usize,
|
|
||||||
#[clap(long, env, default_value_t)]
|
|
||||||
disable_usage_stats: bool,
|
|
||||||
#[clap(long, env, default_value_t)]
|
|
||||||
disable_crash_reports: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
|
||||||
enum Commands {
|
|
||||||
PrintSchema,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> Result<(), RouterError> {
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
// Pattern match configuration
|
|
||||||
let Args {
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
hostname,
|
|
||||||
port,
|
|
||||||
master_shard_uds_path,
|
|
||||||
tokenizer_name,
|
|
||||||
tokenizer_config_path,
|
|
||||||
revision,
|
|
||||||
validation_workers,
|
|
||||||
json_output,
|
|
||||||
otlp_endpoint,
|
|
||||||
otlp_service_name,
|
|
||||||
cors_allow_origin,
|
|
||||||
api_key,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
disable_usage_stats,
|
|
||||||
disable_crash_reports,
|
|
||||||
command,
|
|
||||||
} = args;
|
|
||||||
|
|
||||||
let print_schema_command = match command {
|
|
||||||
Some(Commands::PrintSchema) => true,
|
|
||||||
None => {
|
|
||||||
// only init logging if we are not running the print schema command
|
|
||||||
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
|
||||||
false
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Validate args
|
|
||||||
if max_input_tokens >= max_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
if validation_workers == 0 {
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"`validation_workers` must be > 0".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CORS allowed origins
|
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
|
||||||
// Finally, convert to AllowOrigin
|
|
||||||
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
|
||||||
AllowOrigin::list(
|
|
||||||
cors_allow_origin
|
|
||||||
.iter()
|
|
||||||
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Parse Huggingface hub token
|
|
||||||
let authorization_token = std::env::var("HF_TOKEN")
|
|
||||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
|
||||||
.ok();
|
|
||||||
|
|
||||||
// Tokenizer instance
|
|
||||||
// This will only be used to validate payloads
|
|
||||||
let local_path = Path::new(&tokenizer_name);
|
|
||||||
|
|
||||||
// Shared API builder initialization
|
|
||||||
let api_builder = || {
|
|
||||||
let mut builder = ApiBuilder::new()
|
|
||||||
.with_progress(false)
|
|
||||||
.with_token(authorization_token);
|
|
||||||
|
|
||||||
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
|
||||||
builder = builder.with_cache_dir(cache_dir.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
builder
|
|
||||||
};
|
|
||||||
|
|
||||||
// Decide if we need to use the API based on the revision and local path
|
|
||||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
|
||||||
|
|
||||||
// Initialize API if needed
|
|
||||||
#[derive(Clone)]
|
|
||||||
enum Type {
|
|
||||||
Api(Api),
|
|
||||||
Cache(Cache),
|
|
||||||
None,
|
|
||||||
}
|
|
||||||
let api = if use_api {
|
|
||||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
|
||||||
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
|
|
||||||
.map_err(|_| ())
|
|
||||||
.map(|cache_dir| Cache::new(cache_dir.into()))
|
|
||||||
.unwrap_or_else(|_| Cache::default());
|
|
||||||
|
|
||||||
tracing::warn!("Offline mode active using cache defaults");
|
|
||||||
Type::Cache(cache)
|
|
||||||
} else {
|
|
||||||
tracing::info!("Using the Hugging Face API");
|
|
||||||
match api_builder().build() {
|
|
||||||
Ok(api) => Type::Api(api),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Unable to build the Hugging Face API");
|
|
||||||
Type::None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Type::None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Load tokenizer and model info
|
|
||||||
let (
|
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
|
||||||
tokenizer_config_filename,
|
|
||||||
preprocessor_config_filename,
|
|
||||||
processor_config_filename,
|
|
||||||
model_info,
|
|
||||||
) = match api {
|
|
||||||
Type::None => (
|
|
||||||
Some(local_path.join("tokenizer.json")),
|
|
||||||
Some(local_path.join("config.json")),
|
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
|
||||||
Some(local_path.join("preprocessor_config.json")),
|
|
||||||
Some(local_path.join("processor_config.json")),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
Type::Api(api) => {
|
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
|
||||||
));
|
|
||||||
|
|
||||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
|
||||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
|
||||||
};
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
|
||||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
|
||||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
|
||||||
|
|
||||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
|
||||||
Some(model_info)
|
|
||||||
} else {
|
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
|
||||||
None
|
|
||||||
};
|
|
||||||
(
|
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
|
||||||
tokenizer_config_filename,
|
|
||||||
preprocessor_config_filename,
|
|
||||||
processor_config_filename,
|
|
||||||
model_info,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Type::Cache(cache) => {
|
|
||||||
let repo = cache.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
|
||||||
));
|
|
||||||
(
|
|
||||||
repo.get("tokenizer.json"),
|
|
||||||
repo.get("config.json"),
|
|
||||||
repo.get("tokenizer_config.json"),
|
|
||||||
repo.get("preprocessor_config.json"),
|
|
||||||
repo.get("processor_config.json"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
|
||||||
std::fs::read_to_string(filename)
|
|
||||||
.ok()
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|c| {
|
|
||||||
let config: Result<Config, _> = serde_json::from_str(c);
|
|
||||||
if let Err(err) = &config {
|
|
||||||
tracing::warn!("Could not parse config {err:?}");
|
|
||||||
}
|
|
||||||
config.ok()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
|
||||||
model_id: tokenizer_name.to_string(),
|
|
||||||
sha: None,
|
|
||||||
pipeline_tag: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
|
||||||
{
|
|
||||||
HubTokenizerConfig::from_file(filename)
|
|
||||||
} else {
|
|
||||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
|
||||||
};
|
|
||||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
});
|
|
||||||
let tokenizer_class = tokenizer_config.tokenizer_class.clone();
|
|
||||||
|
|
||||||
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
|
||||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
|
||||||
if let Some(tokenizer) = &mut tokenizer {
|
|
||||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
|
||||||
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
|
|
||||||
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
|
||||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
|
||||||
tokenizer.with_post_processor(post_processor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenizer
|
|
||||||
});
|
|
||||||
|
|
||||||
let preprocessor_config =
|
|
||||||
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
|
||||||
let processor_config = processor_config_filename
|
|
||||||
.and_then(HubProcessorConfig::from_file)
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
tracing::info!("Using config {config:?}");
|
|
||||||
if tokenizer.is_none() {
|
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
|
||||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
|
||||||
None => {
|
|
||||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
|
||||||
true
|
|
||||||
}
|
|
||||||
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
|
||||||
};
|
|
||||||
|
|
||||||
// Determine the server port based on the feature and environment variable.
|
|
||||||
let port = if cfg!(feature = "google") {
|
|
||||||
std::env::var("AIP_HTTP_PORT")
|
|
||||||
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
|
||||||
.unwrap_or(port)
|
|
||||||
} else {
|
|
||||||
port
|
|
||||||
};
|
|
||||||
|
|
||||||
let addr = match hostname.parse() {
|
|
||||||
Ok(ip) => SocketAddr::new(ip, port),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
|
||||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Only send usage stats when TGI is run in container and the function returns Some
|
|
||||||
let is_container = matches!(usage_stats::is_container(), Ok(true));
|
|
||||||
|
|
||||||
let user_agent = if !disable_usage_stats && is_container {
|
|
||||||
let reduced_args = usage_stats::Args::new(
|
|
||||||
config.clone(),
|
|
||||||
tokenizer_class,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
revision,
|
|
||||||
validation_workers,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
disable_usage_stats,
|
|
||||||
disable_crash_reports,
|
|
||||||
);
|
|
||||||
Some(usage_stats::UserAgent::new(reduced_args))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(ref ua) = user_agent {
|
|
||||||
let start_event =
|
|
||||||
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
|
|
||||||
tokio::spawn(async move {
|
|
||||||
start_event.send().await;
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
// Run server
|
|
||||||
let result = server::run(
|
|
||||||
master_shard_uds_path,
|
|
||||||
model_info,
|
|
||||||
compat_return_full_text,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
tokenizer,
|
|
||||||
config,
|
|
||||||
validation_workers,
|
|
||||||
addr,
|
|
||||||
cors_allow_origin,
|
|
||||||
api_key,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
tokenizer_config,
|
|
||||||
preprocessor_config,
|
|
||||||
processor_config,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
print_schema_command,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(_) => {
|
|
||||||
if let Some(ref ua) = user_agent {
|
|
||||||
let stop_event = usage_stats::UsageStatsEvent::new(
|
|
||||||
ua.clone(),
|
|
||||||
usage_stats::EventType::Stop,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
stop_event.send().await;
|
|
||||||
};
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
if let Some(ref ua) = user_agent {
|
|
||||||
if !disable_crash_reports {
|
|
||||||
let error_event = usage_stats::UsageStatsEvent::new(
|
|
||||||
ua.clone(),
|
|
||||||
usage_stats::EventType::Error,
|
|
||||||
Some(e.to_string()),
|
|
||||||
);
|
|
||||||
error_event.send().await;
|
|
||||||
} else {
|
|
||||||
let unknow_error_event = usage_stats::UsageStatsEvent::new(
|
|
||||||
ua.clone(),
|
|
||||||
usage_stats::EventType::Error,
|
|
||||||
Some("unknow_error".to_string()),
|
|
||||||
);
|
|
||||||
unknow_error_event.send().await;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Err(RouterError::WebServer(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
|
||||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
|
||||||
/// - otlp_service_name service name to appear in APM
|
|
||||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
|
||||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
|
||||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
|
||||||
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
|
|
||||||
let mut layers = Vec::new();
|
|
||||||
|
|
||||||
// STDOUT/STDERR layer
|
|
||||||
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
|
||||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
|
||||||
.with_file(true)
|
|
||||||
.with_ansi(ansi)
|
|
||||||
.with_line_number(true);
|
|
||||||
|
|
||||||
let fmt_layer = match json_output {
|
|
||||||
true => fmt_layer.json().flatten_event(true).boxed(),
|
|
||||||
false => fmt_layer.boxed(),
|
|
||||||
};
|
|
||||||
layers.push(fmt_layer);
|
|
||||||
|
|
||||||
// OpenTelemetry tracing layer
|
|
||||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
|
||||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
|
||||||
|
|
||||||
let tracer = opentelemetry_otlp::new_pipeline()
|
|
||||||
.tracing()
|
|
||||||
.with_exporter(
|
|
||||||
opentelemetry_otlp::new_exporter()
|
|
||||||
.tonic()
|
|
||||||
.with_endpoint(otlp_endpoint),
|
|
||||||
)
|
|
||||||
.with_trace_config(
|
|
||||||
trace::config()
|
|
||||||
.with_resource(Resource::new(vec![KeyValue::new(
|
|
||||||
"service.name",
|
|
||||||
otlp_service_name,
|
|
||||||
)]))
|
|
||||||
.with_sampler(Sampler::AlwaysOn),
|
|
||||||
)
|
|
||||||
.install_batch(opentelemetry::runtime::Tokio);
|
|
||||||
|
|
||||||
if let Ok(tracer) = tracer {
|
|
||||||
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
|
||||||
init_tracing_opentelemetry::init_propagator().unwrap();
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
|
||||||
let varname = "LOG_LEVEL";
|
|
||||||
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
|
||||||
// Override to avoid simple logs to be spammed with tokio level informations
|
|
||||||
let log_level = match &log_level[..] {
|
|
||||||
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
|
||||||
"info" => "text_generation_launcher=info,text_generation_router=info",
|
|
||||||
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
|
||||||
log_level => log_level,
|
|
||||||
};
|
|
||||||
EnvFilter::builder()
|
|
||||||
.with_default_directive(LevelFilter::INFO.into())
|
|
||||||
.parse_lossy(log_level)
|
|
||||||
} else {
|
|
||||||
EnvFilter::new("info")
|
|
||||||
};
|
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
|
||||||
.with(env_filter)
|
|
||||||
.with(layers)
|
|
||||||
.init();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get model info from the Huggingface Hub
|
|
||||||
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
|
||||||
let response = api.info_request().send().await.ok()?;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
|
||||||
let hub_model_info: HubModelInfo =
|
|
||||||
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
|
||||||
if let Some(sha) = &hub_model_info.sha {
|
|
||||||
tracing::info!(
|
|
||||||
"Serving revision {sha} of model {}",
|
|
||||||
hub_model_info.model_id
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Some(hub_model_info)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get base tokenizer
|
|
||||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of `User`.
|
|
||||||
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
|
||||||
|
|
||||||
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
|
||||||
let api_base_repo = api.repo(Repo::with_revision(
|
|
||||||
base_model_id.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
"main".to_string(),
|
|
||||||
));
|
|
||||||
|
|
||||||
api_base_repo.get("tokenizer.json").await.ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get tokenizer_config from the Huggingface Hub
|
|
||||||
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(tokenizer_config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
|
|
||||||
.map_err(|e| {
|
|
||||||
tracing::warn!("Unable to parse tokenizer config: {}", e);
|
|
||||||
e
|
|
||||||
})
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
Some(tokenizer_config)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a post_processor for the LlamaTokenizer
|
|
||||||
pub fn create_post_processor(
|
|
||||||
tokenizer: &Tokenizer,
|
|
||||||
tokenizer_config: &HubTokenizerConfig,
|
|
||||||
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
|
|
||||||
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
|
|
||||||
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
|
|
||||||
|
|
||||||
let bos_token = tokenizer_config.bos_token.as_ref();
|
|
||||||
let eos_token = tokenizer_config.eos_token.as_ref();
|
|
||||||
|
|
||||||
if add_bos_token && bos_token.is_none() {
|
|
||||||
panic!("add_bos_token = true but bos_token is None");
|
|
||||||
}
|
|
||||||
|
|
||||||
if add_eos_token && eos_token.is_none() {
|
|
||||||
panic!("add_eos_token = true but eos_token is None");
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut single = Vec::new();
|
|
||||||
let mut pair = Vec::new();
|
|
||||||
let mut special_tokens = Vec::new();
|
|
||||||
|
|
||||||
if add_bos_token {
|
|
||||||
if let Some(bos) = bos_token {
|
|
||||||
let bos_token_id = tokenizer
|
|
||||||
.token_to_id(bos.as_str())
|
|
||||||
.expect("Should have found the bos token id");
|
|
||||||
special_tokens.push((bos.as_str(), bos_token_id));
|
|
||||||
single.push(format!("{}:0", bos.as_str()));
|
|
||||||
pair.push(format!("{}:0", bos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
single.push("$A:0".to_string());
|
|
||||||
pair.push("$A:0".to_string());
|
|
||||||
|
|
||||||
if add_eos_token {
|
|
||||||
if let Some(eos) = eos_token {
|
|
||||||
let eos_token_id = tokenizer
|
|
||||||
.token_to_id(eos.as_str())
|
|
||||||
.expect("Should have found the eos token id");
|
|
||||||
special_tokens.push((eos.as_str(), eos_token_id));
|
|
||||||
single.push(format!("{}:0", eos.as_str()));
|
|
||||||
pair.push(format!("{}:0", eos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if add_bos_token {
|
|
||||||
if let Some(bos) = bos_token {
|
|
||||||
pair.push(format!("{}:1", bos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pair.push("$B:1".to_string());
|
|
||||||
|
|
||||||
if add_eos_token {
|
|
||||||
if let Some(eos) = eos_token {
|
|
||||||
pair.push(format!("{}:1", eos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let post_processor = TemplateProcessing::builder()
|
|
||||||
.try_single(single)?
|
|
||||||
.try_pair(pair)?
|
|
||||||
.special_tokens(special_tokens)
|
|
||||||
.build()?;
|
|
||||||
|
|
||||||
Ok(post_processor)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
enum RouterError {
|
|
||||||
#[error("Argument validation error: {0}")]
|
|
||||||
ArgumentValidation(String),
|
|
||||||
#[error("WebServer error: {0}")]
|
|
||||||
WebServer(#[from] server::WebServerError),
|
|
||||||
#[error("Tokio runtime failed to start: {0}")]
|
|
||||||
Tokio(#[from] std::io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use text_generation_router::TokenizerConfigToken;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_create_post_processor() {
|
|
||||||
let tokenizer_config = HubTokenizerConfig {
|
|
||||||
add_bos_token: None,
|
|
||||||
add_eos_token: None,
|
|
||||||
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
|
|
||||||
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
|
|
||||||
chat_template: None,
|
|
||||||
tokenizer_class: None,
|
|
||||||
completion_template: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let tokenizer =
|
|
||||||
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
|
|
||||||
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
|
|
||||||
|
|
||||||
let expected = TemplateProcessing::builder()
|
|
||||||
.try_single("<s>:0 $A:0")
|
|
||||||
.unwrap()
|
|
||||||
.try_pair("<s>:0 $A:0 <s>:1 $B:1")
|
|
||||||
.unwrap()
|
|
||||||
.special_tokens(vec![("<s>".to_string(), 1)])
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(post_processor, expected);
|
|
||||||
}
|
|
||||||
}
|
|
82
router/src/sagemaker.rs
Normal file
82
router/src/sagemaker.rs
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
use crate::infer::Infer;
|
||||||
|
use crate::server::{chat_completions, compat_generate, completions, ComputeType};
|
||||||
|
use crate::{
|
||||||
|
ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest,
|
||||||
|
CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse,
|
||||||
|
};
|
||||||
|
use axum::extract::Extension;
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::Response;
|
||||||
|
use axum::Json;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::instrument;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum SagemakerRequest {
|
||||||
|
Generate(CompatGenerateRequest),
|
||||||
|
Chat(ChatRequest),
|
||||||
|
Completion(CompletionRequest),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used for OpenAPI specs
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum SagemakerResponse {
|
||||||
|
Generate(GenerateResponse),
|
||||||
|
Chat(ChatCompletion),
|
||||||
|
Completion(CompletionFinal),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used for OpenAPI specs
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum SagemakerStreamResponse {
|
||||||
|
Generate(StreamResponse),
|
||||||
|
Chat(ChatCompletionChunk),
|
||||||
|
Completion(Chunk),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate tokens from Sagemaker request
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/invocations",
|
||||||
|
request_body = SagemakerRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Chat Completion",
|
||||||
|
content(
|
||||||
|
("application/json" = SagemakerResponse),
|
||||||
|
("text/event-stream" = SagemakerStreamResponse),
|
||||||
|
)),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error", "error_type": "validation"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation", "error_type": "incomplete_generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn sagemaker_compatibility(
|
||||||
|
default_return_full_text: Extension<bool>,
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
compute_type: Extension<ComputeType>,
|
||||||
|
info: Extension<Info>,
|
||||||
|
Json(req): Json<SagemakerRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
match req {
|
||||||
|
SagemakerRequest::Generate(req) => {
|
||||||
|
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
|
||||||
|
}
|
||||||
|
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
|
||||||
|
SagemakerRequest::Completion(req) => {
|
||||||
|
completions(infer, compute_type, info, Json(req)).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -7,6 +7,10 @@ use crate::kserve::{
|
|||||||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
|
use crate::sagemaker::{
|
||||||
|
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
|
||||||
|
__path_sagemaker_compatibility,
|
||||||
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::vertex::vertex_compatibility;
|
use crate::vertex::vertex_compatibility;
|
||||||
use crate::ChatTokenizeResponse;
|
use crate::ChatTokenizeResponse;
|
||||||
@ -83,7 +87,7 @@ example = json ! ({"error": "Incomplete generation"})),
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(infer, req))]
|
#[instrument(skip(infer, req))]
|
||||||
async fn compat_generate(
|
pub(crate) async fn compat_generate(
|
||||||
Extension(default_return_full_text): Extension<bool>,
|
Extension(default_return_full_text): Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
compute_type: Extension<ComputeType>,
|
compute_type: Extension<ComputeType>,
|
||||||
@ -678,7 +682,7 @@ time_per_token,
|
|||||||
seed,
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn completions(
|
pub(crate) async fn completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
@ -1202,7 +1206,7 @@ time_per_token,
|
|||||||
seed,
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn chat_completions(
|
pub(crate) async fn chat_completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
@ -1513,11 +1517,13 @@ completions,
|
|||||||
tokenize,
|
tokenize,
|
||||||
metrics,
|
metrics,
|
||||||
openai_get_model_info,
|
openai_get_model_info,
|
||||||
|
sagemaker_compatibility,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
Info,
|
Info,
|
||||||
CompatGenerateRequest,
|
CompatGenerateRequest,
|
||||||
|
SagemakerRequest,
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GrammarType,
|
GrammarType,
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
@ -1540,6 +1546,8 @@ ChatCompletionTopLogprob,
|
|||||||
ChatCompletion,
|
ChatCompletion,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionComplete,
|
CompletionComplete,
|
||||||
|
SagemakerResponse,
|
||||||
|
SagemakerStreamResponse,
|
||||||
Chunk,
|
Chunk,
|
||||||
Completion,
|
Completion,
|
||||||
CompletionFinal,
|
CompletionFinal,
|
||||||
@ -1601,13 +1609,13 @@ pub async fn run(
|
|||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
tokenizer_config_path: Option<String>,
|
tokenizer_config_path: Option<String>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
trust_remote_code: bool,
|
||||||
hostname: String,
|
hostname: String,
|
||||||
port: u16,
|
port: u16,
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
_ngrok_authtoken: Option<String>,
|
_ngrok_authtoken: Option<String>,
|
||||||
_ngrok_edge: Option<String>,
|
_ngrok_edge: Option<String>,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: usage_stats::UsageStatsLevel,
|
usage_stats_level: usage_stats::UsageStatsLevel,
|
||||||
@ -1761,10 +1769,13 @@ pub async fn run(
|
|||||||
let auto = transformers.getattr("AutoTokenizer")?;
|
let auto = transformers.getattr("AutoTokenizer")?;
|
||||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||||
let args = (tokenizer_name.to_string(),);
|
let args = (tokenizer_name.to_string(),);
|
||||||
let kwargs = [(
|
let kwargs = [
|
||||||
|
(
|
||||||
"revision",
|
"revision",
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
(revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py),
|
||||||
)]
|
),
|
||||||
|
("trust_remote_code", trust_remote_code.into_py(py)),
|
||||||
|
]
|
||||||
.into_py_dict_bound(py);
|
.into_py_dict_bound(py);
|
||||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||||
let save = tokenizer.getattr("save_pretrained")?;
|
let save = tokenizer.getattr("save_pretrained")?;
|
||||||
@ -1836,7 +1847,6 @@ pub async fn run(
|
|||||||
// max_batch_size,
|
// max_batch_size,
|
||||||
revision.clone(),
|
revision.clone(),
|
||||||
validation_workers,
|
validation_workers,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats_level,
|
usage_stats_level,
|
||||||
@ -1878,7 +1888,6 @@ pub async fn run(
|
|||||||
ngrok,
|
ngrok,
|
||||||
_ngrok_authtoken,
|
_ngrok_authtoken,
|
||||||
_ngrok_edge,
|
_ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
model_info,
|
model_info,
|
||||||
@ -1938,7 +1947,6 @@ async fn start(
|
|||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
_ngrok_authtoken: Option<String>,
|
_ngrok_authtoken: Option<String>,
|
||||||
_ngrok_edge: Option<String>,
|
_ngrok_edge: Option<String>,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
model_info: HubModelInfo,
|
model_info: HubModelInfo,
|
||||||
@ -2253,6 +2261,7 @@ async fn start(
|
|||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
.route("/v1/completions", post(completions))
|
.route("/v1/completions", post(completions))
|
||||||
.route("/vertex", post(vertex_compatibility))
|
.route("/vertex", post(vertex_compatibility))
|
||||||
|
.route("/invocations", post(sagemaker_compatibility))
|
||||||
.route("/tokenize", post(tokenize));
|
.route("/tokenize", post(tokenize));
|
||||||
|
|
||||||
if let Some(api_key) = api_key {
|
if let Some(api_key) = api_key {
|
||||||
@ -2288,13 +2297,6 @@ async fn start(
|
|||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
.route("/v1/models", get(openai_get_model_info));
|
.route("/v1/models", get(openai_get_model_info));
|
||||||
|
|
||||||
// Conditional AWS Sagemaker route
|
|
||||||
let aws_sagemaker_route = if messages_api_enabled {
|
|
||||||
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
|
|
||||||
} else {
|
|
||||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
|
||||||
};
|
|
||||||
|
|
||||||
let compute_type =
|
let compute_type =
|
||||||
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||||
|
|
||||||
@ -2302,8 +2304,7 @@ async fn start(
|
|||||||
let mut app = Router::new()
|
let mut app = Router::new()
|
||||||
.merge(swagger_ui)
|
.merge(swagger_ui)
|
||||||
.merge(base_routes)
|
.merge(base_routes)
|
||||||
.merge(info_routes)
|
.merge(info_routes);
|
||||||
.merge(aws_sagemaker_route);
|
|
||||||
|
|
||||||
#[cfg(feature = "google")]
|
#[cfg(feature = "google")]
|
||||||
{
|
{
|
||||||
|
@ -93,7 +93,6 @@ pub struct Args {
|
|||||||
// max_batch_size: Option<usize>,
|
// max_batch_size: Option<usize>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: UsageStatsLevel,
|
usage_stats_level: UsageStatsLevel,
|
||||||
@ -117,7 +116,6 @@ impl Args {
|
|||||||
// max_batch_size: Option<usize>,
|
// max_batch_size: Option<usize>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: UsageStatsLevel,
|
usage_stats_level: UsageStatsLevel,
|
||||||
@ -138,7 +136,6 @@ impl Args {
|
|||||||
// max_batch_size,
|
// max_batch_size,
|
||||||
revision,
|
revision,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats_level,
|
usage_stats_level,
|
||||||
|
@ -31,7 +31,7 @@ install: install-cuda
|
|||||||
echo "Installed server"
|
echo "Installed server"
|
||||||
|
|
||||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
||||||
pip install -e ".[bnb]"
|
pip install -e ".[bnb,marlin,moe]"
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||||
|
1379
server/poetry.lock
generated
1379
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
|
|||||||
numpy = "^1.26"
|
numpy = "^1.26"
|
||||||
|
|
||||||
marlin-kernels = [
|
marlin-kernels = [
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
moe-kernels = [
|
moe-kernels = [
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
|||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
|||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
|||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -28,10 +28,11 @@ else:
|
|||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||||
|
|
||||||
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||||
from .kv_cache import KVCache
|
from .kv_cache import KVCache, get_kv_scales
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"attention",
|
"attention",
|
||||||
|
"get_kv_scales",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"KVCache",
|
"KVCache",
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import (
|
from text_generation_server.models.globals import (
|
||||||
ATTENTION,
|
ATTENTION,
|
||||||
@ -8,6 +8,7 @@ from text_generation_server.models.globals import (
|
|||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
@ -21,6 +22,8 @@ def paged_attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
*,
|
||||||
|
kv_scales: KVScales,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
@ -46,6 +49,8 @@ def paged_attention(
|
|||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
|
|
||||||
|
can_scale = kv_cache.can_scale(kv_scales)
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
@ -55,10 +60,13 @@ def paged_attention(
|
|||||||
from text_generation_server.layers.attention.flashinfer import decode_state
|
from text_generation_server.layers.attention.flashinfer import decode_state
|
||||||
|
|
||||||
return decode_state.get().forward(
|
return decode_state.get().forward(
|
||||||
|
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
|
||||||
query.contiguous(),
|
query.contiguous(),
|
||||||
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
|
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
||||||
|
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
||||||
)
|
)
|
||||||
elif ATTENTION == "flashdecoding":
|
elif ATTENTION == "flashdecoding":
|
||||||
max_q = 1
|
max_q = 1
|
||||||
@ -204,6 +212,7 @@ def attention(
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: KVCache,
|
kv_cache: KVCache,
|
||||||
|
kv_scales: KVScales,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
@ -211,6 +220,8 @@ def attention(
|
|||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
can_scale = kv_cache.can_scale(kv_scales)
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flashinfer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
prefill_with_paged_kv_state,
|
prefill_with_paged_kv_state,
|
||||||
@ -220,12 +231,15 @@ def attention(
|
|||||||
softcap = 0.0
|
softcap = 0.0
|
||||||
|
|
||||||
return prefill_with_paged_kv_state.get().forward(
|
return prefill_with_paged_kv_state.get().forward(
|
||||||
|
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
|
||||||
query.contiguous(),
|
query.contiguous(),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
window_left=window_size_left,
|
window_left=window_size_left,
|
||||||
|
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
||||||
|
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we are using flashdecoding or paged, we always use flash-attn for
|
# If we are using flashdecoding or paged, we always use flash-attn for
|
||||||
|
@ -204,6 +204,7 @@ def use_decode_state(
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
|
kv_cache_dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
window_left: int,
|
window_left: int,
|
||||||
):
|
):
|
||||||
@ -240,7 +241,7 @@ def use_decode_state(
|
|||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_size,
|
head_dim=head_size,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
data_type=dtype,
|
data_type=kv_cache_dtype,
|
||||||
q_data_type=dtype,
|
q_data_type=dtype,
|
||||||
window_left=window_left,
|
window_left=window_left,
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -14,6 +14,7 @@ def attention(
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: KVCache,
|
kv_cache: KVCache,
|
||||||
|
kv_scales: KVScales,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
@ -55,6 +56,8 @@ def paged_attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
*,
|
||||||
|
kv_scales: KVScales,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
|
@ -1,8 +1,38 @@
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.layers.fp8 import fp8_quantize
|
||||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KVScales:
|
||||||
|
"""
|
||||||
|
Key-value scales for FP8 KV cache.
|
||||||
|
|
||||||
|
This data class stores key and value scales both as a GPU tensor and
|
||||||
|
as a GPU float. This inconvenience is necessary because some functions
|
||||||
|
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
|
||||||
|
(e.g. flashinfer) take scales as a CPU scalar.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key_scale: torch.Tensor
|
||||||
|
value_scale: torch.Tensor
|
||||||
|
key_scale_cpu: float = field(init=False)
|
||||||
|
value_scale_cpu: float = field(init=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
|
||||||
|
raise ValueError("Key and value scales must be scalar tensors.")
|
||||||
|
|
||||||
|
self.key_scale_cpu = self.key_scale.item()
|
||||||
|
self.value_scale_cpu = self.value_scale.item()
|
||||||
|
|
||||||
|
|
||||||
class KVCache:
|
class KVCache:
|
||||||
@ -76,6 +106,33 @@ class KVCache:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def can_scale(self, kv_scales: KVScales) -> bool:
|
||||||
|
"""Check if the cache can be scaled by the given scales."""
|
||||||
|
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
|
||||||
|
return False
|
||||||
|
elif (
|
||||||
|
self.dtype == torch.float8_e4m3fn
|
||||||
|
and ATTENTION == "flashinfer"
|
||||||
|
and SYSTEM == "cuda"
|
||||||
|
):
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Using FP8 KV cache scales",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# We have scales, but not the correct FP8 cache type, so warn once.
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""Get the data type of the cache."""
|
||||||
|
return self.kv_cache[0].dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self):
|
def key(self):
|
||||||
"""Get the key cache."""
|
"""Get the key cache."""
|
||||||
@ -94,17 +151,33 @@ class KVCache:
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
|
kv_scales: KVScales,
|
||||||
):
|
):
|
||||||
"""Store the key and value at the given slots."""
|
"""Store the key and value at the given slots."""
|
||||||
|
|
||||||
key_cache = self.kv_cache[0]
|
key_cache = self.kv_cache[0]
|
||||||
value_cache = self.kv_cache[1]
|
value_cache = self.kv_cache[1]
|
||||||
|
|
||||||
|
if self.can_scale(kv_scales):
|
||||||
|
if kv_scales.key_scale_cpu != 1.0:
|
||||||
|
key = fp8_quantize(
|
||||||
|
key.float(),
|
||||||
|
scale=kv_scales.key_scale,
|
||||||
|
qdtype=self.dtype,
|
||||||
|
scalar=True,
|
||||||
|
)[0]
|
||||||
|
if kv_scales.value_scale_cpu != 1.0:
|
||||||
|
value = fp8_quantize(
|
||||||
|
value.float(),
|
||||||
|
scale=kv_scales.value_scale,
|
||||||
|
qdtype=self.dtype,
|
||||||
|
scalar=True,
|
||||||
|
)[0]
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||||
# TODO: add scale
|
|
||||||
key = key.to(key_cache.dtype)
|
key = key.to(key_cache.dtype)
|
||||||
value = value.to(value_cache.dtype)
|
value = value.to(value_cache.dtype)
|
||||||
if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
||||||
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
|
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
|
||||||
# put as raw data instead.
|
# put as raw data instead.
|
||||||
key_cache = key_cache.view(torch.uint8)
|
key_cache = key_cache.view(torch.uint8)
|
||||||
@ -151,5 +224,23 @@ def paged_reshape_and_cache(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention"
|
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||||
|
"""Load KV cache scales."""
|
||||||
|
|
||||||
|
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
|
||||||
|
value_scale = key_scale
|
||||||
|
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
|
||||||
|
f"{prefix}.v_scale"
|
||||||
|
):
|
||||||
|
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
|
||||||
|
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
|
||||||
|
elif weights.has_tensor(f"{prefix}.kv_scale"):
|
||||||
|
# Fall back to older more coarse-grained scale when available.
|
||||||
|
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
|
||||||
|
value_scale = key_scale
|
||||||
|
|
||||||
|
return KVScales(key_scale=key_scale, value_scale=value_scale)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
@ -36,6 +36,8 @@ def paged_attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
*,
|
||||||
|
kv_scales: KVScales,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
@ -210,6 +212,7 @@ def attention(
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: KVCache,
|
kv_cache: KVCache,
|
||||||
|
kv_scales: KVScales,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
|
@ -26,6 +26,12 @@ def is_fbgemm_gpu_available():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import marlin_kernels
|
||||||
|
except ImportError:
|
||||||
|
marlin_kernels = None
|
||||||
|
|
||||||
|
|
||||||
if is_fbgemm_gpu_available():
|
if is_fbgemm_gpu_available():
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
major, _ = torch.cuda.get_device_capability()
|
major, _ = torch.cuda.get_device_capability()
|
||||||
@ -94,6 +100,17 @@ def fp8_quantize(
|
|||||||
)
|
)
|
||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
|
if marlin_kernels is not None:
|
||||||
|
shape = weight.shape
|
||||||
|
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
||||||
|
weight.reshape(-1, shape[-1]),
|
||||||
|
dtype=qdtype,
|
||||||
|
scale=scale,
|
||||||
|
scale_ub=scale_upper_bound,
|
||||||
|
)
|
||||||
|
|
||||||
|
return qweight.reshape(shape), scale
|
||||||
|
|
||||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
finfo = torch.finfo(qdtype)
|
finfo = torch.finfo(qdtype)
|
||||||
|
|
||||||
|
@ -195,6 +195,11 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Phi 3",
|
"name": "Phi 3",
|
||||||
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||||
}
|
}
|
||||||
|
GRANITE = {
|
||||||
|
"type": "granite",
|
||||||
|
"name": "Granite",
|
||||||
|
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
}
|
||||||
GEMMA = {
|
GEMMA = {
|
||||||
"type": "gemma",
|
"type": "gemma",
|
||||||
"name": "Gemma",
|
"name": "Gemma",
|
||||||
@ -862,7 +867,12 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
elif (
|
||||||
|
model_type == LLAMA
|
||||||
|
or model_type == BAICHUAN
|
||||||
|
or model_type == PHI3
|
||||||
|
or model_type == GRANITE
|
||||||
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -876,7 +886,9 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
|||||||
attention,
|
attention,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -227,6 +228,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.use_qk_norm = config.use_qk_norm
|
self.use_qk_norm = config.use_qk_norm
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
@ -289,7 +291,12 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, key, cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
kv_cache.store(key=key, value=value, slots=slots)
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -299,6 +306,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -313,6 +321,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
|
@ -20,6 +20,7 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM != "ipex":
|
if SYSTEM != "ipex":
|
||||||
@ -288,6 +289,7 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -328,7 +330,12 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -338,6 +345,7 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
key=kv[:, 0],
|
key=kv[:, 0],
|
||||||
value=kv[:, 1],
|
value=kv[:, 1],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -352,6 +360,7 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -34,6 +34,7 @@ from text_generation_server.layers.attention import (
|
|||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
@ -230,6 +231,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.kv_a_layernorm = FastRMSNorm.load(
|
self.kv_a_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
@ -258,7 +261,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
cu_seqlen_prefill: torch.Tensor,
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
kv_cache: KVCache,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
@ -319,7 +322,12 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_cache.store(key=key, value=value, slots=slots)
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -329,6 +337,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -343,6 +352,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove padding.
|
# Remove padding.
|
||||||
|
@ -39,6 +39,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -206,6 +207,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
],
|
],
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -251,7 +253,12 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -261,6 +268,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
key=kv[:, 0],
|
key=kv[:, 0],
|
||||||
value=kv[:, 1],
|
value=kv[:, 1],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -278,6 +286,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
softcap=self.softcap,
|
softcap=self.softcap,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
|
@ -37,6 +37,7 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -185,6 +186,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -222,7 +224,12 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -232,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
key=kv[:, 0],
|
key=kv[:, 0],
|
||||||
value=kv[:, 1],
|
value=kv[:, 1],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -247,6 +255,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -36,6 +36,7 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
|
||||||
|
|
||||||
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||||
@ -193,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = load_row(
|
self.o_proj = load_row(
|
||||||
config,
|
config,
|
||||||
@ -222,7 +224,12 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
key = key.view(-1, self.num_heads, self.head_size)
|
key = key.view(-1, self.num_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_heads, self.head_size)
|
value = value.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
kv_cache.store(key=key, value=value, slots=slots)
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -232,6 +239,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -246,6 +254,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -24,6 +24,7 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
@ -138,6 +139,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = load_row(
|
self.o_proj = load_row(
|
||||||
config,
|
config,
|
||||||
@ -184,7 +186,12 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.rotary_emb(query, key, cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
kv_cache.store(key=key, value=value, slots=slots)
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -194,6 +201,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -208,6 +216,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -27,7 +27,10 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from text_generation_server.layers.attention import KVCache
|
from text_generation_server.layers.attention import (
|
||||||
|
KVCache,
|
||||||
|
get_kv_scales,
|
||||||
|
)
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
@ -156,7 +159,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
# `config.attention_multiplier` is used in Granite
|
||||||
|
self.softmax_scale = getattr(
|
||||||
|
config, "attention_multiplier", self.head_size**-0.5
|
||||||
|
)
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -176,11 +182,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
self.index = index
|
self.index = index
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=getattr(config, "attention_bias", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
@ -221,7 +229,12 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -230,6 +243,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
query=query,
|
query=query,
|
||||||
key=kv[:, 0],
|
key=kv[:, 0],
|
||||||
value=kv[:, 1],
|
value=kv[:, 1],
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
@ -245,6 +259,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
@ -436,6 +451,11 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
# This could eventually be baked into the weights like we do for the embeddings/lm_head
|
||||||
|
# but this would mean modifying the lora code
|
||||||
|
self.residual_multiplier = getattr(config, "residual_multiplier", None)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -466,13 +486,16 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
if self.residual_multiplier is not None:
|
||||||
|
attn_output *= self.residual_multiplier
|
||||||
|
|
||||||
# faster post attention rms norm
|
|
||||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.dense(normed_attn_res_output, adapter_data)
|
mlp_output = self.dense(normed_attn_res_output, adapter_data)
|
||||||
|
if self.residual_multiplier is not None:
|
||||||
|
mlp_output *= self.residual_multiplier
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return mlp_output, attn_res
|
||||||
|
|
||||||
@ -624,6 +647,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
suffix = "lm_head"
|
suffix = "lm_head"
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
||||||
|
if embedding_multiplier is not None:
|
||||||
|
self.embed_tokens.weight.data *= embedding_multiplier
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
@ -631,6 +659,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
self.logits_scaling = getattr(config, "logits_scaling", None)
|
||||||
|
if self.logits_scaling is not None and self.lm_head.head is not None:
|
||||||
|
try:
|
||||||
|
# Scale the weights directly
|
||||||
|
self.lm_head.head.linear.weight.data /= self.logits_scaling
|
||||||
|
self.logits_scaled = True
|
||||||
|
except Exception:
|
||||||
|
self.logits_scaled = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -664,4 +702,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
# Used in Granite
|
||||||
|
if self.logits_scaling is not None and not self.logits_scaled:
|
||||||
|
logits /= self.logits_scaling
|
||||||
|
if speculative_logits is not None:
|
||||||
|
speculative_logits /= self.logits_scaling
|
||||||
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
@ -158,6 +159,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
],
|
],
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -208,7 +210,12 @@ class MistralAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=kv_to_cache[:, 0],
|
||||||
|
value=kv_to_cache[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -218,6 +225,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
key=kv_to_cache[:, 0],
|
key=kv_to_cache[:, 0],
|
||||||
value=kv_to_cache[:, 1],
|
value=kv_to_cache[:, 1],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -233,6 +241,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
|
@ -38,6 +38,7 @@ from text_generation_server.layers.attention import (
|
|||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
@ -213,6 +214,7 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -256,7 +258,12 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=kv_to_cache[:, 0],
|
||||||
|
value=kv_to_cache[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -266,6 +273,7 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
key=kv_to_cache[:, 0],
|
key=kv_to_cache[:, 0],
|
||||||
value=kv_to_cache[:, 1],
|
value=kv_to_cache[:, 1],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -281,6 +289,7 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -38,6 +38,7 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
@ -130,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
self.dense = load_row(
|
self.dense = load_row(
|
||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
@ -163,7 +165,12 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots)
|
kv_cache.store(
|
||||||
|
key=qkv[:, 1],
|
||||||
|
value=qkv[:, 2],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -173,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
key=qkv[:, 1],
|
key=qkv[:, 1],
|
||||||
value=qkv[:, 2],
|
value=qkv[:, 2],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -187,6 +195,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -18,6 +18,7 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
@ -137,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
# in llama the dense layer is called "o_proj" and has bias=False
|
# in llama the dense layer is called "o_proj" and has bias=False
|
||||||
self.dense = TensorParallelRowLinear.load(
|
self.dense = TensorParallelRowLinear.load(
|
||||||
@ -186,7 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Reshape key and value and cache
|
# Reshape key and value and cache
|
||||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -194,6 +201,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
query=query,
|
query=query,
|
||||||
key=kv[:, 0],
|
key=kv[:, 0],
|
||||||
value=kv[:, 1],
|
value=kv[:, 1],
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
@ -209,6 +217,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -16,6 +16,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -84,6 +85,8 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
@ -126,7 +129,12 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=kv_to_cache[:, 0],
|
||||||
|
value=kv_to_cache[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -136,6 +144,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
key=kv_to_cache[:, 0],
|
key=kv_to_cache[:, 0],
|
||||||
value=kv_to_cache[:, 1],
|
value=kv_to_cache[:, 1],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -151,6 +160,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -12,6 +12,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import FastLayerNorm
|
from text_generation_server.layers.layernorm import FastLayerNorm
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
@ -158,6 +159,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.bias,
|
bias=config.bias,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
self.dense = load_row(
|
self.dense = load_row(
|
||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||||
)
|
)
|
||||||
@ -198,7 +200,12 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=kv[:, 0],
|
||||||
|
value=kv[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -208,6 +215,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
key=kv[:, 0],
|
key=kv[:, 0],
|
||||||
value=kv[:, 1],
|
value=kv[:, 1],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -222,6 +230,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
@ -276,6 +285,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=config.bias,
|
bias=config.bias,
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
self.dense = load_row(
|
self.dense = load_row(
|
||||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||||
)
|
)
|
||||||
@ -311,7 +321,10 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
kv_cache.store(
|
kv_cache.store(
|
||||||
key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots
|
key=kv[:, :, 0].contiguous(),
|
||||||
|
value=kv[:, :, 1].contiguous(),
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
@ -322,6 +335,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
key=kv[:, :, 0],
|
key=kv[:, :, 0],
|
||||||
value=kv[:, :, 1],
|
value=kv[:, :, 1],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -336,6 +350,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.dense(
|
return self.dense(
|
||||||
|
@ -17,6 +17,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
@ -257,6 +258,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
self.c_proj = load_row(
|
self.c_proj = load_row(
|
||||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
self.kv_head_mapping = torch.zeros(
|
self.kv_head_mapping = torch.zeros(
|
||||||
self.num_heads, dtype=torch.int32, device=weights.device
|
self.num_heads, dtype=torch.int32, device=weights.device
|
||||||
)
|
)
|
||||||
@ -282,7 +284,12 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=key_value[:, 0],
|
||||||
|
value=key_value[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -292,6 +299,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
key=key_value[:, 0],
|
key=key_value[:, 0],
|
||||||
value=key_value[:, 1],
|
value=key_value[:, 1],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -306,6 +314,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -38,6 +38,7 @@ from text_generation_server.layers import (
|
|||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -188,6 +189,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -231,7 +233,12 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
kv_cache.store(
|
||||||
|
key=kv_to_cache[:, 0],
|
||||||
|
value=kv_to_cache[:, 1],
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
@ -241,6 +248,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
key=kv_to_cache[:, 0],
|
key=kv_to_cache[:, 0],
|
||||||
value=kv_to_cache[:, 1],
|
value=kv_to_cache[:, 1],
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@ -256,6 +264,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
block_tables,
|
block_tables,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
@ -2283,6 +2283,7 @@ class FlashCausalLM(Model):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
page_size=BLOCK_SIZE,
|
page_size=BLOCK_SIZE,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
window_left=self.sliding_window,
|
window_left=self.sliding_window,
|
||||||
)
|
)
|
||||||
|
@ -207,7 +207,9 @@ class Weights:
|
|||||||
def get_shape(self, tensor_name: str):
|
def get_shape(self, tensor_name: str):
|
||||||
return self._get_slice(tensor_name).get_shape()
|
return self._get_slice(tensor_name).get_shape()
|
||||||
|
|
||||||
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
|
def get_tensor(
|
||||||
|
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
|
||||||
|
) -> torch.Tensor:
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
|
@ -172,6 +172,8 @@ def check_openapi(check: bool):
|
|||||||
# allow for trailing whitespace since it's not significant
|
# allow for trailing whitespace since it's not significant
|
||||||
# and the precommit hook will remove it
|
# and the precommit hook will remove it
|
||||||
"lint",
|
"lint",
|
||||||
|
"--skip-rule",
|
||||||
|
"security-defined",
|
||||||
filename,
|
filename,
|
||||||
],
|
],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user