Merge branch 'rocm-awq-support' of https://github.com/huggingface/text-generation-inference into rocm-awq-support

This commit is contained in:
IlyasMoutawwakil 2024-02-08 16:03:17 +01:00
commit e29fb799cb
27 changed files with 2028 additions and 387 deletions

View File

@ -154,6 +154,12 @@ COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm-cuda
# Build mamba kernels
FROM kernel-builder as mamba-builder
WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile
RUN make build-all
# Build megablocks
FROM kernel-builder as megablocks-builder
@ -205,6 +211,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

View File

@ -21,22 +21,6 @@ def test_generate(flan_t5_xxl_url, hf_headers):
assert not response.details.tokens[0].special
def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", decoder_input_details=True)
assert response.generated_text != ""
assert response.details.finish_reason == FinishReason.EndOfSequenceToken
assert response.details.generated_tokens > 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) > 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate(

View File

@ -62,7 +62,7 @@ class Client:
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: Optional[int] = None,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
@ -157,7 +157,7 @@ class Client:
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: Optional[int] = None,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
@ -312,7 +312,7 @@ class AsyncClient:
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: Optional[int] = None,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
@ -405,7 +405,7 @@ class AsyncClient:
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: Optional[int] = None,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,

View File

@ -9,7 +9,7 @@ class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False
# Maximum number of generated tokens
max_new_tokens: Optional[int] = None
max_new_tokens: int = 20
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None

View File

@ -1,6 +1,6 @@
# Using TGI CLI
You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](./installation#install-cli).
You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](../installation#install-cli).
`text-generation-server` lets you download the model with `download-weights` command like below 👇

View File

@ -4,6 +4,15 @@ Text Generation Inference (TGI) now supports the Messages API, which is fully co
> **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature.
#### Table of Contents
- [Making a Request](#making-a-request)
- [Streaming](#streaming)
- [Synchronous](#synchronous)
- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints)
- [Cloud Providers](#cloud-providers)
- [Amazon SageMaker](#amazon-sagemaker)
## Making a Request
You can make a request to TGI's Messages API using `curl`. Here's an example:
@ -81,6 +90,38 @@ chat_completion = client.chat.completions.create(
print(chat_completion)
```
## Hugging Face Inference Endpoints
The Messages API is integrated with [Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated).
Every endpoint that uses "Text Generation Inference" with an LLM, which has a chat template can now be used. Below is an example of how to use IE with TGI using OpenAI's Python client library:
> **Note:** Make sure to replace `base_url` with your endpoint URL and to include `v1/` at the end of the URL. The `api_key` should be replaced with your Hugging Face API key.
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
# replace with your endpoint url, make sure to include "v1/" at the end
base_url="https://vlzz10eq3fol3429.us-east-1.aws.endpoints.huggingface.cloud/v1/",
# replace with your API key
api_key="hf_XXX"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message.choices[0].delta.content, end="")
```
## Cloud Providers
TGI can be deployed on various cloud providers for scalable and robust text generation. One such provider is Amazon SageMaker, which has recently added support for TGI. Here's how you can deploy TGI on Amazon SageMaker:
@ -114,7 +155,7 @@ hub = {
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
env=hub,
role=role,
role=role,
)
# deploy model to SageMaker Inference
@ -123,7 +164,7 @@ predictor = huggingface_model.deploy(
instance_type="ml.g5.2xlarge",
container_startup_health_check_timeout=300,
)
# send request
predictor.predict({
"messages": [

View File

@ -1,193 +1,194 @@
{
"generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L",
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"seed": null,
"prefill": [
{
"id": 589,
"text": "def",
"logprob": null
"logprob": null,
"text": "def"
},
{
"id": 3226,
"text": " ge",
"logprob": -9.0234375
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"text": "ometric",
"logprob": -9.0859375
"logprob": -7.5859375,
"text": "ometric"
},
{
"id": 81,
"text": "_",
"logprob": -0.25878906
"logprob": -0.2668457,
"text": "_"
},
{
"id": 6009,
"text": "mean",
"logprob": -2.2109375
"logprob": -1.6416016,
"text": "mean"
},
{
"id": 26,
"text": "(",
"logprob": -0.30371094
"logprob": -0.22705078,
"text": "("
},
{
"id": 62,
"text": "L",
"logprob": -5.6054688
"logprob": -5.2304688,
"text": "L"
},
{
"id": 44,
"text": ":",
"logprob": -3.0722656
"logprob": -3.0976562,
"text": ":"
},
{
"id": 1682,
"text": " List",
"logprob": -0.6879883
"logprob": -1.1044922,
"text": " List"
},
{
"id": 77,
"text": "[",
"logprob": -0.38500977
"logprob": -0.14294434,
"text": "["
},
{
"id": 1808,
"text": "float",
"logprob": -0.984375
"logprob": -0.32299805,
"text": "float"
},
{
"id": 10794,
"text": "]):",
"logprob": -2.5351562
"logprob": -2.8164062,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"text": "\n ",
"logprob": -1.1738281,
"special": false
"logprob": -0.1282959,
"special": false,
"text": "\n "
},
{
"id": 442,
"text": " return",
"logprob": -0.95947266,
"special": false
"id": 1524,
"logprob": -0.97998047,
"special": false,
"text": " \"\"\""
},
{
"id": 3632,
"text": " sum",
"logprob": -1.4199219,
"special": false
"id": 284,
"logprob": -0.7006836,
"special": false,
"text": "\n "
},
{
"id": 26,
"text": "(",
"logprob": -0.085876465,
"special": false
"id": 14883,
"logprob": -2.1933594,
"special": false,
"text": " Calculate"
},
{
"id": 62,
"text": "L",
"logprob": -0.09875488,
"special": false
},
{
"id": 27,
"text": ")",
"logprob": -0.30517578,
"special": false
},
{
"id": 517,
"text": " /",
"logprob": -0.42089844,
"special": false
},
{
"id": 2069,
"text": " len",
"logprob": -0.042053223,
"special": false
},
{
"id": 26,
"text": "(",
"logprob": -0.0011806488,
"special": false
},
{
"id": 62,
"text": "L",
"logprob": -0.0005259514,
"special": false
},
{
"id": 27,
"text": ")",
"logprob": -0.0017633438,
"special": false
},
{
"id": 478,
"text": "\n\n",
"logprob": -0.69189453,
"special": false
},
{
"id": 203,
"text": "\n",
"logprob": -0.041870117,
"special": false
},
{
"id": 589,
"text": "def",
"logprob": -0.27856445,
"special": false
"id": 322,
"logprob": -0.2697754,
"special": false,
"text": " the"
},
{
"id": 3226,
"text": " ge",
"logprob": -1.7255859,
"special": false
"logprob": -0.0836792,
"special": false,
"text": " ge"
},
{
"id": 21017,
"text": "ometric",
"logprob": -0.011291504,
"special": false
"logprob": -0.018737793,
"special": false,
"text": "ometric"
},
{
"id": 81,
"text": "_",
"logprob": -0.008430481,
"special": false
"id": 5651,
"logprob": -0.028640747,
"special": false,
"text": " mean"
},
{
"id": 6009,
"text": "mean",
"logprob": -0.025787354,
"special": false
"id": 432,
"logprob": -0.29467773,
"special": false,
"text": " of"
},
{
"id": 26,
"text": "(",
"logprob": -0.073913574,
"special": false
"id": 312,
"logprob": -0.31518555,
"special": false,
"text": " a"
},
{
"id": 62,
"text": "L",
"logprob": -0.09967041,
"special": false
"id": 1149,
"logprob": -0.20605469,
"special": false,
"text": " list"
},
{
"id": 432,
"logprob": -0.23254395,
"special": false,
"text": " of"
},
{
"id": 7515,
"logprob": -0.4489746,
"special": false,
"text": " numbers"
},
{
"id": 32,
"logprob": -0.6044922,
"special": false,
"text": "."
},
{
"id": 446,
"logprob": -0.63964844,
"special": false,
"text": "\n\n "
},
{
"id": 499,
"logprob": -1.1953125,
"special": false,
"text": " :"
},
{
"id": 753,
"logprob": -0.03515625,
"special": false,
"text": "param"
},
{
"id": 498,
"logprob": -0.06311035,
"special": false,
"text": " L"
},
{
"id": 44,
"logprob": -0.003414154,
"special": false,
"text": ":"
},
{
"id": 1682,
"logprob": -1.3310547,
"special": false,
"text": " List"
}
]
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List"
}

View File

@ -11,57 +11,57 @@
},
{
"id": 3226,
"logprob": -9.0234375,
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"logprob": -7.5898438,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25830078,
"logprob": -0.26586914,
"text": "_"
},
{
"id": 6009,
"logprob": -2.1875,
"logprob": -1.6347656,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30004883,
"logprob": -0.22705078,
"text": "("
},
{
"id": 62,
"logprob": -5.6171875,
"logprob": -5.2382812,
"text": "L"
},
{
"id": 44,
"logprob": -3.078125,
"logprob": -3.0996094,
"text": ":"
},
{
"id": 1682,
"logprob": -0.68066406,
"logprob": -1.1025391,
"text": " List"
},
{
"id": 77,
"logprob": -0.38745117,
"logprob": -0.14294434,
"text": "["
},
{
"id": 1808,
"logprob": -0.9453125,
"logprob": -0.32226562,
"text": "float"
},
{
"id": 10794,
"logprob": -2.5371094,
"logprob": -2.8164062,
"text": "]):"
}
],
@ -69,19 +69,19 @@
"tokens": [
{
"id": 284,
"logprob": -0.051635742,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": 0.0,
"logprob": -1.3134766,
"special": false,
"text": " return"
},
{
"id": 11665,
"logprob": -1.2236328,
"logprob": -0.10021973,
"special": false,
"text": " reduce"
},
@ -129,7 +129,7 @@
},
{
"id": 319,
"logprob": 0.0,
"logprob": -0.42871094,
"special": false,
"text": " *"
},
@ -158,36 +158,37 @@
"text": ")"
},
{
"id": 203,
"logprob": -0.12695312,
"special": false,
"text": "\n"
},
{
"id": 203,
"id": 1115,
"logprob": 0.0,
"special": false,
"text": "\n"
"text": " **"
},
{
"id": 589,
"id": 308,
"logprob": 0.0,
"special": false,
"text": "def"
"text": " ("
},
{
"id": 3226,
"id": 35,
"logprob": 0.0,
"special": false,
"text": " ge"
"text": "1"
},
{
"id": 21017,
"id": 32,
"logprob": -0.31323242,
"special": false,
"text": "."
},
{
"id": 34,
"logprob": 0.0,
"special": false,
"text": "ometric"
"text": "0"
}
]
],
"top_tokens": null
},
"generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric"
"generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0"
}

View File

@ -12,57 +12,57 @@
},
{
"id": 3226,
"logprob": -9.0234375,
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"logprob": -7.5820312,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25927734,
"logprob": -0.26708984,
"text": "_"
},
{
"id": 6009,
"logprob": -2.25,
"logprob": -1.6386719,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30126953,
"logprob": -0.22717285,
"text": "("
},
{
"id": 62,
"logprob": -5.7539062,
"logprob": -5.234375,
"text": "L"
},
{
"id": 44,
"logprob": -3.0878906,
"logprob": -3.1015625,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6845703,
"logprob": -1.1083984,
"text": " List"
},
{
"id": 77,
"logprob": -0.3918457,
"logprob": -0.14294434,
"text": "["
},
{
"id": 1808,
"logprob": -0.8798828,
"logprob": -0.32592773,
"text": "float"
},
{
"id": 10794,
"logprob": -2.4980469,
"logprob": -2.8164062,
"text": "]):"
}
],
@ -70,67 +70,68 @@
"tokens": [
{
"id": 284,
"logprob": -1.1533203,
"logprob": -0.12817383,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.91796875,
"id": 1524,
"logprob": -0.9863281,
"special": false,
"text": " return"
"text": " \"\"\""
},
{
"id": 3632,
"logprob": -1.3291016,
"id": 284,
"logprob": -0.7011719,
"special": false,
"text": " sum"
"text": "\n "
},
{
"id": 26,
"logprob": -0.08062744,
"id": 14883,
"logprob": -2.2050781,
"special": false,
"text": "("
"text": " Calculate"
},
{
"id": 62,
"logprob": -0.097717285,
"id": 322,
"logprob": -0.2668457,
"special": false,
"text": "L"
"text": " the"
},
{
"id": 27,
"logprob": -0.29003906,
"id": 3226,
"logprob": -0.08465576,
"special": false,
"text": ")"
"text": " ge"
},
{
"id": 517,
"logprob": -0.34958984,
"id": 21017,
"logprob": -0.019012451,
"special": false,
"text": " /"
"text": "ometric"
},
{
"id": 2069,
"logprob": -0.03829956,
"id": 5651,
"logprob": -0.028625488,
"special": false,
"text": " len"
"text": " mean"
},
{
"id": 26,
"logprob": -0.0011987686,
"id": 432,
"logprob": -0.29418945,
"special": false,
"text": "("
"text": " of"
},
{
"id": 62,
"logprob": -0.00050878525,
"id": 312,
"logprob": -0.3161621,
"special": false,
"text": "L"
"text": " a"
}
]
],
"top_tokens": null
},
"generated_text": "\n return sum(L) / len(L"
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
},
{
"details": {
@ -145,57 +146,57 @@
},
{
"id": 3226,
"logprob": -9.0234375,
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"logprob": -7.59375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25878906,
"logprob": -0.26953125,
"text": "_"
},
{
"id": 6009,
"logprob": -2.2109375,
"logprob": -1.640625,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30371094,
"logprob": -0.22705078,
"text": "("
},
{
"id": 62,
"logprob": -5.6054688,
"logprob": -5.234375,
"text": "L"
},
{
"id": 44,
"logprob": -3.0722656,
"logprob": -3.1132812,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6879883,
"logprob": -1.1123047,
"text": " List"
},
{
"id": 77,
"logprob": -0.38500977,
"logprob": -0.14294434,
"text": "["
},
{
"id": 1808,
"logprob": -0.984375,
"logprob": -0.32299805,
"text": "float"
},
{
"id": 10794,
"logprob": -2.5351562,
"logprob": -2.8164062,
"text": "]):"
}
],
@ -203,67 +204,68 @@
"tokens": [
{
"id": 284,
"logprob": -1.1738281,
"logprob": -0.12854004,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.9584961,
"id": 1524,
"logprob": -0.9897461,
"special": false,
"text": " return"
"text": " \"\"\""
},
{
"id": 3632,
"logprob": -1.4169922,
"id": 284,
"logprob": -0.69970703,
"special": false,
"text": " sum"
"text": "\n "
},
{
"id": 26,
"logprob": -0.085876465,
"id": 14883,
"logprob": -2.2050781,
"special": false,
"text": "("
"text": " Calculate"
},
{
"id": 62,
"logprob": -0.0982666,
"id": 322,
"logprob": -0.2668457,
"special": false,
"text": "L"
"text": " the"
},
{
"id": 27,
"logprob": -0.3022461,
"id": 3226,
"logprob": -0.08496094,
"special": false,
"text": ")"
"text": " ge"
},
{
"id": 517,
"logprob": -0.40504883,
"id": 21017,
"logprob": -0.019012451,
"special": false,
"text": " /"
"text": "ometric"
},
{
"id": 2069,
"logprob": -0.041656494,
"id": 5651,
"logprob": -0.029037476,
"special": false,
"text": " len"
"text": " mean"
},
{
"id": 26,
"logprob": -0.0011844635,
"id": 432,
"logprob": -0.2939453,
"special": false,
"text": "("
"text": " of"
},
{
"id": 62,
"logprob": -0.0005264282,
"id": 312,
"logprob": -0.31591797,
"special": false,
"text": "L"
"text": " a"
}
]
],
"top_tokens": null
},
"generated_text": "\n return sum(L) / len(L"
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
},
{
"details": {
@ -278,57 +280,57 @@
},
{
"id": 3226,
"logprob": -9.0234375,
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"logprob": -7.5859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25927734,
"logprob": -0.26586914,
"text": "_"
},
{
"id": 6009,
"logprob": -2.25,
"logprob": -1.6347656,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30126953,
"logprob": -0.22766113,
"text": "("
},
{
"id": 62,
"logprob": -5.7539062,
"logprob": -5.2265625,
"text": "L"
},
{
"id": 44,
"logprob": -3.0878906,
"logprob": -3.0976562,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6845703,
"logprob": -1.1025391,
"text": " List"
},
{
"id": 77,
"logprob": -0.3918457,
"logprob": -0.1427002,
"text": "["
},
{
"id": 1808,
"logprob": -0.8798828,
"logprob": -0.32592773,
"text": "float"
},
{
"id": 10794,
"logprob": -2.4980469,
"logprob": -2.8164062,
"text": "]):"
}
],
@ -336,67 +338,68 @@
"tokens": [
{
"id": 284,
"logprob": -1.1533203,
"logprob": -0.13012695,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.9165039,
"id": 1524,
"logprob": -0.98046875,
"special": false,
"text": " return"
"text": " \"\"\""
},
{
"id": 3632,
"logprob": -1.328125,
"id": 284,
"logprob": -0.69921875,
"special": false,
"text": " sum"
"text": "\n "
},
{
"id": 26,
"logprob": -0.07946777,
"id": 14883,
"logprob": -2.1992188,
"special": false,
"text": "("
"text": " Calculate"
},
{
"id": 62,
"logprob": -0.09820557,
"id": 322,
"logprob": -0.2668457,
"special": false,
"text": "L"
"text": " the"
},
{
"id": 27,
"logprob": -0.28930664,
"id": 3226,
"logprob": -0.083496094,
"special": false,
"text": ")"
"text": " ge"
},
{
"id": 517,
"logprob": -0.34592773,
"id": 21017,
"logprob": -0.01902771,
"special": false,
"text": " /"
"text": "ometric"
},
{
"id": 2069,
"logprob": -0.038330078,
"id": 5651,
"logprob": -0.029006958,
"special": false,
"text": " len"
"text": " mean"
},
{
"id": 26,
"logprob": -0.0011940002,
"id": 432,
"logprob": -0.29248047,
"special": false,
"text": "("
"text": " of"
},
{
"id": 62,
"logprob": -0.00050878525,
"id": 312,
"logprob": -0.3161621,
"special": false,
"text": "L"
"text": " a"
}
]
],
"top_tokens": null
},
"generated_text": "\n return sum(L) / len(L"
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
},
{
"details": {
@ -411,57 +414,57 @@
},
{
"id": 3226,
"logprob": -9.0234375,
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"logprob": -7.5859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25927734,
"logprob": -0.26904297,
"text": "_"
},
{
"id": 6009,
"logprob": -2.25,
"logprob": -1.6386719,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30126953,
"logprob": -0.22705078,
"text": "("
},
{
"id": 62,
"logprob": -5.7539062,
"logprob": -5.234375,
"text": "L"
},
{
"id": 44,
"logprob": -3.0878906,
"logprob": -3.1132812,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6845703,
"logprob": -1.1074219,
"text": " List"
},
{
"id": 77,
"logprob": -0.3918457,
"logprob": -0.14477539,
"text": "["
},
{
"id": 1808,
"logprob": -0.8798828,
"logprob": -0.3256836,
"text": "float"
},
{
"id": 10794,
"logprob": -2.4980469,
"logprob": -2.8027344,
"text": "]):"
}
],
@ -469,66 +472,67 @@
"tokens": [
{
"id": 284,
"logprob": -1.1533203,
"logprob": -0.12915039,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.91259766,
"id": 1524,
"logprob": -0.98535156,
"special": false,
"text": " return"
"text": " \"\"\""
},
{
"id": 3632,
"logprob": -1.3251953,
"id": 284,
"logprob": -0.69921875,
"special": false,
"text": " sum"
"text": "\n "
},
{
"id": 26,
"logprob": -0.08062744,
"id": 14883,
"logprob": -2.2011719,
"special": false,
"text": "("
"text": " Calculate"
},
{
"id": 62,
"logprob": -0.09906006,
"id": 322,
"logprob": -0.26708984,
"special": false,
"text": "L"
"text": " the"
},
{
"id": 27,
"logprob": -0.28979492,
"id": 3226,
"logprob": -0.08502197,
"special": false,
"text": ")"
"text": " ge"
},
{
"id": 517,
"logprob": -0.35958984,
"id": 21017,
"logprob": -0.019012451,
"special": false,
"text": " /"
"text": "ometric"
},
{
"id": 2069,
"logprob": -0.038604736,
"id": 5651,
"logprob": -0.028625488,
"special": false,
"text": " len"
"text": " mean"
},
{
"id": 26,
"logprob": -0.0011901855,
"id": 432,
"logprob": -0.29589844,
"special": false,
"text": "("
"text": " of"
},
{
"id": 62,
"logprob": -0.0005078316,
"id": 312,
"logprob": -0.31591797,
"special": false,
"text": "L"
"text": " a"
}
]
],
"top_tokens": null
},
"generated_text": "\n return sum(L) / len(L"
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
}
]

