diff --git a/Dockerfile b/Dockerfile index b6c5b2ed..6818005f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -154,6 +154,12 @@ COPY server/Makefile-vllm Makefile # Build specific version of vllm RUN make build-vllm-cuda +# Build mamba kernels +FROM kernel-builder as mamba-builder +WORKDIR /usr/src +COPY server/Makefile-selective-scan Makefile +RUN make build-all + # Build megablocks FROM kernel-builder as megablocks-builder @@ -205,6 +211,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31 # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from mamba builder +COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages +COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages + # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 775e7a6c..1e25e1b1 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -21,22 +21,6 @@ def test_generate(flan_t5_xxl_url, hf_headers): assert not response.details.tokens[0].special -def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers): - client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", decoder_input_details=True) - - assert response.generated_text != "" - assert response.details.finish_reason == FinishReason.EndOfSequenceToken - assert response.details.generated_tokens > 1 - assert response.details.seed is None - assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) - assert len(response.details.tokens) > 1 - assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == " " - assert not response.details.tokens[0].special - - def test_generate_best_of(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) response = client.generate( diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 63b5258d..0bf80f8c 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -62,7 +62,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -157,7 +157,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -312,7 +312,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -405,7 +405,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 7fa8033e..aa02d8d8 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -9,7 +9,7 @@ class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False # Maximum number of generated tokens - max_new_tokens: Optional[int] = None + max_new_tokens: int = 20 # The parameter for repetition penalty. 1.0 means no penalty. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] = None diff --git a/docs/source/basic_tutorials/using_cli.md b/docs/source/basic_tutorials/using_cli.md index 82c10e6b..a3a65f60 100644 --- a/docs/source/basic_tutorials/using_cli.md +++ b/docs/source/basic_tutorials/using_cli.md @@ -1,6 +1,6 @@ # Using TGI CLI -You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](./installation#install-cli). +You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](../installation#install-cli). `text-generation-server` lets you download the model with `download-weights` command like below 👇 diff --git a/docs/source/messages_api.md b/docs/source/messages_api.md index 1e342686..939850aa 100644 --- a/docs/source/messages_api.md +++ b/docs/source/messages_api.md @@ -4,6 +4,15 @@ Text Generation Inference (TGI) now supports the Messages API, which is fully co > **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature. +#### Table of Contents + +- [Making a Request](#making-a-request) +- [Streaming](#streaming) +- [Synchronous](#synchronous) +- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints) +- [Cloud Providers](#cloud-providers) + - [Amazon SageMaker](#amazon-sagemaker) + ## Making a Request You can make a request to TGI's Messages API using `curl`. Here's an example: @@ -81,6 +90,38 @@ chat_completion = client.chat.completions.create( print(chat_completion) ``` +## Hugging Face Inference Endpoints + +The Messages API is integrated with [Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated). +Every endpoint that uses "Text Generation Inference" with an LLM, which has a chat template can now be used. Below is an example of how to use IE with TGI using OpenAI's Python client library: + +> **Note:** Make sure to replace `base_url` with your endpoint URL and to include `v1/` at the end of the URL. The `api_key` should be replaced with your Hugging Face API key. + +```python +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI( + # replace with your endpoint url, make sure to include "v1/" at the end + base_url="https://vlzz10eq3fol3429.us-east-1.aws.endpoints.huggingface.cloud/v1/", + # replace with your API key + api_key="hf_XXX" +) + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ], + stream=True +) + +# iterate and print stream +for message in chat_completion: + print(message.choices[0].delta.content, end="") +``` + ## Cloud Providers TGI can be deployed on various cloud providers for scalable and robust text generation. One such provider is Amazon SageMaker, which has recently added support for TGI. Here's how you can deploy TGI on Amazon SageMaker: @@ -114,7 +155,7 @@ hub = { huggingface_model = HuggingFaceModel( image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"), env=hub, - role=role, + role=role, ) # deploy model to SageMaker Inference @@ -123,7 +164,7 @@ predictor = huggingface_model.deploy( instance_type="ml.g5.2xlarge", container_startup_health_check_timeout=300, ) - + # send request predictor.predict({ "messages": [ diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json index 53055e42..5e537bb7 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -1,193 +1,194 @@ { - "generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L", "details": { "best_of_sequences": null, "finish_reason": "length", "generated_tokens": 20, - "seed": null, "prefill": [ { "id": 589, - "text": "def", - "logprob": null + "logprob": null, + "text": "def" }, { "id": 3226, - "text": " ge", - "logprob": -9.0234375 + "logprob": -8.5859375, + "text": " ge" }, { "id": 21017, - "text": "ometric", - "logprob": -9.0859375 + "logprob": -7.5859375, + "text": "ometric" }, { "id": 81, - "text": "_", - "logprob": -0.25878906 + "logprob": -0.2668457, + "text": "_" }, { "id": 6009, - "text": "mean", - "logprob": -2.2109375 + "logprob": -1.6416016, + "text": "mean" }, { "id": 26, - "text": "(", - "logprob": -0.30371094 + "logprob": -0.22705078, + "text": "(" }, { "id": 62, - "text": "L", - "logprob": -5.6054688 + "logprob": -5.2304688, + "text": "L" }, { "id": 44, - "text": ":", - "logprob": -3.0722656 + "logprob": -3.0976562, + "text": ":" }, { "id": 1682, - "text": " List", - "logprob": -0.6879883 + "logprob": -1.1044922, + "text": " List" }, { "id": 77, - "text": "[", - "logprob": -0.38500977 + "logprob": -0.14294434, + "text": "[" }, { "id": 1808, - "text": "float", - "logprob": -0.984375 + "logprob": -0.32299805, + "text": "float" }, { "id": 10794, - "text": "]):", - "logprob": -2.5351562 + "logprob": -2.8164062, + "text": "]):" } ], + "seed": null, "tokens": [ { "id": 284, - "text": "\n ", - "logprob": -1.1738281, - "special": false + "logprob": -0.1282959, + "special": false, + "text": "\n " }, { - "id": 442, - "text": " return", - "logprob": -0.95947266, - "special": false + "id": 1524, + "logprob": -0.97998047, + "special": false, + "text": " \"\"\"" }, { - "id": 3632, - "text": " sum", - "logprob": -1.4199219, - "special": false + "id": 284, + "logprob": -0.7006836, + "special": false, + "text": "\n " }, { - "id": 26, - "text": "(", - "logprob": -0.085876465, - "special": false + "id": 14883, + "logprob": -2.1933594, + "special": false, + "text": " Calculate" }, { - "id": 62, - "text": "L", - "logprob": -0.09875488, - "special": false - }, - { - "id": 27, - "text": ")", - "logprob": -0.30517578, - "special": false - }, - { - "id": 517, - "text": " /", - "logprob": -0.42089844, - "special": false - }, - { - "id": 2069, - "text": " len", - "logprob": -0.042053223, - "special": false - }, - { - "id": 26, - "text": "(", - "logprob": -0.0011806488, - "special": false - }, - { - "id": 62, - "text": "L", - "logprob": -0.0005259514, - "special": false - }, - { - "id": 27, - "text": ")", - "logprob": -0.0017633438, - "special": false - }, - { - "id": 478, - "text": "\n\n", - "logprob": -0.69189453, - "special": false - }, - { - "id": 203, - "text": "\n", - "logprob": -0.041870117, - "special": false - }, - { - "id": 589, - "text": "def", - "logprob": -0.27856445, - "special": false + "id": 322, + "logprob": -0.2697754, + "special": false, + "text": " the" }, { "id": 3226, - "text": " ge", - "logprob": -1.7255859, - "special": false + "logprob": -0.0836792, + "special": false, + "text": " ge" }, { "id": 21017, - "text": "ometric", - "logprob": -0.011291504, - "special": false + "logprob": -0.018737793, + "special": false, + "text": "ometric" }, { - "id": 81, - "text": "_", - "logprob": -0.008430481, - "special": false + "id": 5651, + "logprob": -0.028640747, + "special": false, + "text": " mean" }, { - "id": 6009, - "text": "mean", - "logprob": -0.025787354, - "special": false + "id": 432, + "logprob": -0.29467773, + "special": false, + "text": " of" }, { - "id": 26, - "text": "(", - "logprob": -0.073913574, - "special": false + "id": 312, + "logprob": -0.31518555, + "special": false, + "text": " a" }, { - "id": 62, - "text": "L", - "logprob": -0.09967041, - "special": false + "id": 1149, + "logprob": -0.20605469, + "special": false, + "text": " list" + }, + { + "id": 432, + "logprob": -0.23254395, + "special": false, + "text": " of" + }, + { + "id": 7515, + "logprob": -0.4489746, + "special": false, + "text": " numbers" + }, + { + "id": 32, + "logprob": -0.6044922, + "special": false, + "text": "." + }, + { + "id": 446, + "logprob": -0.63964844, + "special": false, + "text": "\n\n " + }, + { + "id": 499, + "logprob": -1.1953125, + "special": false, + "text": " :" + }, + { + "id": 753, + "logprob": -0.03515625, + "special": false, + "text": "param" + }, + { + "id": 498, + "logprob": -0.06311035, + "special": false, + "text": " L" + }, + { + "id": 44, + "logprob": -0.003414154, + "special": false, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.3310547, + "special": false, + "text": " List" } - ] - } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json index 1ace3814..bf0f5146 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5898438, "text": "ometric" }, { "id": 81, - "logprob": -0.25830078, + "logprob": -0.26586914, "text": "_" }, { "id": 6009, - "logprob": -2.1875, + "logprob": -1.6347656, "text": "mean" }, { "id": 26, - "logprob": -0.30004883, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.6171875, + "logprob": -5.2382812, "text": "L" }, { "id": 44, - "logprob": -3.078125, + "logprob": -3.0996094, "text": ":" }, { "id": 1682, - "logprob": -0.68066406, + "logprob": -1.1025391, "text": " List" }, { "id": 77, - "logprob": -0.38745117, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.9453125, + "logprob": -0.32226562, "text": "float" }, { "id": 10794, - "logprob": -2.5371094, + "logprob": -2.8164062, "text": "]):" } ], @@ -69,19 +69,19 @@ "tokens": [ { "id": 284, - "logprob": -0.051635742, + "logprob": 0.0, "special": false, "text": "\n " }, { "id": 442, - "logprob": 0.0, + "logprob": -1.3134766, "special": false, "text": " return" }, { "id": 11665, - "logprob": -1.2236328, + "logprob": -0.10021973, "special": false, "text": " reduce" }, @@ -129,7 +129,7 @@ }, { "id": 319, - "logprob": 0.0, + "logprob": -0.42871094, "special": false, "text": " *" }, @@ -158,36 +158,37 @@ "text": ")" }, { - "id": 203, - "logprob": -0.12695312, - "special": false, - "text": "\n" - }, - { - "id": 203, + "id": 1115, "logprob": 0.0, "special": false, - "text": "\n" + "text": " **" }, { - "id": 589, + "id": 308, "logprob": 0.0, "special": false, - "text": "def" + "text": " (" }, { - "id": 3226, + "id": 35, "logprob": 0.0, "special": false, - "text": " ge" + "text": "1" }, { - "id": 21017, + "id": 32, + "logprob": -0.31323242, + "special": false, + "text": "." + }, + { + "id": 34, "logprob": 0.0, "special": false, - "text": "ometric" + "text": "0" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric" + "generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json index 5381ce5a..46a21ed8 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json @@ -12,57 +12,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5820312, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26708984, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6386719, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22717285, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.1015625, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1083984, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.32592773, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8164062, "text": "]):" } ], @@ -70,67 +70,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.12817383, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.91796875, + "id": 1524, + "logprob": -0.9863281, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.3291016, + "id": 284, + "logprob": -0.7011719, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.08062744, + "id": 14883, + "logprob": -2.2050781, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.097717285, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.29003906, + "id": 3226, + "logprob": -0.08465576, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.34958984, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.03829956, + "id": 5651, + "logprob": -0.028625488, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011987686, + "id": 432, + "logprob": -0.29418945, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.00050878525, + "id": 312, + "logprob": -0.3161621, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -145,57 +146,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.59375, "text": "ometric" }, { "id": 81, - "logprob": -0.25878906, + "logprob": -0.26953125, "text": "_" }, { "id": 6009, - "logprob": -2.2109375, + "logprob": -1.640625, "text": "mean" }, { "id": 26, - "logprob": -0.30371094, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.6054688, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0722656, + "logprob": -3.1132812, "text": ":" }, { "id": 1682, - "logprob": -0.6879883, + "logprob": -1.1123047, "text": " List" }, { "id": 77, - "logprob": -0.38500977, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.984375, + "logprob": -0.32299805, "text": "float" }, { "id": 10794, - "logprob": -2.5351562, + "logprob": -2.8164062, "text": "]):" } ], @@ -203,67 +204,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1738281, + "logprob": -0.12854004, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.9584961, + "id": 1524, + "logprob": -0.9897461, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.4169922, + "id": 284, + "logprob": -0.69970703, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.085876465, + "id": 14883, + "logprob": -2.2050781, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.0982666, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.3022461, + "id": 3226, + "logprob": -0.08496094, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.40504883, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.041656494, + "id": 5651, + "logprob": -0.029037476, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011844635, + "id": 432, + "logprob": -0.2939453, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.0005264282, + "id": 312, + "logprob": -0.31591797, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -278,57 +280,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5859375, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26586914, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6347656, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22766113, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.2265625, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.0976562, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1025391, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.1427002, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.32592773, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8164062, "text": "]):" } ], @@ -336,67 +338,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.13012695, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.9165039, + "id": 1524, + "logprob": -0.98046875, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.328125, + "id": 284, + "logprob": -0.69921875, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.07946777, + "id": 14883, + "logprob": -2.1992188, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.09820557, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.28930664, + "id": 3226, + "logprob": -0.083496094, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.34592773, + "id": 21017, + "logprob": -0.01902771, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.038330078, + "id": 5651, + "logprob": -0.029006958, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011940002, + "id": 432, + "logprob": -0.29248047, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.00050878525, + "id": 312, + "logprob": -0.3161621, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -411,57 +414,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5859375, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26904297, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6386719, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.1132812, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1074219, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.14477539, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.3256836, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8027344, "text": "]):" } ], @@ -469,66 +472,67 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.12915039, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.91259766, + "id": 1524, + "logprob": -0.98535156, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.3251953, + "id": 284, + "logprob": -0.69921875, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.08062744, + "id": 14883, + "logprob": -2.2011719, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.09906006, + "id": 322, + "logprob": -0.26708984, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.28979492, + "id": 3226, + "logprob": -0.08502197, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.35958984, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.038604736, + "id": 5651, + "logprob": -0.028625488, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011901855, + "id": 432, + "logprob": -0.29589844, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.0005078316, + "id": 312, + "logprob": -0.31591797, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" } ] diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json new file mode 100644 index 00000000..4435f215 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.3552246, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38378906, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.140625, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.5551758, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59033203, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.70654297, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0410156, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0026435852, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2841797, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" +} diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json new file mode 100644 index 00000000..052c1c69 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2502, + "logprob": null, + "text": " red" + }, + { + "id": 13, + "logprob": -2.5234375, + "text": "," + }, + { + "id": 8862, + "logprob": -3.4433594, + "text": " yellow" + }, + { + "id": 13, + "logprob": -0.43017578, + "text": "," + }, + { + "id": 209, + "logprob": -8.21875, + "text": " " + } + ], + "seed": 0, + "tokens": [ + { + "id": 187, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 395, + "logprob": -0.46411133, + "special": false, + "text": "and" + }, + { + "id": 13735, + "logprob": -2.1132812, + "special": false, + "text": " orange" + }, + { + "id": 313, + "logprob": -1.2128906, + "special": false, + "text": " (" + }, + { + "id": 249, + "logprob": -2.3671875, + "special": false, + "text": "in" + }, + { + "id": 253, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 1340, + "logprob": -1.640625, + "special": false, + "text": " order" + }, + { + "id": 597, + "logprob": -0.5488281, + "special": false, + "text": " they" + }, + { + "id": 3176, + "logprob": -0.48608398, + "special": false, + "text": " appear" + }, + { + "id": 275, + "logprob": 0.0, + "special": false, + "text": " in" + } + ], + "top_tokens": null + }, + "generated_text": "blue, red, yellow, \nand orange (in the order they appear in" +} diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json new file mode 100644 index 00000000..014210b2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json @@ -0,0 +1,398 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.8125, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.828125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -3.0, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1484375, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.3552246, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38378906, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1279297, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.5595703, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.60253906, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7050781, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0488281, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3808594, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0026416779, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.35351562, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38256836, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1269531, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.54541016, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59765625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7001953, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0585938, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0027446747, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.35351562, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38256836, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1269531, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.54541016, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59765625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7001953, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0585938, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0027446747, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.35351562, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38256836, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1269531, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.54541016, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59765625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7001953, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0585938, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0027446747, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + } +] diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py new file mode 100644 index 00000000..d86faeff --- /dev/null +++ b/integration-tests/models/test_mamba.py @@ -0,0 +1,59 @@ +import pytest + + +@pytest.fixture(scope="module") +def fused_kernel_mamba_handle(launcher): + with launcher("state-spaces/mamba-130m", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def fused_kernel_mamba(fused_kernel_mamba_handle): + await fused_kernel_mamba_handle.health(300) + return fused_kernel_mamba_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_mamba(fused_kernel_mamba, response_snapshot): + response = await fused_kernel_mamba.generate( + "What is Deep Learning?", max_new_tokens=10 + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "\n\nDeep learning is a new type of machine" + assert response == response_snapshot + +@pytest.mark.asyncio +@pytest.mark.private +async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): + response = await fused_kernel_mamba.generate( + "blue, red, yellow, ", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in" + assert response == response_snapshot + +@pytest.mark.asyncio +@pytest.mark.private +async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): + responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert responses[0].generated_text == "\n\nDeep learning is a new type of machine" + + assert responses == response_snapshot diff --git a/router/Cargo.toml b/router/Cargo.toml index f6f16dae..1a7ceb70 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -32,7 +32,7 @@ reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" thiserror = "1.0.48" -tokenizers = { version = "0.14.0", features = ["http"] } +tokenizers = { version = "0.15.1", features = ["http"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.14" tower-http = { version = "0.4.4", features = ["cors"] } diff --git a/router/src/infer.rs b/router/src/infer.rs index 5f078ba0..4da0da0a 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -198,6 +198,7 @@ impl Infer { messages, eos_token: eos_token.as_deref(), bos_token: bos_token.as_deref(), + add_generation_prompt: true, }) .map_err(|e| { metrics::increment_counter!("tgi_request_failure", "err" => "template"); @@ -806,21 +807,14 @@ mod tests { ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); assert_eq!( result, - r#"### User: -Hi! - -### Assistant: -Hello how can I help?### User: -What is Deep Learning? - -### Assistant: -magic!"# + "### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n" ); } @@ -878,6 +872,7 @@ magic!"# ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); @@ -943,9 +938,60 @@ magic!"# ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); } + + #[test] + fn test_chat_template_valid_with_add_generation_prompt() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {% for message in messages %} + {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}} + {% endfor %} + {% if add_generation_prompt %} + {{ '<|im_start|>assistant\n' }} + {% endif %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + Message { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + Message { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n"); + } } diff --git a/router/src/lib.rs b/router/src/lib.rs index fc5670a0..e85519cc 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -37,7 +37,7 @@ pub struct HubTokenizerConfig { } impl HubTokenizerConfig { - pub fn from_file(filename: &str) -> Self { + pub fn from_file(filename: &std::path::Path) -> Self { let content = std::fs::read_to_string(filename).unwrap(); serde_json::from_str(&content).unwrap_or_default() } @@ -398,6 +398,7 @@ pub(crate) struct ChatTemplateInputs<'a> { messages: Vec, bos_token: Option<&'a str>, eos_token: Option<&'a str>, + add_generation_prompt: bool, } #[derive(Clone, Deserialize, ToSchema, Serialize)] diff --git a/router/src/main.rs b/router/src/main.rs index 495fd5bc..2a080468 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -154,12 +154,6 @@ async fn main() -> Result<(), RouterError> { let local_path = Path::new(&tokenizer_name); let local_model = local_path.exists() && local_path.is_dir(); - // Load tokenizer config - // This will be used to format the chat template - let local_tokenizer_config_path = - tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string()); - let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists(); - // Shared API builder initialization let api_builder = || { let mut builder = ApiBuilder::new() @@ -230,24 +224,35 @@ async fn main() -> Result<(), RouterError> { }; // Load tokenizer config if found locally, or check if we can get it from the API if needed - let tokenizer_config = if local_tokenizer_config { + let tokenizer_config = if let Some(path) = tokenizer_config_path { + tracing::info!("Using local tokenizer config from user specified path"); + HubTokenizerConfig::from_file(&std::path::PathBuf::from(path)) + } else if local_model { tracing::info!("Using local tokenizer config"); - HubTokenizerConfig::from_file(&local_tokenizer_config_path) - } else if let Some(api) = api { - tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); - get_tokenizer_config(&api.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.unwrap_or_else(|| "main".to_string()), - ))) - .await - .unwrap_or_else(|| { - tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); - HubTokenizerConfig::default() - }) + HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json")) } else { - tracing::warn!("Could not find tokenizer config locally and no revision specified"); - HubTokenizerConfig::default() + match api { + Some(api) => { + tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); + let repo = Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.unwrap_or("main".to_string()), + ); + get_tokenizer_config(&api.repo(repo)) + .await + .unwrap_or_else(|| { + tracing::warn!( + "Could not retrieve tokenizer config from the Hugging Face hub." + ); + HubTokenizerConfig::default() + }) + } + None => { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + } + } }; if tokenizer.is_none() { diff --git a/router/src/server.rs b/router/src/server.rs index 52ed03df..b4d26158 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -936,6 +936,7 @@ pub async fn run( // Define base and health routes let base_routes = Router::new() .route("/", post(compat_generate)) + .route("/", get(health)) .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) diff --git a/server/.gitignore b/server/.gitignore index dcb8fe67..576746ee 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -161,3 +161,4 @@ flash-attention-v2/ vllm/ llm-awq/ eetq/ +mamba/ diff --git a/server/Makefile b/server/Makefile index b1926828..31d55c41 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,6 +3,7 @@ include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq include Makefile-eetq +include Makefile-selective-scan unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-selective-scan b/server/Makefile-selective-scan new file mode 100644 index 00000000..f4dec868 --- /dev/null +++ b/server/Makefile-selective-scan @@ -0,0 +1,28 @@ +selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137 + +causal-conv1d: + rm -rf causal-conv1d + git clone https://github.com/Dao-AILab/causal-conv1d.git + +build-causal-conv1d: causal-conv1d + cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag + cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build + +install-causal-conv1d: build-causal-conv1d + pip uninstall causal-conv1d -y || true + cd causal-conv1d/ && pip install . + +# selective-scan dependends on causal-conv1d +selective-scan: + rm -rf mamba + git clone https://github.com/state-spaces/mamba.git mamba + +build-selective-scan: selective-scan + cd mamba/ && git fetch && git checkout $(selective_scan_commit) + cd mamba && python setup.py build + +install-selective-scan: install-causal-conv1d build-selective-scan + pip uninstall selective-scan-cuda -y || true + cd mamba && pip install . + +build-all: build-causal-conv1d build-selective-scan \ No newline at end of file diff --git a/server/poetry.lock b/server/poetry.lock index 64b1b74f..32031f89 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -1589,30 +1589,32 @@ xml = ["lxml (>=4.9.2)"] [[package]] name = "peft" -version = "0.4.0" +version = "0.8.2" description = "Parameter-Efficient Fine-Tuning (PEFT)" optional = true python-versions = ">=3.8.0" files = [ - {file = "peft-0.4.0-py3-none-any.whl", hash = "sha256:2cf992772a6d703814477e0bdcdadd68cb8ea388111ce2d793dd2ff0e438f357"}, - {file = "peft-0.4.0.tar.gz", hash = "sha256:e768fa22d6e9f32aa7e891f0d06f355960278ca4dc0cdd96bff71f6f06269207"}, + {file = "peft-0.8.2-py3-none-any.whl", hash = "sha256:4a9c81c38e689fd4043b2757cd0e2b526a9b8b8fd04f8442df2c4824b32c2505"}, + {file = "peft-0.8.2.tar.gz", hash = "sha256:bbdf61db2d8ca503e894edc64016038e6f34b7b522374bad09a22af41882e7ac"}, ] [package.dependencies] -accelerate = "*" +accelerate = ">=0.21.0" +huggingface-hub = ">=0.17.0" numpy = ">=1.17" packaging = ">=20.0" psutil = "*" pyyaml = "*" safetensors = "*" torch = ">=1.13.0" +tqdm = "*" transformers = "*" [package.extras] dev = ["black (>=22.0,<23.0)", "hf-doc-builder", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"] docs-specific = ["hf-doc-builder"] quality = ["black (>=22.0,<23.0)", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"] -test = ["black (>=22.0,<23.0)", "datasets", "diffusers", "hf-doc-builder", "parameterized", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"] +test = ["black (>=22.0,<23.0)", "datasets", "diffusers (<0.21.0)", "hf-doc-builder", "parameterized", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.0.241)", "scipy", "urllib3 (<=2.0.0)"] [[package]] name = "pillow" @@ -1893,6 +1895,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2962,4 +2965,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "33d533d21d14c258678a8c4bb28e2a15e8ebe5ca35d8589cbfe4a7b7d2e79a90" +content-hash = "f7529125bdd7ce142082ce4969edbda5d9b67b6209f199194c54198829f5dc64" diff --git a/server/pyproject.toml b/server/pyproject.toml index 72a7afb0..b8ebf2e3 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -30,7 +30,7 @@ transformers = "^4.37.1" einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } -peft = { version = "^0.4.0", optional = true } +peft = { version = "^0.8.2", optional = true } torch = { version = "^2.1.1", optional = true } scipy = "^1.11.1" pillow = "^10.0.0" diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 68096709..a952f060 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -76,6 +76,15 @@ if FLASH_ATTENTION: __all__.append(FlashMixtral) __all__.append(FlashPhi) +MAMBA_AVAILABLE = True +try: + from text_generation_server.models.mamba import Mamba +except ImportError as e: + logger.warning(f"Could not import Mamba: {e}") + MAMBA_AVAILABLE = False + +if MAMBA_AVAILABLE: + __all__.append(Mamba) def get_model( model_id: str, @@ -164,7 +173,25 @@ def get_model( if speculate > 0: logger.info(f"Using speculation {method} with {speculate} input ids.") - model_type = config_dict["model_type"] + model_type = config_dict.get("model_type", None) + if model_type is None: + # TODO: fix how we determine model type for Mamba + if "ssm_cfg" in config_dict: + # *only happens in Mamba case + model_type = "ssm" + else: + raise RuntimeError( + f"Could not determine model type for {model_id} revision {revision}" + ) + + if model_type == "ssm": + return Mamba( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "gpt_bigcode": if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 22d03adf..81041046 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -69,9 +69,17 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) - g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - g_idx = g_idx.to(device=weights.device) - bits, groupsize, _ = weights._get_gptq_params() + bits, groupsize, _, quant_method, = weights._get_gptq_params() + if quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") + g_idx = g_idx.to(device=weights.device) + elif quant_method == "awq": + g_idx = None + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) from text_generation_server.utils.layers import HAS_EXLLAMA diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py new file mode 100644 index 00000000..1773f04d --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -0,0 +1,194 @@ +import torch +import torch.distributed + +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.utils.generation import InferenceParams +from torch import nn +from typing import Optional, Tuple, Any +from transformers.configuration_utils import PretrainedConfig +import torch.nn.functional as F + +from text_generation_server.utils.layers import ( + TensorParallelEmbedding, + FastRMSNorm, + FastLinear, +) + +from einops import rearrange +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +import math + +class MambaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=50280, + d_model=768, + d_state=16, + n_layer=32, + layer_norm_epsilon=1e-5, + tie_word_embeddings=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + expand=2, + dt_rank="auto", + **kwargs, + ): + self.vocab_size = vocab_size + self.n_layer = n_layer + self.layer_norm_epsilon = layer_norm_epsilon + self.d_model = d_model + self.d_inner = d_model * 2 + self.d_conv = 4 + self.d_state = d_state + self.expand = expand + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + +class MambaBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.layer_idx = int(prefix.split(".")[2]) + self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) + self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) + self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) + self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False) + self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False) + self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True) + self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) + self.D = weights.get_tensor(f"{prefix}.D") + self.activation = "silu" + self.dt_rank = config.dt_rank + self.d_state = config.d_state + self.d_conv = config.d_conv + self.act = nn.SiLU() + + # inference_params + def forward(self, hidden_states: torch.Tensor, inference_params=None): + _, seqlen, _ = hidden_states.shape + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + + if inference_params.seqlen_offset > 0: + out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) + return out, conv_state, ssm_state + + projected_states = self.in_proj(hidden_states).transpose(1,2) + x, z = projected_states.chunk(2, dim=1) + conv_state = F.pad(x, (self.d_conv - seqlen, 0)) + x = causal_conv1d_fn( + x=x, + weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + y, last_state = selective_scan_fn( + x, + dt, + self.negA, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=True, + ) + y = rearrange(y, "b d l -> b l d") + attn_outputs = self.out_proj(y) + return attn_outputs, conv_state, last_state + + def step(self, hidden_states, conv_state, ssm_state): + _xz = self.in_proj(hidden_states) + _x, _z = _xz.chunk(2, dim=-1) # (B D) + conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) + conv_out = causal_conv1d_fn( + x=conv_state_new, + weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + bias=self.conv1d.bias, + activation=self.activation + ) + conv_state = conv_state_new[:, :, 1:] + bsz, seqlen, dim = hidden_states.shape + output_tensor = torch.zeros( + (bsz, seqlen, dim), + device=hidden_states.device, + dtype=hidden_states.dtype + ) + for i in range(0, bsz): + x = conv_out[i:i+1,:,-1] + z = _z[i:i+1, -1, :] + x_db = self.x_proj(x) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = F.linear(dt, self.dt_proj.weight) + y = selective_state_update( + ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + out = self.out_proj(y) + output_tensor[i] = out + + return output_tensor, conv_state, ssm_state + + + +class ResidualBlock(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights) + self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] = None, + inference_params: Optional[Any] = None, + ): + residual = (hidden_states + residual) if residual is not None else hidden_states + shape = residual.shape + hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1])) + hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) + return hidden_states, residual, conv_state, last_ssm_state + +class MambaModel(nn.Module): + def __init__(self, config, weights): + super().__init__() + prefix = "backbone" + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) + self.blocks = nn.ModuleList( + [ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)] + ) + self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon) + self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False) + self.config = config + + def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: + hidden_states = self.embed_tokens(input_ids) + for block in self.blocks: + hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params) + inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state) + + hidden_states = hidden_states + residual if residual is not None else hidden_states + hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) + hidden_states = hidden_states.view(residual.shape) + logits = self.lm_head(hidden_states) + + # update the offset for the next inference using these params + inference_params.seqlen_offset += input_ids.size(1) + return logits, input_ids, inference_params \ No newline at end of file diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py new file mode 100644 index 00000000..c10910aa --- /dev/null +++ b/server/text_generation_server/models/mamba.py @@ -0,0 +1,656 @@ +import torch +import torch.distributed +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from typing import Optional +from text_generation_server.models.custom_modeling.mamba_modeling import ( + MambaConfig, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) +import time +from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel +from text_generation_server.models import Model +from typing import Any, List, Optional, Tuple, Type, Dict +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.utils.tokens import batch_top_tokens, Sampling +from dataclasses import dataclass +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from mamba_ssm.utils.generation import InferenceParams + +@dataclass +class MambaBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # Decoder values + input_ids: torch.Tensor + + # All tokens + all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor + + # Metadata used for padding + max_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + # Past metadata + keys_head_dim_last: bool = True + + # Inference params + inference_params: Optional[Dict[str, Any]] = None + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "MambaBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + inputs.append(r.inputs) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + tokenized_inputs = tokenizer( + inputs, + return_tensors="pt", + padding=True, + return_token_type_ids=False, + truncation=True, + max_length=max_truncation, + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(input_len - 5) + read_offsets.append(input_len) + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + input_ids = tokenized_inputs["input_ids"] + all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + # past_input_ids=None, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + indices = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + indices.append(idx) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + + # TODO + # Kept it simple by just updating the state, maybe updating the other CPU values is necessary. + key_value_memory_dict = {} + for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items(): + key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices]) + self.inference_params.key_value_memory_dict = key_value_memory_dict + + return self + + @classmethod + def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": + # Used for padding + total_batch_size = 0 + max_input_length = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + max_tokens = 0 + max_seqlen = 0 + batch_size = 0 + seqlen_offset = 0 + + # Batch tensors + input_ids = None + top_n_tokens_tensor = None + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen) + seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset) + batch_size += batch.inference_params.max_batch_size + + start_index = end_index + + + (_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape + (_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape + n_blocks = len(batches[0].inference_params.key_value_memory_dict) + dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype + device = batches[0].inference_params.key_value_memory_dict[0][0].device + + key_value_memory_dict = {} + for i in range(n_blocks): + conv_state = torch.zeros( + batch_size, + d_model, + d_conv, + device=device, + dtype=dtype, + ) + ssm_state = torch.zeros( + batch_size, + d_model, + d_state, + device=device, + dtype=dtype, + ) + key_value_memory_dict[i] = (conv_state, ssm_state) + lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device) + + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_offset, + key_value_memory_dict=key_value_memory_dict, + lengths_per_sample=lengths_per_sample, + ) + + current_batch = 0 + for batch in batches: + for i in range(n_blocks): + conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] + batch_size = batch.inference_params.max_batch_size + inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state + inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state + inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample + current_batch += batch_size + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + inference_params=inference_params + ) + + def __len__(self): + return len(self.requests) + +class Mamba(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, _rank, _world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/gpt-neox-20b", + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + config = MambaConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + + tokenizer.bos_token_id = config.bos_token_id + tokenizer.eos_token_id = config.eos_token_id + tokenizer.pad_token = tokenizer.eos_token + + config.quantize = quantize + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + model = MambaModel(config, weights) + torch.distributed.barrier(group=self.process_group) + super(Mamba, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + @property + def batch_type(self) -> Type[MambaBatch]: + return MambaBatch + + def warmup(self, batch) -> Optional[int]: + # TODO: implement warmup for Mamba if needed + return None + + def forward( + self, + input_ids: torch.Tensor, + past: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.model( + input_ids, + past=past, + ) + + def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: + start = time.time_ns() + input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids + + batch_size = input_ids.shape[0] + max_seqlen = input_ids.shape[1] + dtype = input_ids.dtype + + # Inference params + seqlen_og = 0 + inf_cache = {} + lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen + + if batch.inference_params is None: + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_og, + key_value_memory_dict=inf_cache, + lengths_per_sample=lengths_per_sample, + ) + + # Allocate inference cache + for res_block in self.model.blocks: + block = res_block.mamba_block + conv_state = torch.zeros( + batch_size, + self.model.config.d_model * self.model.config.expand, + self.model.config.d_conv, + device=block.conv1d.weight.device, + dtype=block.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.model.config.d_model * self.model.config.expand, + self.model.config.d_state, + device=block.dt_proj.weight.device, + dtype=block.dt_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state) + batch.inference_params = inference_params + + # Forward pass + logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params) + + batch.inference_params = new_inference_params + # Results + generations: List[Generation] = [] + stopped = True + + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.log_softmax(logits[:, -1], -1), + accepted_ids, + ) + + start_decode = time.time_ns() + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # Update values + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns)