diff --git a/.github/workflows/delete_doc_comment.yml b/.github/workflows/delete_doc_comment.yml
deleted file mode 100644
index 1cad807b..00000000
--- a/.github/workflows/delete_doc_comment.yml
+++ /dev/null
@@ -1,12 +0,0 @@
-name: Delete doc comment
-
-on:
- pull_request:
- types: [ closed ]
-
-
-jobs:
- delete:
- uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
- with:
- pr_number: ${{ github.event.number }}
\ No newline at end of file
diff --git a/Cargo.lock b/Cargo.lock
index 9048105e..689dc0ae 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1398,9 +1398,9 @@ dependencies = [
[[package]]
name = "minijinja"
-version = "1.0.10"
+version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "208758577ef2c86cf5dd3e85730d161413ec3284e2d73b2ef65d9a24d9971bcb"
+checksum = "fb5c5e3d2b4c0a6832bd3d571f7c19a7c1c1f05f11a6e85ae1a29f76be5f9455"
dependencies = [
"serde",
]
@@ -2811,7 +2811,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
-version = "1.3.4"
+version = "1.4.0"
dependencies = [
"average",
"clap",
@@ -2832,7 +2832,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
-version = "1.3.4"
+version = "1.4.0"
dependencies = [
"futures",
"grpc-metadata",
@@ -2849,7 +2849,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
-version = "1.3.4"
+version = "1.4.0"
dependencies = [
"clap",
"ctrlc",
@@ -2865,7 +2865,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
-version = "1.3.4"
+version = "1.4.0"
dependencies = [
"async-stream",
"axum",
diff --git a/Cargo.toml b/Cargo.toml
index 80e6e145..a328a368 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -9,7 +9,7 @@ members = [
resolver = "2"
[workspace.package]
-version = "1.3.4"
+version = "1.4.0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
diff --git a/docs/openapi.json b/docs/openapi.json
index 9a9ed116..da3969df 100644
--- a/docs/openapi.json
+++ b/docs/openapi.json
@@ -1 +1,1293 @@
-{"openapi":"3.0.3","info":{"title":"Text Generation Inference","description":"Text Generation Webserver","contact":{"name":"Olivier Dehaene"},"license":{"name":"Apache 2.0","url":"https://www.apache.org/licenses/LICENSE-2.0"},"version":"1.3.4"},"paths":{"/":{"post":{"tags":["Text Generation Inference"],"summary":"Generate tokens if `stream == false` or a stream of token if `stream == true`","description":"Generate tokens if `stream == false` or a stream of token if `stream == true`","operationId":"compat_generate","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/CompatGenerateRequest"}}},"required":true},"responses":{"200":{"description":"Generated Text","content":{"application/json":{"schema":{"$ref":"#/components/schemas/GenerateResponse"}},"text/event-stream":{"schema":{"$ref":"#/components/schemas/StreamResponse"}}}},"422":{"description":"Input validation error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Input validation error"}}}},"424":{"description":"Generation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Request failed during generation"}}}},"429":{"description":"Model is overloaded","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Model is overloaded"}}}},"500":{"description":"Incomplete generation","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Incomplete generation"}}}}}}},"/generate":{"post":{"tags":["Text Generation Inference"],"summary":"Generate tokens","description":"Generate tokens","operationId":"generate","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/GenerateRequest"}}},"required":true},"responses":{"200":{"description":"Generated Text","content":{"application/json":{"schema":{"$ref":"#/components/schemas/GenerateResponse"}}}},"422":{"description":"Input validation error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Input validation error"}}}},"424":{"description":"Generation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Request failed during generation"}}}},"429":{"description":"Model is overloaded","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Model is overloaded"}}}},"500":{"description":"Incomplete generation","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Incomplete generation"}}}}}}},"/generate_stream":{"post":{"tags":["Text Generation Inference"],"summary":"Generate a stream of token using Server-Sent Events","description":"Generate a stream of token using Server-Sent Events","operationId":"generate_stream","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/GenerateRequest"}}},"required":true},"responses":{"200":{"description":"Generated Text","content":{"text/event-stream":{"schema":{"$ref":"#/components/schemas/StreamResponse"}}}},"422":{"description":"Input validation error","content":{"text/event-stream":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Input validation error"}}}},"424":{"description":"Generation Error","content":{"text/event-stream":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Request failed during generation"}}}},"429":{"description":"Model is overloaded","content":{"text/event-stream":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Model is overloaded"}}}},"500":{"description":"Incomplete generation","content":{"text/event-stream":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Incomplete generation"}}}}}}},"/health":{"get":{"tags":["Text Generation Inference"],"summary":"Health check method","description":"Health check method","operationId":"health","responses":{"200":{"description":"Everything is working fine"},"503":{"description":"Text generation inference is down","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"unhealthy","error_type":"healthcheck"}}}}}}},"/info":{"get":{"tags":["Text Generation Inference"],"summary":"Text Generation Inference endpoint info","description":"Text Generation Inference endpoint info","operationId":"get_model_info","responses":{"200":{"description":"Served model info","content":{"application/json":{"schema":{"$ref":"#/components/schemas/Info"}}}}}}},"/metrics":{"get":{"tags":["Text Generation Inference"],"summary":"Prometheus metrics scrape endpoint","description":"Prometheus metrics scrape endpoint","operationId":"metrics","responses":{"200":{"description":"Prometheus Metrics","content":{"text/plain":{"schema":{"type":"string"}}}}}}},"/tokenize":{"post":{"tags":["Text Generation Inference"],"summary":"Tokenize inputs","description":"Tokenize inputs","operationId":"tokenize","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/GenerateRequest"}}},"required":true},"responses":{"200":{"description":"Tokenized ids","content":{"application/json":{"schema":{"$ref":"#/components/schemas/TokenizeResponse"}}}},"404":{"description":"No tokenizer found","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"No fast tokenizer available"}}}}}}},"/v1/chat/completions":{"post":{"tags":["Text Generation Inference"],"summary":"Generate tokens","description":"Generate tokens","operationId":"chat_completions","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/ChatRequest"}}},"required":true},"responses":{"200":{"description":"Generated Text","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ChatCompletionChunk"}}}},"422":{"description":"Input validation error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Input validation error"}}}},"424":{"description":"Generation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Request failed during generation"}}}},"429":{"description":"Model is overloaded","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Model is overloaded"}}}},"500":{"description":"Incomplete generation","content":{"application/json":{"schema":{"$ref":"#/components/schemas/ErrorResponse"},"example":{"error":"Incomplete generation"}}}}}}}},"components":{"schemas":{"BestOfSequence":{"type":"object","required":["generated_text","finish_reason","generated_tokens","prefill","tokens"],"properties":{"finish_reason":{"$ref":"#/components/schemas/FinishReason"},"generated_text":{"type":"string","example":"test"},"generated_tokens":{"type":"integer","format":"int32","example":1,"minimum":0},"prefill":{"type":"array","items":{"$ref":"#/components/schemas/PrefillToken"}},"seed":{"type":"integer","format":"int64","example":42,"nullable":true,"minimum":0},"tokens":{"type":"array","items":{"$ref":"#/components/schemas/Token"}},"top_tokens":{"type":"array","items":{"type":"array","items":{"$ref":"#/components/schemas/Token"}}}}},"ChatCompletion":{"type":"object","required":["id","object","created","model","system_fingerprint","choices","usage"],"properties":{"choices":{"type":"array","items":{"$ref":"#/components/schemas/ChatCompletionComplete"}},"created":{"type":"integer","format":"int64","example":"1706270835","minimum":0},"id":{"type":"string"},"model":{"type":"string","example":"mistralai/Mistral-7B-Instruct-v0.2"},"object":{"type":"string"},"system_fingerprint":{"type":"string"},"usage":{"$ref":"#/components/schemas/Usage"}}},"ChatCompletionChoice":{"type":"object","required":["index","delta"],"properties":{"delta":{"$ref":"#/components/schemas/ChatCompletionDelta"},"finish_reason":{"type":"string","nullable":true},"index":{"type":"integer","format":"int32","minimum":0},"logprobs":{"type":"number","format":"float","nullable":true}}},"ChatCompletionChunk":{"type":"object","required":["id","object","created","model","system_fingerprint","choices"],"properties":{"choices":{"type":"array","items":{"$ref":"#/components/schemas/ChatCompletionChoice"}},"created":{"type":"integer","format":"int64","example":"1706270978","minimum":0},"id":{"type":"string"},"model":{"type":"string","example":"mistralai/Mistral-7B-Instruct-v0.2"},"object":{"type":"string"},"system_fingerprint":{"type":"string"}}},"ChatCompletionDelta":{"type":"object","required":["role","content"],"properties":{"content":{"type":"string","example":"What is Deep Learning?"},"role":{"type":"string","example":"user"}}},"ChatRequest":{"type":"object","required":["model"],"properties":{"frequency_penalty":{"type":"number","format":"float","description":"Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.","example":"1.0","nullable":true},"logit_bias":{"type":"array","items":{"type":"number","format":"float"},"description":"UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.","nullable":true},"logprobs":{"type":"boolean","description":"Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each\noutput token returned in the content of message.","example":"false","nullable":true},"max_tokens":{"type":"integer","format":"int32","description":"The maximum number of tokens that can be generated in the chat completion.","example":"32","nullable":true,"minimum":0},"messages":{"type":"array","items":{"$ref":"#/components/schemas/Message"},"description":"A list of messages comprising the conversation so far."},"model":{"type":"string","description":"UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.","example":"mistralai/Mistral-7B-Instruct-v0.2"},"n":{"type":"integer","format":"int32","description":"UNUSED\nHow many chat completion choices to generate for each input message. Note that you will be charged based on the\nnumber of generated tokens across all of the choices. Keep n as 1 to minimize costs.","example":"2","nullable":true,"minimum":0},"presence_penalty":{"type":"number","format":"float","description":"UNUSED\nNumber between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics","example":0.1,"nullable":true},"seed":{"type":"integer","format":"int64","example":42,"nullable":true,"minimum":0},"stream":{"type":"boolean"},"temperature":{"type":"number","format":"float","description":"What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\nlower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.","example":1.0,"nullable":true},"top_logprobs":{"type":"integer","format":"int32","description":"UNUSED\nAn integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.","example":"5","nullable":true,"minimum":0},"top_p":{"type":"number","format":"float","description":"An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.","example":0.95,"nullable":true}}},"CompatGenerateRequest":{"type":"object","required":["inputs"],"properties":{"inputs":{"type":"string","example":"My name is Olivier and I"},"parameters":{"$ref":"#/components/schemas/GenerateParameters"},"stream":{"type":"boolean","default":"false"}}},"Details":{"type":"object","required":["finish_reason","generated_tokens","prefill","tokens"],"properties":{"best_of_sequences":{"type":"array","items":{"$ref":"#/components/schemas/BestOfSequence"},"nullable":true},"finish_reason":{"$ref":"#/components/schemas/FinishReason"},"generated_tokens":{"type":"integer","format":"int32","example":1,"minimum":0},"prefill":{"type":"array","items":{"$ref":"#/components/schemas/PrefillToken"}},"seed":{"type":"integer","format":"int64","example":42,"nullable":true,"minimum":0},"tokens":{"type":"array","items":{"$ref":"#/components/schemas/Token"}},"top_tokens":{"type":"array","items":{"type":"array","items":{"$ref":"#/components/schemas/Token"}}}}},"ErrorResponse":{"type":"object","required":["error","error_type"],"properties":{"error":{"type":"string"},"error_type":{"type":"string"}}},"FinishReason":{"type":"string","enum":["length","eos_token","stop_sequence"],"example":"Length"},"GenerateParameters":{"type":"object","properties":{"best_of":{"type":"integer","default":"null","example":1,"nullable":true,"minimum":0,"exclusiveMinimum":0},"decoder_input_details":{"type":"boolean","default":"true"},"details":{"type":"boolean","default":"true"},"do_sample":{"type":"boolean","default":"false","example":true},"max_new_tokens":{"type":"integer","format":"int32","default":"100","example":"20","nullable":true,"minimum":0},"repetition_penalty":{"type":"number","format":"float","default":"null","example":1.03,"nullable":true,"exclusiveMinimum":0},"return_full_text":{"type":"boolean","default":"null","example":false,"nullable":true},"seed":{"type":"integer","format":"int64","default":"null","example":"null","nullable":true,"minimum":0,"exclusiveMinimum":0},"stop":{"type":"array","items":{"type":"string"},"example":["photographer"],"maxItems":4},"temperature":{"type":"number","format":"float","default":"null","example":0.5,"nullable":true,"exclusiveMinimum":0},"top_k":{"type":"integer","format":"int32","default":"null","example":10,"nullable":true,"exclusiveMinimum":0},"top_n_tokens":{"type":"integer","format":"int32","default":"null","example":5,"nullable":true,"minimum":0,"exclusiveMinimum":0},"top_p":{"type":"number","format":"float","default":"null","example":0.95,"nullable":true,"maximum":1,"exclusiveMinimum":0},"truncate":{"type":"integer","default":"null","example":"null","nullable":true,"minimum":0},"typical_p":{"type":"number","format":"float","default":"null","example":0.95,"nullable":true,"maximum":1,"exclusiveMinimum":0},"watermark":{"type":"boolean","default":"false","example":true}}},"GenerateRequest":{"type":"object","required":["inputs"],"properties":{"inputs":{"type":"string","example":"My name is Olivier and I"},"parameters":{"$ref":"#/components/schemas/GenerateParameters"}}},"GenerateResponse":{"type":"object","required":["generated_text"],"properties":{"details":{"allOf":[{"$ref":"#/components/schemas/Details"}],"nullable":true},"generated_text":{"type":"string","example":"test"}}},"Info":{"type":"object","required":["model_id","model_dtype","model_device_type","max_concurrent_requests","max_best_of","max_stop_sequences","max_input_length","max_total_tokens","waiting_served_ratio","max_batch_total_tokens","max_waiting_tokens","validation_workers","version"],"properties":{"docker_label":{"type":"string","example":"null","nullable":true},"max_batch_total_tokens":{"type":"integer","format":"int32","example":"32000","minimum":0},"max_best_of":{"type":"integer","example":"2","minimum":0},"max_concurrent_requests":{"type":"integer","description":"Router Parameters","example":"128","minimum":0},"max_input_length":{"type":"integer","example":"1024","minimum":0},"max_stop_sequences":{"type":"integer","example":"4","minimum":0},"max_total_tokens":{"type":"integer","example":"2048","minimum":0},"max_waiting_tokens":{"type":"integer","example":"20","minimum":0},"model_device_type":{"type":"string","example":"cuda"},"model_dtype":{"type":"string","example":"torch.float16"},"model_id":{"type":"string","description":"Model info","example":"bigscience/blomm-560m"},"model_pipeline_tag":{"type":"string","example":"text-generation","nullable":true},"model_sha":{"type":"string","example":"e985a63cdc139290c5f700ff1929f0b5942cced2","nullable":true},"sha":{"type":"string","example":"null","nullable":true},"validation_workers":{"type":"integer","example":"2","minimum":0},"version":{"type":"string","description":"Router Info","example":"0.5.0"},"waiting_served_ratio":{"type":"number","format":"float","example":"1.2"}}},"Message":{"type":"object","required":["role","content"],"properties":{"content":{"type":"string","example":"My name is David and I"},"role":{"type":"string","example":"user"}}},"PrefillToken":{"type":"object","required":["id","text","logprob"],"properties":{"id":{"type":"integer","format":"int32","example":0,"minimum":0},"logprob":{"type":"number","format":"float","example":-0.34,"nullable":true},"text":{"type":"string","example":"test"}}},"SimpleToken":{"type":"object","required":["id","text","start","stop"],"properties":{"id":{"type":"integer","format":"int32","example":0,"minimum":0},"start":{"type":"integer","example":0,"minimum":0},"stop":{"type":"integer","example":2,"minimum":0},"text":{"type":"string","example":"test"}}},"StreamDetails":{"type":"object","required":["finish_reason","generated_tokens"],"properties":{"finish_reason":{"$ref":"#/components/schemas/FinishReason"},"generated_tokens":{"type":"integer","format":"int32","example":1,"minimum":0},"seed":{"type":"integer","format":"int64","example":42,"nullable":true,"minimum":0}}},"StreamResponse":{"type":"object","required":["index","token"],"properties":{"details":{"allOf":[{"$ref":"#/components/schemas/StreamDetails"}],"default":"null","nullable":true},"generated_text":{"type":"string","default":"null","example":"test","nullable":true},"index":{"type":"integer","format":"int32","minimum":0},"token":{"$ref":"#/components/schemas/Token"},"top_tokens":{"type":"array","items":{"$ref":"#/components/schemas/Token"}}}},"Token":{"type":"object","required":["id","text","logprob","special"],"properties":{"id":{"type":"integer","format":"int32","example":0,"minimum":0},"logprob":{"type":"number","format":"float","example":-0.34,"nullable":true},"special":{"type":"boolean","example":"false"},"text":{"type":"string","example":"test"}}},"TokenizeResponse":{"type":"array","items":{"$ref":"#/components/schemas/SimpleToken"}}}},"tags":[{"name":"Text Generation Inference","description":"Hugging Face Text Generation Inference API"}]}
\ No newline at end of file
+{
+ "openapi": "3.0.3",
+ "info": {
+ "title": "Text Generation Inference",
+ "description": "Text Generation Webserver",
+ "contact": {
+ "name": "Olivier Dehaene"
+ },
+ "license": {
+ "name": "Apache 2.0",
+ "url": "https://www.apache.org/licenses/LICENSE-2.0"
+ },
+ "version": "1.4.0"
+ },
+ "paths": {
+ "/": {
+ "post": {
+ "tags": [
+ "Text Generation Inference"
+ ],
+ "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
+ "description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
+ "operationId": "compat_generate",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/CompatGenerateRequest"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "200": {
+ "description": "Generated Text",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/GenerateResponse"
+ }
+ },
+ "text/event-stream": {
+ "schema": {
+ "$ref": "#/components/schemas/StreamResponse"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Input validation error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Input validation error"
+ }
+ }
+ }
+ },
+ "424": {
+ "description": "Generation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Request failed during generation"
+ }
+ }
+ }
+ },
+ "429": {
+ "description": "Model is overloaded",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Model is overloaded"
+ }
+ }
+ }
+ },
+ "500": {
+ "description": "Incomplete generation",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Incomplete generation"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "/generate": {
+ "post": {
+ "tags": [
+ "Text Generation Inference"
+ ],
+ "summary": "Generate tokens",
+ "description": "Generate tokens",
+ "operationId": "generate",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/GenerateRequest"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "200": {
+ "description": "Generated Text",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/GenerateResponse"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Input validation error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Input validation error"
+ }
+ }
+ }
+ },
+ "424": {
+ "description": "Generation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Request failed during generation"
+ }
+ }
+ }
+ },
+ "429": {
+ "description": "Model is overloaded",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Model is overloaded"
+ }
+ }
+ }
+ },
+ "500": {
+ "description": "Incomplete generation",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Incomplete generation"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "/generate_stream": {
+ "post": {
+ "tags": [
+ "Text Generation Inference"
+ ],
+ "summary": "Generate a stream of token using Server-Sent Events",
+ "description": "Generate a stream of token using Server-Sent Events",
+ "operationId": "generate_stream",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/GenerateRequest"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "200": {
+ "description": "Generated Text",
+ "content": {
+ "text/event-stream": {
+ "schema": {
+ "$ref": "#/components/schemas/StreamResponse"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Input validation error",
+ "content": {
+ "text/event-stream": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Input validation error"
+ }
+ }
+ }
+ },
+ "424": {
+ "description": "Generation Error",
+ "content": {
+ "text/event-stream": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Request failed during generation"
+ }
+ }
+ }
+ },
+ "429": {
+ "description": "Model is overloaded",
+ "content": {
+ "text/event-stream": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Model is overloaded"
+ }
+ }
+ }
+ },
+ "500": {
+ "description": "Incomplete generation",
+ "content": {
+ "text/event-stream": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Incomplete generation"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "/health": {
+ "get": {
+ "tags": [
+ "Text Generation Inference"
+ ],
+ "summary": "Health check method",
+ "description": "Health check method",
+ "operationId": "health",
+ "responses": {
+ "200": {
+ "description": "Everything is working fine"
+ },
+ "503": {
+ "description": "Text generation inference is down",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "unhealthy",
+ "error_type": "healthcheck"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "/info": {
+ "get": {
+ "tags": [
+ "Text Generation Inference"
+ ],
+ "summary": "Text Generation Inference endpoint info",
+ "description": "Text Generation Inference endpoint info",
+ "operationId": "get_model_info",
+ "responses": {
+ "200": {
+ "description": "Served model info",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Info"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "/metrics": {
+ "get": {
+ "tags": [
+ "Text Generation Inference"
+ ],
+ "summary": "Prometheus metrics scrape endpoint",
+ "description": "Prometheus metrics scrape endpoint",
+ "operationId": "metrics",
+ "responses": {
+ "200": {
+ "description": "Prometheus Metrics",
+ "content": {
+ "text/plain": {
+ "schema": {
+ "type": "string"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "/tokenize": {
+ "post": {
+ "tags": [
+ "Text Generation Inference"
+ ],
+ "summary": "Tokenize inputs",
+ "description": "Tokenize inputs",
+ "operationId": "tokenize",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/GenerateRequest"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "200": {
+ "description": "Tokenized ids",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/TokenizeResponse"
+ }
+ }
+ }
+ },
+ "404": {
+ "description": "No tokenizer found",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "No fast tokenizer available"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "/v1/chat/completions": {
+ "post": {
+ "tags": [
+ "Text Generation Inference"
+ ],
+ "summary": "Generate tokens",
+ "description": "Generate tokens",
+ "operationId": "chat_completions",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ChatRequest"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "200": {
+ "description": "Generated Text",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ChatCompletionChunk"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Input validation error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Input validation error"
+ }
+ }
+ }
+ },
+ "424": {
+ "description": "Generation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Request failed during generation"
+ }
+ }
+ }
+ },
+ "429": {
+ "description": "Model is overloaded",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Model is overloaded"
+ }
+ }
+ }
+ },
+ "500": {
+ "description": "Incomplete generation",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Incomplete generation"
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "components": {
+ "schemas": {
+ "BestOfSequence": {
+ "type": "object",
+ "required": [
+ "generated_text",
+ "finish_reason",
+ "generated_tokens",
+ "prefill",
+ "tokens"
+ ],
+ "properties": {
+ "finish_reason": {
+ "$ref": "#/components/schemas/FinishReason"
+ },
+ "generated_text": {
+ "type": "string",
+ "example": "test"
+ },
+ "generated_tokens": {
+ "type": "integer",
+ "format": "int32",
+ "example": 1,
+ "minimum": 0
+ },
+ "prefill": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/PrefillToken"
+ }
+ },
+ "seed": {
+ "type": "integer",
+ "format": "int64",
+ "example": 42,
+ "nullable": true,
+ "minimum": 0
+ },
+ "tokens": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Token"
+ }
+ },
+ "top_tokens": {
+ "type": "array",
+ "items": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Token"
+ }
+ }
+ }
+ }
+ },
+ "ChatCompletion": {
+ "type": "object",
+ "required": [
+ "id",
+ "object",
+ "created",
+ "model",
+ "system_fingerprint",
+ "choices",
+ "usage"
+ ],
+ "properties": {
+ "choices": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ChatCompletionComplete"
+ }
+ },
+ "created": {
+ "type": "integer",
+ "format": "int64",
+ "example": "1706270835",
+ "minimum": 0
+ },
+ "id": {
+ "type": "string"
+ },
+ "model": {
+ "type": "string",
+ "example": "mistralai/Mistral-7B-Instruct-v0.2"
+ },
+ "object": {
+ "type": "string"
+ },
+ "system_fingerprint": {
+ "type": "string"
+ },
+ "usage": {
+ "$ref": "#/components/schemas/Usage"
+ }
+ }
+ },
+ "ChatCompletionChoice": {
+ "type": "object",
+ "required": [
+ "index",
+ "delta"
+ ],
+ "properties": {
+ "delta": {
+ "$ref": "#/components/schemas/ChatCompletionDelta"
+ },
+ "finish_reason": {
+ "type": "string",
+ "nullable": true
+ },
+ "index": {
+ "type": "integer",
+ "format": "int32",
+ "minimum": 0
+ },
+ "logprobs": {
+ "type": "number",
+ "format": "float",
+ "nullable": true
+ }
+ }
+ },
+ "ChatCompletionChunk": {
+ "type": "object",
+ "required": [
+ "id",
+ "object",
+ "created",
+ "model",
+ "system_fingerprint",
+ "choices"
+ ],
+ "properties": {
+ "choices": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ChatCompletionChoice"
+ }
+ },
+ "created": {
+ "type": "integer",
+ "format": "int64",
+ "example": "1706270978",
+ "minimum": 0
+ },
+ "id": {
+ "type": "string"
+ },
+ "model": {
+ "type": "string",
+ "example": "mistralai/Mistral-7B-Instruct-v0.2"
+ },
+ "object": {
+ "type": "string"
+ },
+ "system_fingerprint": {
+ "type": "string"
+ }
+ }
+ },
+ "ChatCompletionDelta": {
+ "type": "object",
+ "required": [
+ "role",
+ "content"
+ ],
+ "properties": {
+ "content": {
+ "type": "string",
+ "example": "What is Deep Learning?"
+ },
+ "role": {
+ "type": "string",
+ "example": "user"
+ }
+ }
+ },
+ "ChatRequest": {
+ "type": "object",
+ "required": [
+ "model"
+ ],
+ "properties": {
+ "frequency_penalty": {
+ "type": "number",
+ "format": "float",
+ "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.",
+ "example": "1.0",
+ "nullable": true
+ },
+ "logit_bias": {
+ "type": "array",
+ "items": {
+ "type": "number",
+ "format": "float"
+ },
+ "description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
+ "nullable": true
+ },
+ "logprobs": {
+ "type": "boolean",
+ "description": "Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each\noutput token returned in the content of message.",
+ "example": "false",
+ "nullable": true
+ },
+ "max_tokens": {
+ "type": "integer",
+ "format": "int32",
+ "description": "The maximum number of tokens that can be generated in the chat completion.",
+ "example": "32",
+ "nullable": true,
+ "minimum": 0
+ },
+ "messages": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Message"
+ },
+ "description": "A list of messages comprising the conversation so far."
+ },
+ "model": {
+ "type": "string",
+ "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
+ "example": "mistralai/Mistral-7B-Instruct-v0.2"
+ },
+ "n": {
+ "type": "integer",
+ "format": "int32",
+ "description": "UNUSED\nHow many chat completion choices to generate for each input message. Note that you will be charged based on the\nnumber of generated tokens across all of the choices. Keep n as 1 to minimize costs.",
+ "example": "2",
+ "nullable": true,
+ "minimum": 0
+ },
+ "presence_penalty": {
+ "type": "number",
+ "format": "float",
+ "description": "UNUSED\nNumber between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics",
+ "example": 0.1,
+ "nullable": true
+ },
+ "seed": {
+ "type": "integer",
+ "format": "int64",
+ "example": 42,
+ "nullable": true,
+ "minimum": 0
+ },
+ "stream": {
+ "type": "boolean"
+ },
+ "temperature": {
+ "type": "number",
+ "format": "float",
+ "description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\nlower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.",
+ "example": 1.0,
+ "nullable": true
+ },
+ "top_logprobs": {
+ "type": "integer",
+ "format": "int32",
+ "description": "UNUSED\nAn integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.",
+ "example": "5",
+ "nullable": true,
+ "minimum": 0
+ },
+ "top_p": {
+ "type": "number",
+ "format": "float",
+ "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
+ "example": 0.95,
+ "nullable": true
+ }
+ }
+ },
+ "CompatGenerateRequest": {
+ "type": "object",
+ "required": [
+ "inputs"
+ ],
+ "properties": {
+ "inputs": {
+ "type": "string",
+ "example": "My name is Olivier and I"
+ },
+ "parameters": {
+ "$ref": "#/components/schemas/GenerateParameters"
+ },
+ "stream": {
+ "type": "boolean",
+ "default": "false"
+ }
+ }
+ },
+ "Details": {
+ "type": "object",
+ "required": [
+ "finish_reason",
+ "generated_tokens",
+ "prefill",
+ "tokens"
+ ],
+ "properties": {
+ "best_of_sequences": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/BestOfSequence"
+ },
+ "nullable": true
+ },
+ "finish_reason": {
+ "$ref": "#/components/schemas/FinishReason"
+ },
+ "generated_tokens": {
+ "type": "integer",
+ "format": "int32",
+ "example": 1,
+ "minimum": 0
+ },
+ "prefill": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/PrefillToken"
+ }
+ },
+ "seed": {
+ "type": "integer",
+ "format": "int64",
+ "example": 42,
+ "nullable": true,
+ "minimum": 0
+ },
+ "tokens": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Token"
+ }
+ },
+ "top_tokens": {
+ "type": "array",
+ "items": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Token"
+ }
+ }
+ }
+ }
+ },
+ "ErrorResponse": {
+ "type": "object",
+ "required": [
+ "error",
+ "error_type"
+ ],
+ "properties": {
+ "error": {
+ "type": "string"
+ },
+ "error_type": {
+ "type": "string"
+ }
+ }
+ },
+ "FinishReason": {
+ "type": "string",
+ "enum": [
+ "length",
+ "eos_token",
+ "stop_sequence"
+ ],
+ "example": "Length"
+ },
+ "GenerateParameters": {
+ "type": "object",
+ "properties": {
+ "best_of": {
+ "type": "integer",
+ "default": "null",
+ "example": 1,
+ "nullable": true,
+ "minimum": 0,
+ "exclusiveMinimum": 0
+ },
+ "decoder_input_details": {
+ "type": "boolean",
+ "default": "true"
+ },
+ "details": {
+ "type": "boolean",
+ "default": "true"
+ },
+ "do_sample": {
+ "type": "boolean",
+ "default": "false",
+ "example": true
+ },
+ "max_new_tokens": {
+ "type": "integer",
+ "format": "int32",
+ "default": "100",
+ "example": "20",
+ "nullable": true,
+ "minimum": 0
+ },
+ "repetition_penalty": {
+ "type": "number",
+ "format": "float",
+ "default": "null",
+ "example": 1.03,
+ "nullable": true,
+ "exclusiveMinimum": 0
+ },
+ "return_full_text": {
+ "type": "boolean",
+ "default": "null",
+ "example": false,
+ "nullable": true
+ },
+ "seed": {
+ "type": "integer",
+ "format": "int64",
+ "default": "null",
+ "example": "null",
+ "nullable": true,
+ "minimum": 0,
+ "exclusiveMinimum": 0
+ },
+ "stop": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ },
+ "example": [
+ "photographer"
+ ],
+ "maxItems": 4
+ },
+ "temperature": {
+ "type": "number",
+ "format": "float",
+ "default": "null",
+ "example": 0.5,
+ "nullable": true,
+ "exclusiveMinimum": 0
+ },
+ "top_k": {
+ "type": "integer",
+ "format": "int32",
+ "default": "null",
+ "example": 10,
+ "nullable": true,
+ "exclusiveMinimum": 0
+ },
+ "top_n_tokens": {
+ "type": "integer",
+ "format": "int32",
+ "default": "null",
+ "example": 5,
+ "nullable": true,
+ "minimum": 0,
+ "exclusiveMinimum": 0
+ },
+ "top_p": {
+ "type": "number",
+ "format": "float",
+ "default": "null",
+ "example": 0.95,
+ "nullable": true,
+ "maximum": 1,
+ "exclusiveMinimum": 0
+ },
+ "truncate": {
+ "type": "integer",
+ "default": "null",
+ "example": "null",
+ "nullable": true,
+ "minimum": 0
+ },
+ "typical_p": {
+ "type": "number",
+ "format": "float",
+ "default": "null",
+ "example": 0.95,
+ "nullable": true,
+ "maximum": 1,
+ "exclusiveMinimum": 0
+ },
+ "watermark": {
+ "type": "boolean",
+ "default": "false",
+ "example": true
+ }
+ }
+ },
+ "GenerateRequest": {
+ "type": "object",
+ "required": [
+ "inputs"
+ ],
+ "properties": {
+ "inputs": {
+ "type": "string",
+ "example": "My name is Olivier and I"
+ },
+ "parameters": {
+ "$ref": "#/components/schemas/GenerateParameters"
+ }
+ }
+ },
+ "GenerateResponse": {
+ "type": "object",
+ "required": [
+ "generated_text"
+ ],
+ "properties": {
+ "details": {
+ "allOf": [
+ {
+ "$ref": "#/components/schemas/Details"
+ }
+ ],
+ "nullable": true
+ },
+ "generated_text": {
+ "type": "string",
+ "example": "test"
+ }
+ }
+ },
+ "Info": {
+ "type": "object",
+ "required": [
+ "model_id",
+ "model_dtype",
+ "model_device_type",
+ "max_concurrent_requests",
+ "max_best_of",
+ "max_stop_sequences",
+ "max_input_length",
+ "max_total_tokens",
+ "waiting_served_ratio",
+ "max_batch_total_tokens",
+ "max_waiting_tokens",
+ "validation_workers",
+ "version"
+ ],
+ "properties": {
+ "docker_label": {
+ "type": "string",
+ "example": "null",
+ "nullable": true
+ },
+ "max_batch_total_tokens": {
+ "type": "integer",
+ "format": "int32",
+ "example": "32000",
+ "minimum": 0
+ },
+ "max_best_of": {
+ "type": "integer",
+ "example": "2",
+ "minimum": 0
+ },
+ "max_concurrent_requests": {
+ "type": "integer",
+ "description": "Router Parameters",
+ "example": "128",
+ "minimum": 0
+ },
+ "max_input_length": {
+ "type": "integer",
+ "example": "1024",
+ "minimum": 0
+ },
+ "max_stop_sequences": {
+ "type": "integer",
+ "example": "4",
+ "minimum": 0
+ },
+ "max_total_tokens": {
+ "type": "integer",
+ "example": "2048",
+ "minimum": 0
+ },
+ "max_waiting_tokens": {
+ "type": "integer",
+ "example": "20",
+ "minimum": 0
+ },
+ "model_device_type": {
+ "type": "string",
+ "example": "cuda"
+ },
+ "model_dtype": {
+ "type": "string",
+ "example": "torch.float16"
+ },
+ "model_id": {
+ "type": "string",
+ "description": "Model info",
+ "example": "bigscience/blomm-560m"
+ },
+ "model_pipeline_tag": {
+ "type": "string",
+ "example": "text-generation",
+ "nullable": true
+ },
+ "model_sha": {
+ "type": "string",
+ "example": "e985a63cdc139290c5f700ff1929f0b5942cced2",
+ "nullable": true
+ },
+ "sha": {
+ "type": "string",
+ "example": "null",
+ "nullable": true
+ },
+ "validation_workers": {
+ "type": "integer",
+ "example": "2",
+ "minimum": 0
+ },
+ "version": {
+ "type": "string",
+ "description": "Router Info",
+ "example": "0.5.0"
+ },
+ "waiting_served_ratio": {
+ "type": "number",
+ "format": "float",
+ "example": "1.2"
+ }
+ }
+ },
+ "Message": {
+ "type": "object",
+ "required": [
+ "role",
+ "content"
+ ],
+ "properties": {
+ "content": {
+ "type": "string",
+ "example": "My name is David and I"
+ },
+ "role": {
+ "type": "string",
+ "example": "user"
+ }
+ }
+ },
+ "PrefillToken": {
+ "type": "object",
+ "required": [
+ "id",
+ "text",
+ "logprob"
+ ],
+ "properties": {
+ "id": {
+ "type": "integer",
+ "format": "int32",
+ "example": 0,
+ "minimum": 0
+ },
+ "logprob": {
+ "type": "number",
+ "format": "float",
+ "example": -0.34,
+ "nullable": true
+ },
+ "text": {
+ "type": "string",
+ "example": "test"
+ }
+ }
+ },
+ "SimpleToken": {
+ "type": "object",
+ "required": [
+ "id",
+ "text",
+ "start",
+ "stop"
+ ],
+ "properties": {
+ "id": {
+ "type": "integer",
+ "format": "int32",
+ "example": 0,
+ "minimum": 0
+ },
+ "start": {
+ "type": "integer",
+ "example": 0,
+ "minimum": 0
+ },
+ "stop": {
+ "type": "integer",
+ "example": 2,
+ "minimum": 0
+ },
+ "text": {
+ "type": "string",
+ "example": "test"
+ }
+ }
+ },
+ "StreamDetails": {
+ "type": "object",
+ "required": [
+ "finish_reason",
+ "generated_tokens"
+ ],
+ "properties": {
+ "finish_reason": {
+ "$ref": "#/components/schemas/FinishReason"
+ },
+ "generated_tokens": {
+ "type": "integer",
+ "format": "int32",
+ "example": 1,
+ "minimum": 0
+ },
+ "seed": {
+ "type": "integer",
+ "format": "int64",
+ "example": 42,
+ "nullable": true,
+ "minimum": 0
+ }
+ }
+ },
+ "StreamResponse": {
+ "type": "object",
+ "required": [
+ "index",
+ "token"
+ ],
+ "properties": {
+ "details": {
+ "allOf": [
+ {
+ "$ref": "#/components/schemas/StreamDetails"
+ }
+ ],
+ "default": "null",
+ "nullable": true
+ },
+ "generated_text": {
+ "type": "string",
+ "default": "null",
+ "example": "test",
+ "nullable": true
+ },
+ "index": {
+ "type": "integer",
+ "format": "int32",
+ "minimum": 0
+ },
+ "token": {
+ "$ref": "#/components/schemas/Token"
+ },
+ "top_tokens": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Token"
+ }
+ }
+ }
+ },
+ "Token": {
+ "type": "object",
+ "required": [
+ "id",
+ "text",
+ "logprob",
+ "special"
+ ],
+ "properties": {
+ "id": {
+ "type": "integer",
+ "format": "int32",
+ "example": 0,
+ "minimum": 0
+ },
+ "logprob": {
+ "type": "number",
+ "format": "float",
+ "example": -0.34,
+ "nullable": true
+ },
+ "special": {
+ "type": "boolean",
+ "example": "false"
+ },
+ "text": {
+ "type": "string",
+ "example": "test"
+ }
+ }
+ },
+ "TokenizeResponse": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/SimpleToken"
+ }
+ }
+ }
+ },
+ "tags": [
+ {
+ "name": "Text Generation Inference",
+ "description": "Hugging Face Text Generation Inference API"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/source/basic_tutorials/gated_model_access.md b/docs/source/basic_tutorials/gated_model_access.md
index 1437717f..060d177d 100644
--- a/docs/source/basic_tutorials/gated_model_access.md
+++ b/docs/source/basic_tutorials/gated_model_access.md
@@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \
-e HUGGING_FACE_HUB_TOKEN=$token \
-p 8080:80 \
- -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 \
+ -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 \
--model-id $model
```
diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md
index e9a33f04..78ebb8e2 100644
--- a/docs/source/quicktour.md
+++ b/docs/source/quicktour.md
@@ -8,7 +8,7 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
model=tiiuae/falcon-7b-instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
-docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
+docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
```
@@ -20,7 +20,7 @@ To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://d
TGI also supports ROCm-enabled AMD GPUs (only MI210 and MI250 are tested), details are available in the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). To launch TGI on ROCm GPUs, please use instead:
```bash
-docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3-rocm --model-id $model
+docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model
```
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
@@ -91,7 +91,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash
-docker run ghcr.io/huggingface/text-generation-inference:1.3 --help
+docker run ghcr.io/huggingface/text-generation-inference:1.4 --help
```
diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py
index 6391f2a1..0987b3a1 100644
--- a/integration-tests/models/test_flash_phi.py
+++ b/integration-tests/models/test_flash_phi.py
@@ -21,7 +21,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
)
assert response.details.generated_tokens == 10
- assert response.generated_text == ": {request}\")\n response = self"
+ assert response.generated_text == ': {request}")\n response = self'
assert response == response_snapshot
@@ -52,14 +52,12 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
- responses = await generate_load(
- flash_phi, "Test request", max_new_tokens=10, n=4
- )
+ responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
- assert responses[0].generated_text == ": {request}\")\n response = self"
+ assert responses[0].generated_text == ': {request}")\n response = self'
assert responses == response_snapshot
diff --git a/integration-tests/pyproject.toml b/integration-tests/pyproject.toml
index f6929587..f0c5add9 100644
--- a/integration-tests/pyproject.toml
+++ b/integration-tests/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-integration-tests"
-version = "1.3.4"
+version = "1.4.0"
description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry "]
diff --git a/server/poetry.lock b/server/poetry.lock
index 360eeb36..16a28a01 100644
--- a/server/poetry.lock
+++ b/server/poetry.lock
@@ -1812,13 +1812,13 @@ xmp = ["defusedxml"]
[[package]]
name = "pluggy"
-version = "1.4.0"
+version = "1.5.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"},
- {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"},
+ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
+ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
]
[package.extras]
@@ -1886,51 +1886,51 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
[[package]]
name = "pyarrow"
-version = "15.0.2"
+version = "16.0.0"
description = "Python library for Apache Arrow"
optional = false
python-versions = ">=3.8"
files = [
- {file = "pyarrow-15.0.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:88b340f0a1d05b5ccc3d2d986279045655b1fe8e41aba6ca44ea28da0d1455d8"},
- {file = "pyarrow-15.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eaa8f96cecf32da508e6c7f69bb8401f03745c050c1dd42ec2596f2e98deecac"},
- {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23c6753ed4f6adb8461e7c383e418391b8d8453c5d67e17f416c3a5d5709afbd"},
- {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f639c059035011db8c0497e541a8a45d98a58dbe34dc8fadd0ef128f2cee46e5"},
- {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:290e36a59a0993e9a5224ed2fb3e53375770f07379a0ea03ee2fce2e6d30b423"},
- {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06c2bb2a98bc792f040bef31ad3e9be6a63d0cb39189227c08a7d955db96816e"},
- {file = "pyarrow-15.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:f7a197f3670606a960ddc12adbe8075cea5f707ad7bf0dffa09637fdbb89f76c"},
- {file = "pyarrow-15.0.2-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:5f8bc839ea36b1f99984c78e06e7a06054693dc2af8920f6fb416b5bca9944e4"},
- {file = "pyarrow-15.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f5e81dfb4e519baa6b4c80410421528c214427e77ca0ea9461eb4097c328fa33"},
- {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4f240852b302a7af4646c8bfe9950c4691a419847001178662a98915fd7ee7"},
- {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e7d9cfb5a1e648e172428c7a42b744610956f3b70f524aa3a6c02a448ba853e"},
- {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2d4f905209de70c0eb5b2de6763104d5a9a37430f137678edfb9a675bac9cd98"},
- {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:90adb99e8ce5f36fbecbbc422e7dcbcbed07d985eed6062e459e23f9e71fd197"},
- {file = "pyarrow-15.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:b116e7fd7889294cbd24eb90cd9bdd3850be3738d61297855a71ac3b8124ee38"},
- {file = "pyarrow-15.0.2-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:25335e6f1f07fdaa026a61c758ee7d19ce824a866b27bba744348fa73bb5a440"},
- {file = "pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90f19e976d9c3d8e73c80be84ddbe2f830b6304e4c576349d9360e335cd627fc"},
- {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a22366249bf5fd40ddacc4f03cd3160f2d7c247692945afb1899bab8a140ddfb"},
- {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2a335198f886b07e4b5ea16d08ee06557e07db54a8400cc0d03c7f6a22f785f"},
- {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e6d459c0c22f0b9c810a3917a1de3ee704b021a5fb8b3bacf968eece6df098f"},
- {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:033b7cad32198754d93465dcfb71d0ba7cb7cd5c9afd7052cab7214676eec38b"},
- {file = "pyarrow-15.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:29850d050379d6e8b5a693098f4de7fd6a2bea4365bfd073d7c57c57b95041ee"},
- {file = "pyarrow-15.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:7167107d7fb6dcadb375b4b691b7e316f4368f39f6f45405a05535d7ad5e5058"},
- {file = "pyarrow-15.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e85241b44cc3d365ef950432a1b3bd44ac54626f37b2e3a0cc89c20e45dfd8bf"},
- {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:248723e4ed3255fcd73edcecc209744d58a9ca852e4cf3d2577811b6d4b59818"},
- {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ff3bdfe6f1b81ca5b73b70a8d482d37a766433823e0c21e22d1d7dde76ca33f"},
- {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f3d77463dee7e9f284ef42d341689b459a63ff2e75cee2b9302058d0d98fe142"},
- {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:8c1faf2482fb89766e79745670cbca04e7018497d85be9242d5350cba21357e1"},
- {file = "pyarrow-15.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:28f3016958a8e45a1069303a4a4f6a7d4910643fc08adb1e2e4a7ff056272ad3"},
- {file = "pyarrow-15.0.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:89722cb64286ab3d4daf168386f6968c126057b8c7ec3ef96302e81d8cdb8ae4"},
- {file = "pyarrow-15.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd0ba387705044b3ac77b1b317165c0498299b08261d8122c96051024f953cd5"},
- {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad2459bf1f22b6a5cdcc27ebfd99307d5526b62d217b984b9f5c974651398832"},
- {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58922e4bfece8b02abf7159f1f53a8f4d9f8e08f2d988109126c17c3bb261f22"},
- {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:adccc81d3dc0478ea0b498807b39a8d41628fa9210729b2f718b78cb997c7c91"},
- {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8bd2baa5fe531571847983f36a30ddbf65261ef23e496862ece83bdceb70420d"},
- {file = "pyarrow-15.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6669799a1d4ca9da9c7e06ef48368320f5856f36f9a4dd31a11839dda3f6cc8c"},
- {file = "pyarrow-15.0.2.tar.gz", hash = "sha256:9c9bc803cb3b7bfacc1e96ffbfd923601065d9d3f911179d81e72d99fd74a3d9"},
+ {file = "pyarrow-16.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:22a1fdb1254e5095d629e29cd1ea98ed04b4bbfd8e42cc670a6b639ccc208b60"},
+ {file = "pyarrow-16.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:574a00260a4ed9d118a14770edbd440b848fcae5a3024128be9d0274dbcaf858"},
+ {file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0815d0ddb733b8c1b53a05827a91f1b8bde6240f3b20bf9ba5d650eb9b89cdf"},
+ {file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df0080339387b5d30de31e0a149c0c11a827a10c82f0c67d9afae3981d1aabb7"},
+ {file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:edf38cce0bf0dcf726e074159c60516447e4474904c0033f018c1f33d7dac6c5"},
+ {file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:91d28f9a40f1264eab2af7905a4d95320ac2f287891e9c8b0035f264fe3c3a4b"},
+ {file = "pyarrow-16.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:99af421ee451a78884d7faea23816c429e263bd3618b22d38e7992c9ce2a7ad9"},
+ {file = "pyarrow-16.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d22d0941e6c7bafddf5f4c0662e46f2075850f1c044bf1a03150dd9e189427ce"},
+ {file = "pyarrow-16.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:266ddb7e823f03733c15adc8b5078db2df6980f9aa93d6bb57ece615df4e0ba7"},
+ {file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cc23090224b6594f5a92d26ad47465af47c1d9c079dd4a0061ae39551889efe"},
+ {file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56850a0afe9ef37249d5387355449c0f94d12ff7994af88f16803a26d38f2016"},
+ {file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:705db70d3e2293c2f6f8e84874b5b775f690465798f66e94bb2c07bab0a6bb55"},
+ {file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:5448564754c154997bc09e95a44b81b9e31ae918a86c0fcb35c4aa4922756f55"},
+ {file = "pyarrow-16.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:729f7b262aa620c9df8b9967db96c1575e4cfc8c25d078a06968e527b8d6ec05"},
+ {file = "pyarrow-16.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:fb8065dbc0d051bf2ae2453af0484d99a43135cadabacf0af588a3be81fbbb9b"},
+ {file = "pyarrow-16.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:20ce707d9aa390593ea93218b19d0eadab56390311cb87aad32c9a869b0e958c"},
+ {file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5823275c8addbbb50cd4e6a6839952682a33255b447277e37a6f518d6972f4e1"},
+ {file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ab8b9050752b16a8b53fcd9853bf07d8daf19093533e990085168f40c64d978"},
+ {file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:42e56557bc7c5c10d3e42c3b32f6cff649a29d637e8f4e8b311d334cc4326730"},
+ {file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:2a7abdee4a4a7cfa239e2e8d721224c4b34ffe69a0ca7981354fe03c1328789b"},
+ {file = "pyarrow-16.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:ef2f309b68396bcc5a354106741d333494d6a0d3e1951271849787109f0229a6"},
+ {file = "pyarrow-16.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:ed66e5217b4526fa3585b5e39b0b82f501b88a10d36bd0d2a4d8aa7b5a48e2df"},
+ {file = "pyarrow-16.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc8814310486f2a73c661ba8354540f17eef51e1b6dd090b93e3419d3a097b3a"},
+ {file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c2f5e239db7ed43e0ad2baf46a6465f89c824cc703f38ef0fde927d8e0955f7"},
+ {file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f293e92d1db251447cb028ae12f7bc47526e4649c3a9924c8376cab4ad6b98bd"},
+ {file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:dd9334a07b6dc21afe0857aa31842365a62eca664e415a3f9536e3a8bb832c07"},
+ {file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d91073d1e2fef2c121154680e2ba7e35ecf8d4969cc0af1fa6f14a8675858159"},
+ {file = "pyarrow-16.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:71d52561cd7aefd22cf52538f262850b0cc9e4ec50af2aaa601da3a16ef48877"},
+ {file = "pyarrow-16.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:b93c9a50b965ee0bf4fef65e53b758a7e8dcc0c2d86cebcc037aaaf1b306ecc0"},
+ {file = "pyarrow-16.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d831690844706e374c455fba2fb8cfcb7b797bfe53ceda4b54334316e1ac4fa4"},
+ {file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35692ce8ad0b8c666aa60f83950957096d92f2a9d8d7deda93fb835e6053307e"},
+ {file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dd3151d098e56f16a8389c1247137f9e4c22720b01c6f3aa6dec29a99b74d80"},
+ {file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bd40467bdb3cbaf2044ed7a6f7f251c8f941c8b31275aaaf88e746c4f3ca4a7a"},
+ {file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:00a1dcb22ad4ceb8af87f7bd30cc3354788776c417f493089e0a0af981bc8d80"},
+ {file = "pyarrow-16.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fda9a7cebd1b1d46c97b511f60f73a5b766a6de4c5236f144f41a5d5afec1f35"},
+ {file = "pyarrow-16.0.0.tar.gz", hash = "sha256:59bb1f1edbbf4114c72415f039f1359f1a57d166a331c3229788ccbfbb31689a"},
]
[package.dependencies]
-numpy = ">=1.16.6,<2"
+numpy = ">=1.16.6"
[[package]]
name = "pyarrow-hotfix"
diff --git a/server/pyproject.toml b/server/pyproject.toml
index d6806848..60bd399a 100644
--- a/server/pyproject.toml
+++ b/server/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-server"
-version = "1.3.4"
+version = "1.4.0"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene "]
diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt
index 694242e1..e9267512 100644
--- a/server/requirements_cuda.txt
+++ b/server/requirements_cuda.txt
@@ -13,11 +13,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
-hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
+hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
-numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
+numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@@ -28,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
-pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
-protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
+pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
+protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
-regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
+regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
-scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
+scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
-setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
-tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
+setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
+tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
-transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
+transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt
index e0495fde..053429c9 100644
--- a/server/requirements_rocm.txt
+++ b/server/requirements_rocm.txt
@@ -12,11 +12,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
-hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
+hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
-numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
+numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@@ -27,18 +27,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
-pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
-protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
+pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
+protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
-regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
+regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
-scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
+scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
-setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
-tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
+setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
+tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
-transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
+transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py
index 0a9fecd1..93a0e982 100644
--- a/server/tests/utils/test_layers.py
+++ b/server/tests/utils/test_layers.py
@@ -3,24 +3,27 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding,
)
+
class ProcessGroup:
def __init__(self, rank: int, world_size: int):
self._rank = rank
self.world_size = world_size
- def size(self)->int:
+ def size(self) -> int:
return self.world_size
- def rank(self)->int:
+ def rank(self) -> int:
return self._rank
+
class Weights:
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
- self.weight = torch.arange(vocab_size*hidden_dim).float().view(vocab_size, hidden_dim)
+ self.weight = (
+ torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
+ )
self.process_group = ProcessGroup(rank, world_size)
-
- def get_partial_sharded(self, name:str, dim: int):
+ def get_partial_sharded(self, name: str, dim: int):
assert dim == 0
rank = self.process_group.rank()
@@ -35,10 +38,11 @@ class Weights:
def get_shape(self, name: str):
return self.weight.shape
+
def test_weight_hub_files_offline_error():
- vocab_size= 17
- weights = Weights(rank=0, world_size=1, vocab_size = vocab_size,hidden_dim = 256)
+ vocab_size = 17
+ weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size)
@@ -47,18 +51,27 @@ def test_weight_hub_files_offline_error():
assert embeddings.max_id == 17
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
- weights_0_2 = Weights(rank=0, world_size=2, vocab_size = vocab_size,hidden_dim = 256)
- weights_1_2 = Weights(rank=1, world_size=2, vocab_size = vocab_size,hidden_dim = 256)
+ weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
+ weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
assert embeddings_0_2.min_id == 0
assert embeddings_0_2.max_id == 9
- torch.testing.assert_close(embeddings_0_2.weight , torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0).view(10, 256).float())
+ torch.testing.assert_close(
+ embeddings_0_2.weight,
+ torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
+ .view(10, 256)
+ .float(),
+ )
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
assert embeddings_1_2.min_id == 9
assert embeddings_1_2.max_id == 17
- torch.testing.assert_close(embeddings_1_2.weight , torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0).view(9, 256).float())
+ torch.testing.assert_close(
+ embeddings_1_2.weight,
+ torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
+ .view(9, 256)
+ .float(),
+ )
output_tp_0 = embeddings_0_2.forward(input_ids)
output_tp_1 = embeddings_1_2.forward(input_ids)
torch.testing.assert_close(output, output_tp_0 + output_tp_1)
-
diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py
index 9ec3ce20..905b7e69 100644
--- a/server/text_generation_server/cli.py
+++ b/server/text_generation_server/cli.py
@@ -270,7 +270,7 @@ def download_weights(
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
-
+
elif (Path(model_id) / "adapter_config.json").exists():
# Try to load as a local PEFT model
try:
diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
index d103973f..96701794 100644
--- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
@@ -17,6 +17,7 @@ from text_generation_server.utils.layers import (
FastLayerNorm,
)
+
class PhiConfig(PretrainedConfig):
def __init__(
self,
@@ -25,15 +26,15 @@ class PhiConfig(PretrainedConfig):
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
- hidden_act="gelu_fast", # llama uses silu
- layer_norm_eps=1e-05, # rms in llama,
+ hidden_act="gelu_fast", # llama uses silu
+ layer_norm_eps=1e-05, # rms in llama,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
- resid_pdrop=0.1, # llama doesn't have this
- partial_rotary_factor=0.5, # important difference between llama and phi
+ resid_pdrop=0.1, # llama doesn't have this
+ partial_rotary_factor=0.5, # important difference between llama and phi
**kwargs,
):
self.vocab_size = vocab_size
@@ -55,6 +56,7 @@ class PhiConfig(PretrainedConfig):
**kwargs,
)
+
# this is the same as llama except for Phi uses bias=True
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
@@ -68,6 +70,7 @@ def load_attention(config, prefix, weights):
bias=True,
)
+
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
@@ -94,6 +97,7 @@ def _load_gqa(config, prefix: str, weights):
get_linear(weight, bias=True, quantize=config.quantize)
)
+
class FlashPhiAttention(torch.nn.Module):
def __init__(
self,
@@ -173,8 +177,7 @@ class FlashPhiAttention(torch.nn.Module):
#
# Apply partial positional embeddings in place
self.rotary_emb(
- query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
- cos, sin
+ query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
)
# Reshape key and value and cache
@@ -210,7 +213,8 @@ class FlashPhiAttention(torch.nn.Module):
max_s,
)
- return self.dense(attn_output.view(-1, self.num_heads*self.head_size))
+ return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
+
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
@@ -256,7 +260,9 @@ class FlashPhiLayer(nn.Module):
)
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load(
- prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
)
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
@@ -287,10 +293,13 @@ class FlashPhiLayer(nn.Module):
max_s,
)
- hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))
+ hidden_states = self.resid_dropout(attn_output).add(
+ self.resid_dropout(self.mlp(hidden_states))
+ )
return hidden_states, res
+
class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
@@ -361,6 +370,7 @@ class FlashPhiModel(torch.nn.Module):
return hidden_states
+
class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
@@ -380,7 +390,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
- input_lengths: torch.Tensor,
+ input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py
index 1a9aef74..2c2fec48 100644
--- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py
@@ -54,9 +54,19 @@ def load_col(config, prefix, weights, bias):
bias_h = bias_h[0]
bias_block_size = bias_h // bias_size
- bias_q_part = bias_slice_[bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size]
- bias_k_part = bias_slice_[bias_h + bias_rank * bias_block_size : bias_h + (bias_rank + 1) * bias_block_size]
- bias_v_part = bias_slice_[2 * bias_h + bias_rank * bias_block_size : 2 * bias_h + (bias_rank + 1) * bias_block_size]
+ bias_q_part = bias_slice_[
+ bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size
+ ]
+ bias_k_part = bias_slice_[
+ bias_h
+ + bias_rank * bias_block_size : bias_h
+ + (bias_rank + 1) * bias_block_size
+ ]
+ bias_v_part = bias_slice_[
+ 2 * bias_h
+ + bias_rank * bias_block_size : 2 * bias_h
+ + (bias_rank + 1) * bias_block_size
+ ]
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
if bias.dtype != torch.int32:
@@ -352,8 +362,12 @@ class MultiheadAttention(nn.Module):
hidden_size = config.d_model
head_dim = hidden_size // self.n_heads
- self.q_ln = LPLayerNorm(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
- self.k_ln = LPLayerNorm(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
+ self.q_ln = LPLayerNorm(
+ d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights
+ )
+ self.k_ln = LPLayerNorm(
+ self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights
+ )
if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn
elif self.attn_impl == "triton":
@@ -684,7 +698,6 @@ class LPLayerNorm(torch.nn.LayerNorm):
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
self.normalized_shape = self.weight.shape
-
def forward(self, x):
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)
@@ -798,7 +811,7 @@ class MPTModel(MPTPreTrainedModel):
self.wte = TensorParallelEmbedding("transformer.wte", weights)
if not self.alibi:
- self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
+ self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
self.blocks = nn.ModuleList(
[
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py
index f9999537..e5c09728 100644
--- a/server/text_generation_server/models/custom_modeling/phi_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py
@@ -62,14 +62,12 @@ class PhiConfig(PretrainedConfig):
**kwargs,
)
+
# RotaryEmbedding is a class that implements the rotary embedding.
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
- inv_freq = [
- 1.0 / 10000.0 ** (i / dim)
- for i in range(0, dim, 2)
- ]
+ inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]
inv_freq_len = len(inv_freq)
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
@@ -131,6 +129,7 @@ class PhiCausalLMHead(nn.Module):
hidden_states = self.linear(hidden_states)
return hidden_states
+
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
class PhiMHA(nn.Module):
def __init__(self, prefix, config, weights):
@@ -172,19 +171,27 @@ class PhiMHA(nn.Module):
v = torch.cat([prev_v, v], dim=1)
past_kv_cache = [k, v]
- attn_weights = torch.einsum('bthd,bshd->bhts', q, k * self.softmax_scale)
+ attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale)
if attention_mask is not None:
seqlen_k = k.shape[1]
seqlen_q = q.shape[1]
- causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), 1)
+ causal_mask = torch.triu(
+ torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),
+ 1,
+ )
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
-
+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
- attn_output = attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)).transpose(1, 2).flatten(-2)
+ attn_output = (
+ attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))
+ .transpose(1, 2)
+ .flatten(-2)
+ )
return self.out_proj(attn_output), past_kv_cache
+
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
@@ -204,19 +211,22 @@ class PhiMLP(nn.Module):
bias=False,
)
self.activation = torch.nn.functional.gelu
-
+
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
+
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
class PhiBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.layer_id = layer_id
- self.layer_norm = nn.LayerNorm.load(prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon)
+ self.layer_norm = nn.LayerNorm.load(
+ prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon
+ )
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
@@ -228,11 +238,14 @@ class PhiBlock(nn.Module):
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
- attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
+ attn_outputs, past_kv_cache = self.mixer(
+ hidden_states, kv_cache, attention_mask
+ )
feed_forward_hidden_states = self.mlp(hidden_states)
out = attn_outputs + feed_forward_hidden_states + residual
return out, past_kv_cache
+
# PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module):
def __init__(self, config, weights):
@@ -241,9 +254,12 @@ class PhiModel(nn.Module):
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.embd.wte", weights=weights
- )
+ )
self.blocks = nn.ModuleList(
- [PhiBlock(f"transformer.h.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
+ [
+ PhiBlock(f"transformer.h.{layer_id}", config, weights)
+ for layer_id in range(config.n_layer)
+ ]
)
def forward(
@@ -258,14 +274,19 @@ class PhiModel(nn.Module):
seq_len = hidden_states.shape[1]
mask = None if seq_len <= 1 else attention_mask
- past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
+ past_key_values = (
+ [None] * len(self.blocks) if past_key_values is None else past_key_values
+ )
for index, block in enumerate(self.blocks):
- hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
+ hidden_states, new_key_values = block(
+ hidden_states, past_key_values[index], mask
+ )
past_key_values[index] = new_key_values
return hidden_states, past_key_values
+
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
@@ -290,12 +311,15 @@ class PhiForCausalLM(torch.nn.Module):
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(
- logits[:, :-1].view(-1, logits.size(-1)),
- labels[:, 1:].view(-1)
+ logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)
)
if not return_dict:
- return ((loss,) + (logits,) + model_output[1:]) if loss is not None else (logits,) + model_output[1:]
+ return (
+ ((loss,) + (logits,) + model_output[1:])
+ if loss is not None
+ else (logits,) + model_output[1:]
+ )
return CausalLMOutputWithPast(
loss=loss,
@@ -304,5 +328,3 @@ class PhiForCausalLM(torch.nn.Module):
hidden_states=None,
attentions=None,
)
-
-
diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py
index 7be61906..94bd58f4 100644
--- a/server/text_generation_server/models/flash_llama.py
+++ b/server/text_generation_server/models/flash_llama.py
@@ -73,11 +73,11 @@ class FlashLlama(FlashCausalLM):
import json
import os
from pathlib import Path
-
- is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
- "WEIGHTS_CACHE_OVERRIDE", None
- ) is not None
-
+
+ is_local_model = (
+ Path(use_medusa).exists() and Path(use_medusa).is_dir()
+ ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
+
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
@@ -88,7 +88,7 @@ class FlashLlama(FlashCausalLM):
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
-
+
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py
index 1c49f2a9..061b9740 100644
--- a/server/text_generation_server/models/flash_phi.py
+++ b/server/text_generation_server/models/flash_phi.py
@@ -63,11 +63,11 @@ class FlashPhi(FlashCausalLM):
import json
import os
from pathlib import Path
-
- is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
- "WEIGHTS_CACHE_OVERRIDE", None
- ) is not None
-
+
+ is_local_model = (
+ Path(use_medusa).exists() and Path(use_medusa).is_dir()
+ ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
+
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
@@ -78,7 +78,7 @@ class FlashPhi(FlashCausalLM):
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
-
+
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py
index d477478a..79aa3fb9 100644
--- a/server/text_generation_server/models/phi.py
+++ b/server/text_generation_server/models/phi.py
@@ -5,13 +5,17 @@ from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple
from text_generation_server.models import CausalLM
-from text_generation_server.models.custom_modeling.phi_modeling import PhiConfig, PhiForCausalLM
+from text_generation_server.models.custom_modeling.phi_modeling import (
+ PhiConfig,
+ PhiForCausalLM,
+)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
+
class Phi(CausalLM):
def __init__(
self,
@@ -60,4 +64,3 @@ class Phi(CausalLM):
dtype=dtype,
device=device,
)
-
diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py
index 6ddfd6f4..010d6143 100644
--- a/server/text_generation_server/utils/layers.py
+++ b/server/text_generation_server/utils/layers.py
@@ -510,7 +510,9 @@ class TensorParallelEmbedding(nn.Module):
block_size = (num_embeddings + world_size - 1) // world_size
self.min_id = rank * block_size
self.max_id = min(num_embeddings, (rank + 1) * block_size)
- self.null_idx = weight.shape[0] # Usually block_size, might be less in non even vocab_size.
+ self.null_idx = weight.shape[
+ 0
+ ] # Usually block_size, might be less in non even vocab_size.
self.process_group = weights.process_group
self.reduce = reduce