View File

@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -0.3552246,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -0.38378906,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.140625,
"special": false,
"text": "Deep"
},
{
"id": 4715,
"logprob": -0.5551758,
"special": false,
"text": " learning"
},
{
"id": 310,
"logprob": -0.59033203,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.70654297,
"special": false,
"text": " a"
},
{
"id": 747,
"logprob": -2.0410156,
"special": false,
"text": " new"
},
{
"id": 1511,
"logprob": -2.3789062,
"special": false,
"text": " type"
},
{
"id": 273,
"logprob": -0.0026435852,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.2841797,
"special": false,
"text": " machine"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new type of machine"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2502,
"logprob": null,
"text": " red"
},
{
"id": 13,
"logprob": -2.5234375,
"text": ","
},
{
"id": 8862,
"logprob": -3.4433594,
"text": " yellow"
},
{
"id": 13,
"logprob": -0.43017578,
"text": ","
},
{
"id": 209,
"logprob": -8.21875,
"text": " "
}
],
"seed": 0,
"tokens": [
{
"id": 187,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 395,
"logprob": -0.46411133,
"special": false,
"text": "and"
},
{
"id": 13735,
"logprob": -2.1132812,
"special": false,
"text": " orange"
},
{
"id": 313,
"logprob": -1.2128906,
"special": false,
"text": " ("
},
{
"id": 249,
"logprob": -2.3671875,
"special": false,
"text": "in"
},
{
"id": 253,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 1340,
"logprob": -1.640625,
"special": false,
"text": " order"
},
{
"id": 597,
"logprob": -0.5488281,
"special": false,
"text": " they"
},
{
"id": 3176,
"logprob": -0.48608398,
"special": false,
"text": " appear"
},
{
"id": 275,
"logprob": 0.0,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "blue, red, yellow, \nand orange (in the order they appear in"
}

View File

@ -0,0 +1,398 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -0.8125,
"text": " is"
},
{
"id": 18147,
"logprob": -12.828125,
"text": " Deep"
},
{
"id": 20727,
"logprob": -3.0,
"text": " Learning"
},
{
"id": 32,
"logprob": -1.1484375,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -0.3552246,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -0.38378906,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.1279297,
"special": false,
"text": "Deep"
},
{
"id": 4715,
"logprob": -0.5595703,
"special": false,
"text": " learning"
},
{
"id": 310,
"logprob": -0.60253906,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.7050781,
"special": false,
"text": " a"
},
{
"id": 747,
"logprob": -2.0488281,
"special": false,
"text": " new"
},
{
"id": 1511,
"logprob": -2.3808594,
"special": false,
"text": " type"
},
{
"id": 273,
"logprob": -0.0026416779,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.2851562,
"special": false,
"text": " machine"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new type of machine"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -0.78027344,
"text": " is"
},
{
"id": 18147,
"logprob": -12.8203125,
"text": " Deep"
},
{
"id": 20727,
"logprob": -2.9902344,
"text": " Learning"
},
{
"id": 32,
"logprob": -1.1523438,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -0.35351562,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -0.38256836,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.1269531,
"special": false,
"text": "Deep"
},
{
"id": 4715,
"logprob": -0.54541016,
"special": false,
"text": " learning"
},
{
"id": 310,
"logprob": -0.59765625,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.7001953,
"special": false,
"text": " a"
},
{
"id": 747,
"logprob": -2.0585938,
"special": false,
"text": " new"
},
{
"id": 1511,
"logprob": -2.3789062,
"special": false,
"text": " type"
},
{
"id": 273,
"logprob": -0.0027446747,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.2851562,
"special": false,
"text": " machine"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new type of machine"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -0.78027344,
"text": " is"
},
{
"id": 18147,
"logprob": -12.8203125,
"text": " Deep"
},
{
"id": 20727,
"logprob": -2.9902344,
"text": " Learning"
},
{
"id": 32,
"logprob": -1.1523438,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -0.35351562,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -0.38256836,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.1269531,
"special": false,
"text": "Deep"
},
{
"id": 4715,
"logprob": -0.54541016,
"special": false,
"text": " learning"
},
{
"id": 310,
"logprob": -0.59765625,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.7001953,
"special": false,
"text": " a"
},
{
"id": 747,
"logprob": -2.0585938,
"special": false,
"text": " new"
},
{
"id": 1511,
"logprob": -2.3789062,
"special": false,
"text": " type"
},
{
"id": 273,
"logprob": -0.0027446747,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.2851562,
"special": false,
"text": " machine"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new type of machine"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -0.78027344,
"text": " is"
},
{
"id": 18147,
"logprob": -12.8203125,
"text": " Deep"
},
{
"id": 20727,
"logprob": -2.9902344,
"text": " Learning"
},
{
"id": 32,
"logprob": -1.1523438,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 187,
"logprob": -0.35351562,
"special": false,
"text": "\n"
},
{
"id": 187,
"logprob": -0.38256836,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.1269531,
"special": false,
"text": "Deep"
},
{
"id": 4715,
"logprob": -0.54541016,
"special": false,
"text": " learning"
},
{
"id": 310,
"logprob": -0.59765625,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.7001953,
"special": false,
"text": " a"
},
{
"id": 747,
"logprob": -2.0585938,
"special": false,
"text": " new"
},
{
"id": 1511,
"logprob": -2.3789062,
"special": false,
"text": " type"
},
{
"id": 273,
"logprob": -0.0027446747,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.2851562,
"special": false,
"text": " machine"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new type of machine"
}
]

View File

@ -0,0 +1,59 @@
import pytest
@pytest.fixture(scope="module")
def fused_kernel_mamba_handle(launcher):
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
yield handle
@pytest.fixture(scope="module")
async def fused_kernel_mamba(fused_kernel_mamba_handle):
await fused_kernel_mamba_handle.health(300)
return fused_kernel_mamba_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"What is Deep Learning?", max_new_tokens=10
)
assert response.details.generated_tokens == 10
assert response.generated_text == "\n\nDeep learning is a new type of machine"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"blue, red, yellow, ",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
assert responses == response_snapshot

View File

@ -32,7 +32,7 @@ reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
serde_json = "1.0.107"
thiserror = "1.0.48"
tokenizers = { version = "0.14.0", features = ["http"] }
tokenizers = { version = "0.15.1", features = ["http"] }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
tower-http = { version = "0.4.4", features = ["cors"] }

View File

@ -198,6 +198,7 @@ impl Infer {
messages,
eos_token: eos_token.as_deref(),
bos_token: bos_token.as_deref(),
add_generation_prompt: true,
})
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
@ -806,21 +807,14 @@ mod tests {
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(
result,
r#"### User:
Hi!
### Assistant:
Hello how can I help?### User:
What is Deep Learning?
### Assistant:
magic!"#
"### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n"
);
}
@ -878,6 +872,7 @@ magic!"#
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
};
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
@ -943,9 +938,60 @@ magic!"#
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
}
#[test]
fn test_chat_template_valid_with_add_generation_prompt() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
let source = r#"
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
{% endfor %}
{% if add_generation_prompt %}
{{ '<|im_start|>assistant\n' }}
{% endif %}"#;
// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
Message {
role: "user".to_string(),
content: "Hi!".to_string(),
},
Message {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
},
Message {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
},
Message {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
}
}

View File

@ -37,7 +37,7 @@ pub struct HubTokenizerConfig {
}
impl HubTokenizerConfig {
pub fn from_file(filename: &str) -> Self {
pub fn from_file(filename: &std::path::Path) -> Self {
let content = std::fs::read_to_string(filename).unwrap();
serde_json::from_str(&content).unwrap_or_default()
}
@ -398,6 +398,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>,
bos_token: Option<&'a str>,
eos_token: Option<&'a str>,
add_generation_prompt: bool,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]

View File

@ -154,12 +154,6 @@ async fn main() -> Result<(), RouterError> {
let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir();
// Load tokenizer config
// This will be used to format the chat template
let local_tokenizer_config_path =
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
@ -230,24 +224,35 @@ async fn main() -> Result<(), RouterError> {
};
// Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if local_tokenizer_config {
let tokenizer_config = if let Some(path) = tokenizer_config_path {
tracing::info!("Using local tokenizer config from user specified path");
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
} else if local_model {
tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
} else if let Some(api) = api {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
get_tokenizer_config(&api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main".to_string()),
)))
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
HubTokenizerConfig::default()
})
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
} else {
tracing::warn!("Could not find tokenizer config locally and no revision specified");
HubTokenizerConfig::default()
match api {
Some(api) => {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
let repo = Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or("main".to_string()),
);
get_tokenizer_config(&api.repo(repo))
.await
.unwrap_or_else(|| {
tracing::warn!(
"Could not retrieve tokenizer config from the Hugging Face hub."
);
HubTokenizerConfig::default()
})
}
None => {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
}
}
};
if tokenizer.is_none() {

View File

@ -936,6 +936,7 @@ pub async fn run(
// Define base and health routes
let base_routes = Router::new()
.route("/", post(compat_generate))
.route("/", get(health))
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))

1
server/.gitignore vendored
View File

@ -161,3 +161,4 @@ flash-attention-v2/
vllm/
llm-awq/
eetq/
mamba/

View File

@ -3,6 +3,7 @@ include Makefile-flash-att-v2
include Makefile-vllm
include Makefile-awq
include Makefile-eetq
include Makefile-selective-scan
unit-tests:
pytest -s -vv -m "not private" tests

View File

@ -0,0 +1,28 @@
selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
causal-conv1d:
rm -rf causal-conv1d
git clone https://github.com/Dao-AILab/causal-conv1d.git
build-causal-conv1d: causal-conv1d
cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
install-causal-conv1d: build-causal-conv1d
pip uninstall causal-conv1d -y || true
cd causal-conv1d/ && pip install .
# selective-scan dependends on causal-conv1d
selective-scan:
rm -rf mamba
git clone https://github.com/state-spaces/mamba.git mamba
build-selective-scan: selective-scan
cd mamba/ && git fetch && git checkout $(selective_scan_commit)
cd mamba && python setup.py build
install-selective-scan: install-causal-conv1d build-selective-scan
pip uninstall selective-scan-cuda -y || true
cd mamba && pip install .
build-all: build-causal-conv1d build-selective-scan

17
server/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "accelerate"
@ -1589,30 +1589,32 @@ xml = ["lxml (>=4.9.2)"]
[[package]]
name = "peft"
version = "0.4.0"
version = "0.8.2"
description = "Parameter-Efficient Fine-Tuning (PEFT)"
optional = true
python-versions = ">=3.8.0"
files = [
{file = "peft-0.4.0-py3-none-any.whl", hash = "sha256:2cf992772a6d703814477e0bdcdadd68cb8ea388111ce2d793dd2ff0e438f357"},
{file = "peft-0.4.0.tar.gz", hash = "sha256:e768fa22d6e9f32aa7e891f0d06f355960278ca4dc0cdd96bff71f6f06269207"},
{file = "peft-0.8.2-py3-none-any.whl", hash = "sha256:4a9c81c38e689fd4043b2757cd0e2b526a9b8b8fd04f8442df2c4824b32c2505"},
{file = "peft-0.8.2.tar.gz", hash = "sha256:bbdf61db2d8ca503e894edc64016038e6f34b7b522374bad09a22af41882e7ac"},
]
[package.dependencies]
accelerate = "*"
accelerate = ">=0.21.0"
huggingface-hub = ">=0.17.0"
numpy = ">=1.17"
packaging = ">=20.0"
psutil = "*"
pyyaml = "*"
safetensors = "*"
torch = ">=1.13.0"
tqdm = "*"
transformers = "*"
[package.extras]
dev = ["black (>=22.0,<23.0)", "hf-doc-builder", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"]
docs-specific = ["hf-doc-builder"]
quality = ["black (>=22.0,<23.0)", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"]
test = ["black (>=22.0,<23.0)", "datasets", "diffusers", "hf-doc-builder", "parameterized", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"]
test = ["black (>=22.0,<23.0)", "datasets", "diffusers (<0.21.0)", "hf-doc-builder", "parameterized", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.0.241)", "scipy", "urllib3 (<=2.0.0)"]
[[package]]
name = "pillow"
@ -1893,6 +1895,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -2962,4 +2965,4 @@ torch = ["torch"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.13"
content-hash = "33d533d21d14c258678a8c4bb28e2a15e8ebe5ca35d8589cbfe4a7b7d2e79a90"
content-hash = "f7529125bdd7ce142082ce4969edbda5d9b67b6209f199194c54198829f5dc64"

View File

@ -30,7 +30,7 @@ transformers = "^4.37.1"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }
peft = { version = "^0.4.0", optional = true }
peft = { version = "^0.8.2", optional = true }
torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1"
pillow = "^10.0.0"

View File

@ -76,6 +76,15 @@ if FLASH_ATTENTION:
__all__.append(FlashMixtral)
__all__.append(FlashPhi)
MAMBA_AVAILABLE = True
try:
from text_generation_server.models.mamba import Mamba
except ImportError as e:
logger.warning(f"Could not import Mamba: {e}")
MAMBA_AVAILABLE = False
if MAMBA_AVAILABLE:
__all__.append(Mamba)
def get_model(
model_id: str,
@ -164,7 +173,25 @@ def get_model(
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"]
model_type = config_dict.get("model_type", None)
if model_type is None:
# TODO: fix how we determine model type for Mamba
if "ssm_cfg" in config_dict:
# *only happens in Mamba case
model_type = "ssm"
else:
raise RuntimeError(
f"Could not determine model type for {model_id} revision {revision}"
)
if model_type == "ssm":
return Mamba(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "gpt_bigcode":
if FLASH_ATTENTION:

View File

@ -69,9 +69,17 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
bits, groupsize, _ = weights._get_gptq_params()
bits, groupsize, _, quant_method, = weights._get_gptq_params()
if quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
elif quant_method == "awq":
g_idx = None
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
from text_generation_server.utils.layers import HAS_EXLLAMA

View File

@ -0,0 +1,194 @@
import torch
import torch.distributed
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.utils.generation import InferenceParams
from torch import nn
from typing import Optional, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F
from text_generation_server.utils.layers import (
TensorParallelEmbedding,
FastRMSNorm,
FastLinear,
)
from einops import rearrange
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math
class MambaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=50280,
d_model=768,
d_state=16,
n_layer=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
expand=2,
dt_rank="auto",
**kwargs,
):
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layer_norm_epsilon = layer_norm_epsilon
self.d_model = d_model
self.d_inner = d_model * 2
self.d_conv = 4
self.d_state = d_state
self.expand = expand
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.layer_idx = int(prefix.split(".")[2])
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False)
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False)
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
self.D = weights.get_tensor(f"{prefix}.D")
self.activation = "silu"
self.dt_rank = config.dt_rank
self.d_state = config.d_state
self.d_conv = config.d_conv
self.act = nn.SiLU()
# inference_params
def forward(self, hidden_states: torch.Tensor, inference_params=None):
_, seqlen, _ = hidden_states.shape
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
if inference_params.seqlen_offset > 0:
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
return out, conv_state, ssm_state
projected_states = self.in_proj(hidden_states).transpose(1,2)
x, z = projected_states.chunk(2, dim=1)
conv_state = F.pad(x, (self.d_conv - seqlen, 0))
x = causal_conv1d_fn(
x=x,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
bias=self.conv1d.bias,
activation=self.activation,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
y, last_state = selective_scan_fn(
x,
dt,
self.negA,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=True,
)
y = rearrange(y, "b d l -> b l d")
attn_outputs = self.out_proj(y)
return attn_outputs, conv_state, last_state
def step(self, hidden_states, conv_state, ssm_state):
_xz = self.in_proj(hidden_states)
_x, _z = _xz.chunk(2, dim=-1) # (B D)
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1)
conv_out = causal_conv1d_fn(
x=conv_state_new,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
bias=self.conv1d.bias,
activation=self.activation
)
conv_state = conv_state_new[:, :, 1:]
bsz, seqlen, dim = hidden_states.shape
output_tensor = torch.zeros(
(bsz, seqlen, dim),
device=hidden_states.device,
dtype=hidden_states.dtype
)
for i in range(0, bsz):
x = conv_out[i:i+1,:,-1]
z = _z[i:i+1, -1, :]
x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = F.linear(dt, self.dt_proj.weight)
y = selective_state_update(
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
)
out = self.out_proj(y)
output_tensor[i] = out
return output_tensor, conv_state, ssm_state
class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None,
inference_params: Optional[Any] = None,
):
residual = (hidden_states + residual) if residual is not None else hidden_states
shape = residual.shape
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
return hidden_states, residual, conv_state, last_ssm_state
class MambaModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
prefix = "backbone"
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList(
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
)
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
self.config = config
def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
hidden_states = self.embed_tokens(input_ids)
for block in self.blocks:
hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params)
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state)
hidden_states = hidden_states + residual if residual is not None else hidden_states
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
hidden_states = hidden_states.view(residual.shape)
logits = self.lm_head(hidden_states)
# update the offset for the next inference using these params
inference_params.seqlen_offset += input_ids.size(1)
return logits, input_ids, inference_params

View File

@ -0,0 +1,656 @@
import torch
import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
import time
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel
from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type, Dict
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from mamba_ssm.utils.generation import InferenceParams
@dataclass
class MambaBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
# All tokens
all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
prefix_offsets: List[int]
read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
# Past metadata
keys_head_dim_last: bool = True
# Inference params
inference_params: Optional[Dict[str, Any]] = None
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "MambaBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(input_len - 5)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
# past_input_ids=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]:
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
indices = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
indices.append(idx)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
# TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
key_value_memory_dict = {}
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items():
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
self.inference_params.key_value_memory_dict = key_value_memory_dict
return self
@classmethod
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
# Used for padding
total_batch_size = 0
max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
max_seqlen = 0
batch_size = 0
seqlen_offset = 0
# Batch tensors
input_ids = None
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + len(batch)
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen)
seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset)
batch_size += batch.inference_params.max_batch_size
start_index = end_index
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
device = batches[0].inference_params.key_value_memory_dict[0][0].device
key_value_memory_dict = {}
for i in range(n_blocks):
conv_state = torch.zeros(
batch_size,
d_model,
d_conv,
device=device,
dtype=dtype,
)
ssm_state = torch.zeros(
batch_size,
d_model,
d_state,
device=device,
dtype=dtype,
)
key_value_memory_dict[i] = (conv_state, ssm_state)
lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device)
inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=seqlen_offset,
key_value_memory_dict=key_value_memory_dict,
lengths_per_sample=lengths_per_sample,
)
current_batch = 0
for batch in batches:
for i in range(n_blocks):
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
batch_size = batch.inference_params.max_batch_size
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
current_batch += batch_size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
inference_params=inference_params
)
def __len__(self):
return len(self.requests)
class Mamba(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, _rank, _world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
"EleutherAI/gpt-neox-20b",
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = MambaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
tokenizer.bos_token_id = config.bos_token_id
tokenizer.eos_token_id = config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property
def batch_type(self) -> Type[MambaBatch]:
return MambaBatch
def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed
return None
def forward(
self,
input_ids: torch.Tensor,
past: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.model(
input_ids,
past=past,
)
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns()
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
batch_size = input_ids.shape[0]
max_seqlen = input_ids.shape[1]
dtype = input_ids.dtype
# Inference params
seqlen_og = 0
inf_cache = {}
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen
if batch.inference_params is None:
inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=seqlen_og,
key_value_memory_dict=inf_cache,
lengths_per_sample=lengths_per_sample,
)
# Allocate inference cache
for res_block in self.model.blocks:
block = res_block.mamba_block
conv_state = torch.zeros(
batch_size,
self.model.config.d_model * self.model.config.expand,
self.model.config.d_conv,
device=block.conv1d.weight.device,
dtype=block.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.model.config.d_model * self.model.config.expand,
self.model.config.d_state,
device=block.dt_proj.weight.device,
dtype=block.dt_proj.weight.dtype,
)
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state)
batch.inference_params = inference_params
# Forward pass
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params)
batch.inference_params = new_inference_params
# Results
generations: List[Generation] = []
stopped = True
# Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1),
accepted_ids,
)
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# Update values
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)