mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Merge branch 'rocm-awq-support' of https://github.com/huggingface/text-generation-inference into rocm-awq-support
This commit is contained in:
commit
e29fb799cb
10
Dockerfile
10
Dockerfile
@ -154,6 +154,12 @@ COPY server/Makefile-vllm Makefile
|
|||||||
# Build specific version of vllm
|
# Build specific version of vllm
|
||||||
RUN make build-vllm-cuda
|
RUN make build-vllm-cuda
|
||||||
|
|
||||||
|
# Build mamba kernels
|
||||||
|
FROM kernel-builder as mamba-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/Makefile-selective-scan Makefile
|
||||||
|
RUN make build-all
|
||||||
|
|
||||||
# Build megablocks
|
# Build megablocks
|
||||||
FROM kernel-builder as megablocks-builder
|
FROM kernel-builder as megablocks-builder
|
||||||
|
|
||||||
@ -205,6 +211,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
|
|||||||
# Copy builds artifacts from vllm builder
|
# Copy builds artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from mamba builder
|
||||||
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
|
@ -21,22 +21,6 @@ def test_generate(flan_t5_xxl_url, hf_headers):
|
|||||||
assert not response.details.tokens[0].special
|
assert not response.details.tokens[0].special
|
||||||
|
|
||||||
|
|
||||||
def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers):
|
|
||||||
client = Client(flan_t5_xxl_url, hf_headers)
|
|
||||||
response = client.generate("test", decoder_input_details=True)
|
|
||||||
|
|
||||||
assert response.generated_text != ""
|
|
||||||
assert response.details.finish_reason == FinishReason.EndOfSequenceToken
|
|
||||||
assert response.details.generated_tokens > 1
|
|
||||||
assert response.details.seed is None
|
|
||||||
assert len(response.details.prefill) == 1
|
|
||||||
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
|
|
||||||
assert len(response.details.tokens) > 1
|
|
||||||
assert response.details.tokens[0].id == 3
|
|
||||||
assert response.details.tokens[0].text == " "
|
|
||||||
assert not response.details.tokens[0].special
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
||||||
client = Client(flan_t5_xxl_url, hf_headers)
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
response = client.generate(
|
response = client.generate(
|
||||||
|
@ -62,7 +62,7 @@ class Client:
|
|||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: Optional[int] = None,
|
max_new_tokens: int = 20,
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
@ -157,7 +157,7 @@ class Client:
|
|||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: Optional[int] = None,
|
max_new_tokens: int = 20,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
@ -312,7 +312,7 @@ class AsyncClient:
|
|||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: Optional[int] = None,
|
max_new_tokens: int = 20,
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
@ -405,7 +405,7 @@ class AsyncClient:
|
|||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
max_new_tokens: Optional[int] = None,
|
max_new_tokens: int = 20,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
return_full_text: bool = False,
|
return_full_text: bool = False,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -9,7 +9,7 @@ class Parameters(BaseModel):
|
|||||||
# Activate logits sampling
|
# Activate logits sampling
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
# Maximum number of generated tokens
|
# Maximum number of generated tokens
|
||||||
max_new_tokens: Optional[int] = None
|
max_new_tokens: int = 20
|
||||||
# The parameter for repetition penalty. 1.0 means no penalty.
|
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||||
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Using TGI CLI
|
# Using TGI CLI
|
||||||
|
|
||||||
You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](./installation#install-cli).
|
You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](../installation#install-cli).
|
||||||
|
|
||||||
`text-generation-server` lets you download the model with `download-weights` command like below 👇
|
`text-generation-server` lets you download the model with `download-weights` command like below 👇
|
||||||
|
|
||||||
|
@ -4,6 +4,15 @@ Text Generation Inference (TGI) now supports the Messages API, which is fully co
|
|||||||
|
|
||||||
> **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature.
|
> **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature.
|
||||||
|
|
||||||
|
#### Table of Contents
|
||||||
|
|
||||||
|
- [Making a Request](#making-a-request)
|
||||||
|
- [Streaming](#streaming)
|
||||||
|
- [Synchronous](#synchronous)
|
||||||
|
- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints)
|
||||||
|
- [Cloud Providers](#cloud-providers)
|
||||||
|
- [Amazon SageMaker](#amazon-sagemaker)
|
||||||
|
|
||||||
## Making a Request
|
## Making a Request
|
||||||
|
|
||||||
You can make a request to TGI's Messages API using `curl`. Here's an example:
|
You can make a request to TGI's Messages API using `curl`. Here's an example:
|
||||||
@ -81,6 +90,38 @@ chat_completion = client.chat.completions.create(
|
|||||||
print(chat_completion)
|
print(chat_completion)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Hugging Face Inference Endpoints
|
||||||
|
|
||||||
|
The Messages API is integrated with [Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated).
|
||||||
|
Every endpoint that uses "Text Generation Inference" with an LLM, which has a chat template can now be used. Below is an example of how to use IE with TGI using OpenAI's Python client library:
|
||||||
|
|
||||||
|
> **Note:** Make sure to replace `base_url` with your endpoint URL and to include `v1/` at the end of the URL. The `api_key` should be replaced with your Hugging Face API key.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# init the client but point it to TGI
|
||||||
|
client = OpenAI(
|
||||||
|
# replace with your endpoint url, make sure to include "v1/" at the end
|
||||||
|
base_url="https://vlzz10eq3fol3429.us-east-1.aws.endpoints.huggingface.cloud/v1/",
|
||||||
|
# replace with your API key
|
||||||
|
api_key="hf_XXX"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
model="tgi",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant." },
|
||||||
|
{"role": "user", "content": "What is deep learning?"}
|
||||||
|
],
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# iterate and print stream
|
||||||
|
for message in chat_completion:
|
||||||
|
print(message.choices[0].delta.content, end="")
|
||||||
|
```
|
||||||
|
|
||||||
## Cloud Providers
|
## Cloud Providers
|
||||||
|
|
||||||
TGI can be deployed on various cloud providers for scalable and robust text generation. One such provider is Amazon SageMaker, which has recently added support for TGI. Here's how you can deploy TGI on Amazon SageMaker:
|
TGI can be deployed on various cloud providers for scalable and robust text generation. One such provider is Amazon SageMaker, which has recently added support for TGI. Here's how you can deploy TGI on Amazon SageMaker:
|
||||||
@ -114,7 +155,7 @@ hub = {
|
|||||||
huggingface_model = HuggingFaceModel(
|
huggingface_model = HuggingFaceModel(
|
||||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
|
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
|
||||||
env=hub,
|
env=hub,
|
||||||
role=role,
|
role=role,
|
||||||
)
|
)
|
||||||
|
|
||||||
# deploy model to SageMaker Inference
|
# deploy model to SageMaker Inference
|
||||||
@ -123,7 +164,7 @@ predictor = huggingface_model.deploy(
|
|||||||
instance_type="ml.g5.2xlarge",
|
instance_type="ml.g5.2xlarge",
|
||||||
container_startup_health_check_timeout=300,
|
container_startup_health_check_timeout=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
# send request
|
# send request
|
||||||
predictor.predict({
|
predictor.predict({
|
||||||
"messages": [
|
"messages": [
|
||||||
|
@ -1,193 +1,194 @@
|
|||||||
{
|
{
|
||||||
"generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L",
|
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 20,
|
"generated_tokens": 20,
|
||||||
"seed": null,
|
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 589,
|
"id": 589,
|
||||||
"text": "def",
|
"logprob": null,
|
||||||
"logprob": null
|
"text": "def"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"text": " ge",
|
"logprob": -8.5859375,
|
||||||
"logprob": -9.0234375
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"text": "ometric",
|
"logprob": -7.5859375,
|
||||||
"logprob": -9.0859375
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"text": "_",
|
"logprob": -0.2668457,
|
||||||
"logprob": -0.25878906
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"text": "mean",
|
"logprob": -1.6416016,
|
||||||
"logprob": -2.2109375
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"text": "(",
|
"logprob": -0.22705078,
|
||||||
"logprob": -0.30371094
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"text": "L",
|
"logprob": -5.2304688,
|
||||||
"logprob": -5.6054688
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"text": ":",
|
"logprob": -3.0976562,
|
||||||
"logprob": -3.0722656
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"text": " List",
|
"logprob": -1.1044922,
|
||||||
"logprob": -0.6879883
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"text": "[",
|
"logprob": -0.14294434,
|
||||||
"logprob": -0.38500977
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"text": "float",
|
"logprob": -0.32299805,
|
||||||
"logprob": -0.984375
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"text": "]):",
|
"logprob": -2.8164062,
|
||||||
"logprob": -2.5351562
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"text": "\n ",
|
"logprob": -0.1282959,
|
||||||
"logprob": -1.1738281,
|
"special": false,
|
||||||
"special": false
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 1524,
|
||||||
"text": " return",
|
"logprob": -0.97998047,
|
||||||
"logprob": -0.95947266,
|
"special": false,
|
||||||
"special": false
|
"text": " \"\"\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3632,
|
"id": 284,
|
||||||
"text": " sum",
|
"logprob": -0.7006836,
|
||||||
"logprob": -1.4199219,
|
"special": false,
|
||||||
"special": false
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 14883,
|
||||||
"text": "(",
|
"logprob": -2.1933594,
|
||||||
"logprob": -0.085876465,
|
"special": false,
|
||||||
"special": false
|
"text": " Calculate"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 322,
|
||||||
"text": "L",
|
"logprob": -0.2697754,
|
||||||
"logprob": -0.09875488,
|
"special": false,
|
||||||
"special": false
|
"text": " the"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 27,
|
|
||||||
"text": ")",
|
|
||||||
"logprob": -0.30517578,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 517,
|
|
||||||
"text": " /",
|
|
||||||
"logprob": -0.42089844,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2069,
|
|
||||||
"text": " len",
|
|
||||||
"logprob": -0.042053223,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 26,
|
|
||||||
"text": "(",
|
|
||||||
"logprob": -0.0011806488,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 62,
|
|
||||||
"text": "L",
|
|
||||||
"logprob": -0.0005259514,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 27,
|
|
||||||
"text": ")",
|
|
||||||
"logprob": -0.0017633438,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 478,
|
|
||||||
"text": "\n\n",
|
|
||||||
"logprob": -0.69189453,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 203,
|
|
||||||
"text": "\n",
|
|
||||||
"logprob": -0.041870117,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 589,
|
|
||||||
"text": "def",
|
|
||||||
"logprob": -0.27856445,
|
|
||||||
"special": false
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"text": " ge",
|
"logprob": -0.0836792,
|
||||||
"logprob": -1.7255859,
|
"special": false,
|
||||||
"special": false
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"text": "ometric",
|
"logprob": -0.018737793,
|
||||||
"logprob": -0.011291504,
|
"special": false,
|
||||||
"special": false
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 5651,
|
||||||
"text": "_",
|
"logprob": -0.028640747,
|
||||||
"logprob": -0.008430481,
|
"special": false,
|
||||||
"special": false
|
"text": " mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 432,
|
||||||
"text": "mean",
|
"logprob": -0.29467773,
|
||||||
"logprob": -0.025787354,
|
"special": false,
|
||||||
"special": false
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 312,
|
||||||
"text": "(",
|
"logprob": -0.31518555,
|
||||||
"logprob": -0.073913574,
|
"special": false,
|
||||||
"special": false
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 1149,
|
||||||
"text": "L",
|
"logprob": -0.20605469,
|
||||||
"logprob": -0.09967041,
|
"special": false,
|
||||||
"special": false
|
"text": " list"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 432,
|
||||||
|
"logprob": -0.23254395,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7515,
|
||||||
|
"logprob": -0.4489746,
|
||||||
|
"special": false,
|
||||||
|
"text": " numbers"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -0.6044922,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 446,
|
||||||
|
"logprob": -0.63964844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 499,
|
||||||
|
"logprob": -1.1953125,
|
||||||
|
"special": false,
|
||||||
|
"text": " :"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 753,
|
||||||
|
"logprob": -0.03515625,
|
||||||
|
"special": false,
|
||||||
|
"text": "param"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 498,
|
||||||
|
"logprob": -0.06311035,
|
||||||
|
"special": false,
|
||||||
|
"text": " L"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 44,
|
||||||
|
"logprob": -0.003414154,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1682,
|
||||||
|
"logprob": -1.3310547,
|
||||||
|
"special": false,
|
||||||
|
"text": " List"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
}
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List"
|
||||||
}
|
}
|
||||||
|
@ -11,57 +11,57 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -9.0234375,
|
"logprob": -8.5859375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.0859375,
|
"logprob": -7.5898438,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25830078,
|
"logprob": -0.26586914,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.1875,
|
"logprob": -1.6347656,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.30004883,
|
"logprob": -0.22705078,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.6171875,
|
"logprob": -5.2382812,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.078125,
|
"logprob": -3.0996094,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.68066406,
|
"logprob": -1.1025391,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.38745117,
|
"logprob": -0.14294434,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.9453125,
|
"logprob": -0.32226562,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.5371094,
|
"logprob": -2.8164062,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -69,19 +69,19 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -0.051635742,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 442,
|
||||||
"logprob": 0.0,
|
"logprob": -1.3134766,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " return"
|
"text": " return"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 11665,
|
"id": 11665,
|
||||||
"logprob": -1.2236328,
|
"logprob": -0.10021973,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " reduce"
|
"text": " reduce"
|
||||||
},
|
},
|
||||||
@ -129,7 +129,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 319,
|
"id": 319,
|
||||||
"logprob": 0.0,
|
"logprob": -0.42871094,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " *"
|
"text": " *"
|
||||||
},
|
},
|
||||||
@ -158,36 +158,37 @@
|
|||||||
"text": ")"
|
"text": ")"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 203,
|
"id": 1115,
|
||||||
"logprob": -0.12695312,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 203,
|
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " **"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 589,
|
"id": 308,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "def"
|
"text": " ("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 35,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " ge"
|
"text": "1"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 32,
|
||||||
|
"logprob": -0.31323242,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 34,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ometric"
|
"text": "0"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric"
|
"generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0"
|
||||||
}
|
}
|
||||||
|
@ -12,57 +12,57 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -9.0234375,
|
"logprob": -8.5859375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.0859375,
|
"logprob": -7.5820312,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25927734,
|
"logprob": -0.26708984,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.25,
|
"logprob": -1.6386719,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.30126953,
|
"logprob": -0.22717285,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.7539062,
|
"logprob": -5.234375,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.0878906,
|
"logprob": -3.1015625,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.6845703,
|
"logprob": -1.1083984,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.3918457,
|
"logprob": -0.14294434,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.8798828,
|
"logprob": -0.32592773,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.4980469,
|
"logprob": -2.8164062,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -70,67 +70,68 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -1.1533203,
|
"logprob": -0.12817383,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 1524,
|
||||||
"logprob": -0.91796875,
|
"logprob": -0.9863281,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " return"
|
"text": " \"\"\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3632,
|
"id": 284,
|
||||||
"logprob": -1.3291016,
|
"logprob": -0.7011719,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sum"
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 14883,
|
||||||
"logprob": -0.08062744,
|
"logprob": -2.2050781,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " Calculate"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 322,
|
||||||
"logprob": -0.097717285,
|
"logprob": -0.2668457,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27,
|
"id": 3226,
|
||||||
"logprob": -0.29003906,
|
"logprob": -0.08465576,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ")"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 517,
|
"id": 21017,
|
||||||
"logprob": -0.34958984,
|
"logprob": -0.019012451,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " /"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2069,
|
"id": 5651,
|
||||||
"logprob": -0.03829956,
|
"logprob": -0.028625488,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " len"
|
"text": " mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 432,
|
||||||
"logprob": -0.0011987686,
|
"logprob": -0.29418945,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 312,
|
||||||
"logprob": -0.00050878525,
|
"logprob": -0.3161621,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " a"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n return sum(L) / len(L"
|
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -145,57 +146,57 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -9.0234375,
|
"logprob": -8.5859375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.0859375,
|
"logprob": -7.59375,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25878906,
|
"logprob": -0.26953125,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.2109375,
|
"logprob": -1.640625,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.30371094,
|
"logprob": -0.22705078,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.6054688,
|
"logprob": -5.234375,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.0722656,
|
"logprob": -3.1132812,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.6879883,
|
"logprob": -1.1123047,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.38500977,
|
"logprob": -0.14294434,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.984375,
|
"logprob": -0.32299805,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.5351562,
|
"logprob": -2.8164062,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -203,67 +204,68 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -1.1738281,
|
"logprob": -0.12854004,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 1524,
|
||||||
"logprob": -0.9584961,
|
"logprob": -0.9897461,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " return"
|
"text": " \"\"\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3632,
|
"id": 284,
|
||||||
"logprob": -1.4169922,
|
"logprob": -0.69970703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sum"
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 14883,
|
||||||
"logprob": -0.085876465,
|
"logprob": -2.2050781,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " Calculate"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 322,
|
||||||
"logprob": -0.0982666,
|
"logprob": -0.2668457,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27,
|
"id": 3226,
|
||||||
"logprob": -0.3022461,
|
"logprob": -0.08496094,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ")"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 517,
|
"id": 21017,
|
||||||
"logprob": -0.40504883,
|
"logprob": -0.019012451,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " /"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2069,
|
"id": 5651,
|
||||||
"logprob": -0.041656494,
|
"logprob": -0.029037476,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " len"
|
"text": " mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 432,
|
||||||
"logprob": -0.0011844635,
|
"logprob": -0.2939453,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 312,
|
||||||
"logprob": -0.0005264282,
|
"logprob": -0.31591797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " a"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n return sum(L) / len(L"
|
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -278,57 +280,57 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -9.0234375,
|
"logprob": -8.5859375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.0859375,
|
"logprob": -7.5859375,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25927734,
|
"logprob": -0.26586914,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.25,
|
"logprob": -1.6347656,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.30126953,
|
"logprob": -0.22766113,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.7539062,
|
"logprob": -5.2265625,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.0878906,
|
"logprob": -3.0976562,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.6845703,
|
"logprob": -1.1025391,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.3918457,
|
"logprob": -0.1427002,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.8798828,
|
"logprob": -0.32592773,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.4980469,
|
"logprob": -2.8164062,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -336,67 +338,68 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -1.1533203,
|
"logprob": -0.13012695,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 1524,
|
||||||
"logprob": -0.9165039,
|
"logprob": -0.98046875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " return"
|
"text": " \"\"\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3632,
|
"id": 284,
|
||||||
"logprob": -1.328125,
|
"logprob": -0.69921875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sum"
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 14883,
|
||||||
"logprob": -0.07946777,
|
"logprob": -2.1992188,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " Calculate"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 322,
|
||||||
"logprob": -0.09820557,
|
"logprob": -0.2668457,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27,
|
"id": 3226,
|
||||||
"logprob": -0.28930664,
|
"logprob": -0.083496094,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ")"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 517,
|
"id": 21017,
|
||||||
"logprob": -0.34592773,
|
"logprob": -0.01902771,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " /"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2069,
|
"id": 5651,
|
||||||
"logprob": -0.038330078,
|
"logprob": -0.029006958,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " len"
|
"text": " mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 432,
|
||||||
"logprob": -0.0011940002,
|
"logprob": -0.29248047,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 312,
|
||||||
"logprob": -0.00050878525,
|
"logprob": -0.3161621,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " a"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n return sum(L) / len(L"
|
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -411,57 +414,57 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -9.0234375,
|
"logprob": -8.5859375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.0859375,
|
"logprob": -7.5859375,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25927734,
|
"logprob": -0.26904297,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.25,
|
"logprob": -1.6386719,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.30126953,
|
"logprob": -0.22705078,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.7539062,
|
"logprob": -5.234375,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.0878906,
|
"logprob": -3.1132812,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.6845703,
|
"logprob": -1.1074219,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.3918457,
|
"logprob": -0.14477539,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.8798828,
|
"logprob": -0.3256836,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.4980469,
|
"logprob": -2.8027344,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -469,66 +472,67 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -1.1533203,
|
"logprob": -0.12915039,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 1524,
|
||||||
"logprob": -0.91259766,
|
"logprob": -0.98535156,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " return"
|
"text": " \"\"\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3632,
|
"id": 284,
|
||||||
"logprob": -1.3251953,
|
"logprob": -0.69921875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sum"
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 14883,
|
||||||
"logprob": -0.08062744,
|
"logprob": -2.2011719,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " Calculate"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 322,
|
||||||
"logprob": -0.09906006,
|
"logprob": -0.26708984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27,
|
"id": 3226,
|
||||||
"logprob": -0.28979492,
|
"logprob": -0.08502197,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ")"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 517,
|
"id": 21017,
|
||||||
"logprob": -0.35958984,
|
"logprob": -0.019012451,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " /"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2069,
|
"id": 5651,
|
||||||
"logprob": -0.038604736,
|
"logprob": -0.028625488,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " len"
|
"text": " mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 432,
|
||||||
"logprob": -0.0011901855,
|
"logprob": -0.29589844,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 312,
|
||||||
"logprob": -0.0005078316,
|
"logprob": -0.31591797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " a"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n return sum(L) / len(L"
|
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -0,0 +1,73 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.3552246,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.38378906,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.140625,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.5551758,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.59033203,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.70654297,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 747,
|
||||||
|
"logprob": -2.0410156,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1511,
|
||||||
|
"logprob": -2.3789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0026435852,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.2841797,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
|
}
|
@ -0,0 +1,99 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2502,
|
||||||
|
"logprob": null,
|
||||||
|
"text": " red"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.5234375,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8862,
|
||||||
|
"logprob": -3.4433594,
|
||||||
|
"text": " yellow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.43017578,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 209,
|
||||||
|
"logprob": -8.21875,
|
||||||
|
"text": " "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 395,
|
||||||
|
"logprob": -0.46411133,
|
||||||
|
"special": false,
|
||||||
|
"text": "and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13735,
|
||||||
|
"logprob": -2.1132812,
|
||||||
|
"special": false,
|
||||||
|
"text": " orange"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 313,
|
||||||
|
"logprob": -1.2128906,
|
||||||
|
"special": false,
|
||||||
|
"text": " ("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 249,
|
||||||
|
"logprob": -2.3671875,
|
||||||
|
"special": false,
|
||||||
|
"text": "in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 253,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1340,
|
||||||
|
"logprob": -1.640625,
|
||||||
|
"special": false,
|
||||||
|
"text": " order"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 597,
|
||||||
|
"logprob": -0.5488281,
|
||||||
|
"special": false,
|
||||||
|
"text": " they"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3176,
|
||||||
|
"logprob": -0.48608398,
|
||||||
|
"special": false,
|
||||||
|
"text": " appear"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 275,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "blue, red, yellow, \nand orange (in the order they appear in"
|
||||||
|
}
|
@ -0,0 +1,398 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.8125,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.828125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -3.0,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1484375,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.3552246,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.38378906,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.1279297,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.5595703,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.60253906,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.7050781,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 747,
|
||||||
|
"logprob": -2.0488281,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1511,
|
||||||
|
"logprob": -2.3808594,
|
||||||
|
"special": false,
|
||||||
|
"text": " type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0026416779,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.2851562,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.78027344,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.8203125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -2.9902344,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1523438,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.35351562,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.38256836,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.1269531,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.54541016,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.59765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 747,
|
||||||
|
"logprob": -2.0585938,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1511,
|
||||||
|
"logprob": -2.3789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0027446747,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.2851562,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.78027344,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.8203125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -2.9902344,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1523438,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.35351562,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.38256836,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.1269531,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.54541016,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.59765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 747,
|
||||||
|
"logprob": -2.0585938,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1511,
|
||||||
|
"logprob": -2.3789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0027446747,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.2851562,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.78027344,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.8203125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -2.9902344,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1523438,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.35351562,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.38256836,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.1269531,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.54541016,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.59765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 747,
|
||||||
|
"logprob": -2.0585938,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1511,
|
||||||
|
"logprob": -2.3789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0027446747,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.2851562,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
|
}
|
||||||
|
]
|
59
integration-tests/models/test_mamba.py
Normal file
59
integration-tests/models/test_mamba.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def fused_kernel_mamba_handle(launcher):
|
||||||
|
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
||||||
|
await fused_kernel_mamba_handle.health(300)
|
||||||
|
return fused_kernel_mamba_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
|
response = await fused_kernel_mamba.generate(
|
||||||
|
"What is Deep Learning?", max_new_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response.generated_text == "\n\nDeep learning is a new type of machine"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
|
response = await fused_kernel_mamba.generate(
|
||||||
|
"blue, red, yellow, ",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -32,7 +32,7 @@ reqwest = { version = "0.11.20", features = [] }
|
|||||||
serde = "1.0.188"
|
serde = "1.0.188"
|
||||||
serde_json = "1.0.107"
|
serde_json = "1.0.107"
|
||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { version = "0.14.0", features = ["http"] }
|
tokenizers = { version = "0.15.1", features = ["http"] }
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.14"
|
tokio-stream = "0.1.14"
|
||||||
tower-http = { version = "0.4.4", features = ["cors"] }
|
tower-http = { version = "0.4.4", features = ["cors"] }
|
||||||
|
@ -198,6 +198,7 @@ impl Infer {
|
|||||||
messages,
|
messages,
|
||||||
eos_token: eos_token.as_deref(),
|
eos_token: eos_token.as_deref(),
|
||||||
bos_token: bos_token.as_deref(),
|
bos_token: bos_token.as_deref(),
|
||||||
|
add_generation_prompt: true,
|
||||||
})
|
})
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
@ -806,21 +807,14 @@ mod tests {
|
|||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: Some("[EOS]"),
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
result,
|
result,
|
||||||
r#"### User:
|
"### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n"
|
||||||
Hi!
|
|
||||||
|
|
||||||
### Assistant:
|
|
||||||
Hello how can I help?### User:
|
|
||||||
What is Deep Learning?
|
|
||||||
|
|
||||||
### Assistant:
|
|
||||||
magic!"#
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -878,6 +872,7 @@ magic!"#
|
|||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: Some("[EOS]"),
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
||||||
@ -943,9 +938,60 @@ magic!"#
|
|||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: Some("[EOS]"),
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
|
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_template_valid_with_add_generation_prompt() {
|
||||||
|
let mut env = Environment::new();
|
||||||
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
|
let source = r#"
|
||||||
|
{% for message in messages %}
|
||||||
|
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
|
||||||
|
{% endfor %}
|
||||||
|
{% if add_generation_prompt %}
|
||||||
|
{{ '<|im_start|>assistant\n' }}
|
||||||
|
{% endif %}"#;
|
||||||
|
|
||||||
|
// trim all the whitespace
|
||||||
|
let source = source
|
||||||
|
.lines()
|
||||||
|
.map(|line| line.trim())
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
let tmpl = env.template_from_str(&source);
|
||||||
|
|
||||||
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hi!".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: "Hello how can I help?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: "magic!".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
bos_token: Some("[BOS]"),
|
||||||
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -37,7 +37,7 @@ pub struct HubTokenizerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
impl HubTokenizerConfig {
|
||||||
pub fn from_file(filename: &str) -> Self {
|
pub fn from_file(filename: &std::path::Path) -> Self {
|
||||||
let content = std::fs::read_to_string(filename).unwrap();
|
let content = std::fs::read_to_string(filename).unwrap();
|
||||||
serde_json::from_str(&content).unwrap_or_default()
|
serde_json::from_str(&content).unwrap_or_default()
|
||||||
}
|
}
|
||||||
@ -398,6 +398,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: Option<&'a str>,
|
eos_token: Option<&'a str>,
|
||||||
|
add_generation_prompt: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
@ -154,12 +154,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
let local_path = Path::new(&tokenizer_name);
|
let local_path = Path::new(&tokenizer_name);
|
||||||
let local_model = local_path.exists() && local_path.is_dir();
|
let local_model = local_path.exists() && local_path.is_dir();
|
||||||
|
|
||||||
// Load tokenizer config
|
|
||||||
// This will be used to format the chat template
|
|
||||||
let local_tokenizer_config_path =
|
|
||||||
tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string());
|
|
||||||
let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists();
|
|
||||||
|
|
||||||
// Shared API builder initialization
|
// Shared API builder initialization
|
||||||
let api_builder = || {
|
let api_builder = || {
|
||||||
let mut builder = ApiBuilder::new()
|
let mut builder = ApiBuilder::new()
|
||||||
@ -230,24 +224,35 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||||
let tokenizer_config = if local_tokenizer_config {
|
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
||||||
|
tracing::info!("Using local tokenizer config from user specified path");
|
||||||
|
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
|
||||||
|
} else if local_model {
|
||||||
tracing::info!("Using local tokenizer config");
|
tracing::info!("Using local tokenizer config");
|
||||||
HubTokenizerConfig::from_file(&local_tokenizer_config_path)
|
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
|
||||||
} else if let Some(api) = api {
|
|
||||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
|
||||||
get_tokenizer_config(&api.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.unwrap_or_else(|| "main".to_string()),
|
|
||||||
)))
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub.");
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
tracing::warn!("Could not find tokenizer config locally and no revision specified");
|
match api {
|
||||||
HubTokenizerConfig::default()
|
Some(api) => {
|
||||||
|
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
||||||
|
let repo = Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.unwrap_or("main".to_string()),
|
||||||
|
);
|
||||||
|
get_tokenizer_config(&api.repo(repo))
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
tracing::warn!(
|
||||||
|
"Could not retrieve tokenizer config from the Hugging Face hub."
|
||||||
|
);
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
|
@ -936,6 +936,7 @@ pub async fn run(
|
|||||||
// Define base and health routes
|
// Define base and health routes
|
||||||
let base_routes = Router::new()
|
let base_routes = Router::new()
|
||||||
.route("/", post(compat_generate))
|
.route("/", post(compat_generate))
|
||||||
|
.route("/", get(health))
|
||||||
.route("/info", get(get_model_info))
|
.route("/info", get(get_model_info))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
|
1
server/.gitignore
vendored
1
server/.gitignore
vendored
@ -161,3 +161,4 @@ flash-attention-v2/
|
|||||||
vllm/
|
vllm/
|
||||||
llm-awq/
|
llm-awq/
|
||||||
eetq/
|
eetq/
|
||||||
|
mamba/
|
||||||
|
@ -3,6 +3,7 @@ include Makefile-flash-att-v2
|
|||||||
include Makefile-vllm
|
include Makefile-vllm
|
||||||
include Makefile-awq
|
include Makefile-awq
|
||||||
include Makefile-eetq
|
include Makefile-eetq
|
||||||
|
include Makefile-selective-scan
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
28
server/Makefile-selective-scan
Normal file
28
server/Makefile-selective-scan
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
|
||||||
|
|
||||||
|
causal-conv1d:
|
||||||
|
rm -rf causal-conv1d
|
||||||
|
git clone https://github.com/Dao-AILab/causal-conv1d.git
|
||||||
|
|
||||||
|
build-causal-conv1d: causal-conv1d
|
||||||
|
cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
|
||||||
|
cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
|
||||||
|
|
||||||
|
install-causal-conv1d: build-causal-conv1d
|
||||||
|
pip uninstall causal-conv1d -y || true
|
||||||
|
cd causal-conv1d/ && pip install .
|
||||||
|
|
||||||
|
# selective-scan dependends on causal-conv1d
|
||||||
|
selective-scan:
|
||||||
|
rm -rf mamba
|
||||||
|
git clone https://github.com/state-spaces/mamba.git mamba
|
||||||
|
|
||||||
|
build-selective-scan: selective-scan
|
||||||
|
cd mamba/ && git fetch && git checkout $(selective_scan_commit)
|
||||||
|
cd mamba && python setup.py build
|
||||||
|
|
||||||
|
install-selective-scan: install-causal-conv1d build-selective-scan
|
||||||
|
pip uninstall selective-scan-cuda -y || true
|
||||||
|
cd mamba && pip install .
|
||||||
|
|
||||||
|
build-all: build-causal-conv1d build-selective-scan
|
17
server/poetry.lock
generated
17
server/poetry.lock
generated
@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "accelerate"
|
name = "accelerate"
|
||||||
@ -1589,30 +1589,32 @@ xml = ["lxml (>=4.9.2)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "peft"
|
name = "peft"
|
||||||
version = "0.4.0"
|
version = "0.8.2"
|
||||||
description = "Parameter-Efficient Fine-Tuning (PEFT)"
|
description = "Parameter-Efficient Fine-Tuning (PEFT)"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.8.0"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "peft-0.4.0-py3-none-any.whl", hash = "sha256:2cf992772a6d703814477e0bdcdadd68cb8ea388111ce2d793dd2ff0e438f357"},
|
{file = "peft-0.8.2-py3-none-any.whl", hash = "sha256:4a9c81c38e689fd4043b2757cd0e2b526a9b8b8fd04f8442df2c4824b32c2505"},
|
||||||
{file = "peft-0.4.0.tar.gz", hash = "sha256:e768fa22d6e9f32aa7e891f0d06f355960278ca4dc0cdd96bff71f6f06269207"},
|
{file = "peft-0.8.2.tar.gz", hash = "sha256:bbdf61db2d8ca503e894edc64016038e6f34b7b522374bad09a22af41882e7ac"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
accelerate = "*"
|
accelerate = ">=0.21.0"
|
||||||
|
huggingface-hub = ">=0.17.0"
|
||||||
numpy = ">=1.17"
|
numpy = ">=1.17"
|
||||||
packaging = ">=20.0"
|
packaging = ">=20.0"
|
||||||
psutil = "*"
|
psutil = "*"
|
||||||
pyyaml = "*"
|
pyyaml = "*"
|
||||||
safetensors = "*"
|
safetensors = "*"
|
||||||
torch = ">=1.13.0"
|
torch = ">=1.13.0"
|
||||||
|
tqdm = "*"
|
||||||
transformers = "*"
|
transformers = "*"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["black (>=22.0,<23.0)", "hf-doc-builder", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"]
|
dev = ["black (>=22.0,<23.0)", "hf-doc-builder", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"]
|
||||||
docs-specific = ["hf-doc-builder"]
|
docs-specific = ["hf-doc-builder"]
|
||||||
quality = ["black (>=22.0,<23.0)", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"]
|
quality = ["black (>=22.0,<23.0)", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"]
|
||||||
test = ["black (>=22.0,<23.0)", "datasets", "diffusers", "hf-doc-builder", "parameterized", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"]
|
test = ["black (>=22.0,<23.0)", "datasets", "diffusers (<0.21.0)", "hf-doc-builder", "parameterized", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.0.241)", "scipy", "urllib3 (<=2.0.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pillow"
|
name = "pillow"
|
||||||
@ -1893,6 +1895,7 @@ files = [
|
|||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
@ -2962,4 +2965,4 @@ torch = ["torch"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<3.13"
|
python-versions = ">=3.9,<3.13"
|
||||||
content-hash = "33d533d21d14c258678a8c4bb28e2a15e8ebe5ca35d8589cbfe4a7b7d2e79a90"
|
content-hash = "f7529125bdd7ce142082ce4969edbda5d9b67b6209f199194c54198829f5dc64"
|
||||||
|
@ -30,7 +30,7 @@ transformers = "^4.37.1"
|
|||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
peft = { version = "^0.4.0", optional = true }
|
peft = { version = "^0.8.2", optional = true }
|
||||||
torch = { version = "^2.1.1", optional = true }
|
torch = { version = "^2.1.1", optional = true }
|
||||||
scipy = "^1.11.1"
|
scipy = "^1.11.1"
|
||||||
pillow = "^10.0.0"
|
pillow = "^10.0.0"
|
||||||
|
@ -76,6 +76,15 @@ if FLASH_ATTENTION:
|
|||||||
__all__.append(FlashMixtral)
|
__all__.append(FlashMixtral)
|
||||||
__all__.append(FlashPhi)
|
__all__.append(FlashPhi)
|
||||||
|
|
||||||
|
MAMBA_AVAILABLE = True
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.mamba import Mamba
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Could not import Mamba: {e}")
|
||||||
|
MAMBA_AVAILABLE = False
|
||||||
|
|
||||||
|
if MAMBA_AVAILABLE:
|
||||||
|
__all__.append(Mamba)
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -164,7 +173,25 @@ def get_model(
|
|||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||||
|
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict.get("model_type", None)
|
||||||
|
if model_type is None:
|
||||||
|
# TODO: fix how we determine model type for Mamba
|
||||||
|
if "ssm_cfg" in config_dict:
|
||||||
|
# *only happens in Mamba case
|
||||||
|
model_type = "ssm"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not determine model type for {model_id} revision {revision}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "ssm":
|
||||||
|
return Mamba(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -69,9 +69,17 @@ def _load_multi_mqa_gptq(
|
|||||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
qzeros = qzeros.to(device=weights.device)
|
qzeros = qzeros.to(device=weights.device)
|
||||||
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
bits, groupsize, _, quant_method, = weights._get_gptq_params()
|
||||||
g_idx = g_idx.to(device=weights.device)
|
if quant_method == "gptq":
|
||||||
bits, groupsize, _ = weights._get_gptq_params()
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
|
g_idx = g_idx.to(device=weights.device)
|
||||||
|
elif quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
from text_generation_server.utils.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
|
@ -0,0 +1,194 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||||
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||||
|
from mamba_ssm.utils.generation import InferenceParams
|
||||||
|
from torch import nn
|
||||||
|
from typing import Optional, Tuple, Any
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
FastRMSNorm,
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||||
|
import math
|
||||||
|
|
||||||
|
class MambaConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=50280,
|
||||||
|
d_model=768,
|
||||||
|
d_state=16,
|
||||||
|
n_layer=32,
|
||||||
|
layer_norm_epsilon=1e-5,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
expand=2,
|
||||||
|
dt_rank="auto",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.d_model = d_model
|
||||||
|
self.d_inner = d_model * 2
|
||||||
|
self.d_conv = 4
|
||||||
|
self.d_state = d_state
|
||||||
|
self.expand = expand
|
||||||
|
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
class MambaBlock(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = int(prefix.split(".")[2])
|
||||||
|
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
|
||||||
|
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
|
||||||
|
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
|
||||||
|
self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False)
|
||||||
|
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False)
|
||||||
|
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
|
||||||
|
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
|
||||||
|
self.D = weights.get_tensor(f"{prefix}.D")
|
||||||
|
self.activation = "silu"
|
||||||
|
self.dt_rank = config.dt_rank
|
||||||
|
self.d_state = config.d_state
|
||||||
|
self.d_conv = config.d_conv
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
# inference_params
|
||||||
|
def forward(self, hidden_states: torch.Tensor, inference_params=None):
|
||||||
|
_, seqlen, _ = hidden_states.shape
|
||||||
|
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||||
|
|
||||||
|
if inference_params.seqlen_offset > 0:
|
||||||
|
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
|
||||||
|
return out, conv_state, ssm_state
|
||||||
|
|
||||||
|
projected_states = self.in_proj(hidden_states).transpose(1,2)
|
||||||
|
x, z = projected_states.chunk(2, dim=1)
|
||||||
|
conv_state = F.pad(x, (self.d_conv - seqlen, 0))
|
||||||
|
x = causal_conv1d_fn(
|
||||||
|
x=x,
|
||||||
|
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
||||||
|
bias=self.conv1d.bias,
|
||||||
|
activation=self.activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We're careful here about the layout, to avoid extra transposes.
|
||||||
|
# We want dt to have d as the slowest moving dimension
|
||||||
|
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||||
|
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
||||||
|
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
|
dt = self.dt_proj.weight @ dt.t()
|
||||||
|
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
||||||
|
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
y, last_state = selective_scan_fn(
|
||||||
|
x,
|
||||||
|
dt,
|
||||||
|
self.negA,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
self.D.float(),
|
||||||
|
z=z,
|
||||||
|
delta_bias=self.dt_proj.bias.float(),
|
||||||
|
delta_softplus=True,
|
||||||
|
return_last_state=True,
|
||||||
|
)
|
||||||
|
y = rearrange(y, "b d l -> b l d")
|
||||||
|
attn_outputs = self.out_proj(y)
|
||||||
|
return attn_outputs, conv_state, last_state
|
||||||
|
|
||||||
|
def step(self, hidden_states, conv_state, ssm_state):
|
||||||
|
_xz = self.in_proj(hidden_states)
|
||||||
|
_x, _z = _xz.chunk(2, dim=-1) # (B D)
|
||||||
|
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1)
|
||||||
|
conv_out = causal_conv1d_fn(
|
||||||
|
x=conv_state_new,
|
||||||
|
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
||||||
|
bias=self.conv1d.bias,
|
||||||
|
activation=self.activation
|
||||||
|
)
|
||||||
|
conv_state = conv_state_new[:, :, 1:]
|
||||||
|
bsz, seqlen, dim = hidden_states.shape
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(bsz, seqlen, dim),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
for i in range(0, bsz):
|
||||||
|
x = conv_out[i:i+1,:,-1]
|
||||||
|
z = _z[i:i+1, -1, :]
|
||||||
|
x_db = self.x_proj(x)
|
||||||
|
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
|
dt = F.linear(dt, self.dt_proj.weight)
|
||||||
|
y = selective_state_update(
|
||||||
|
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
||||||
|
)
|
||||||
|
out = self.out_proj(y)
|
||||||
|
output_tensor[i] = out
|
||||||
|
|
||||||
|
return output_tensor, conv_state, ssm_state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
||||||
|
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
inference_params: Optional[Any] = None,
|
||||||
|
):
|
||||||
|
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||||
|
shape = residual.shape
|
||||||
|
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
|
||||||
|
hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
|
||||||
|
return hidden_states, residual, conv_state, last_ssm_state
|
||||||
|
|
||||||
|
class MambaModel(nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = "backbone"
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
|
||||||
|
)
|
||||||
|
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
|
||||||
|
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for block in self.blocks:
|
||||||
|
hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params)
|
||||||
|
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + residual if residual is not None else hidden_states
|
||||||
|
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
||||||
|
hidden_states = hidden_states.view(residual.shape)
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
# update the offset for the next inference using these params
|
||||||
|
inference_params.seqlen_offset += input_ids.size(1)
|
||||||
|
return logits, input_ids, inference_params
|
656
server/text_generation_server/models/mamba.py
Normal file
656
server/text_generation_server/models/mamba.py
Normal file
@ -0,0 +1,656 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
from typing import Optional
|
||||||
|
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||||
|
MambaConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
import time
|
||||||
|
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel
|
||||||
|
from text_generation_server.models import Model
|
||||||
|
from typing import Any, List, Optional, Tuple, Type, Dict
|
||||||
|
from text_generation_server.models.types import (
|
||||||
|
Batch,
|
||||||
|
Tokens,
|
||||||
|
Generation,
|
||||||
|
GeneratedText,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
from mamba_ssm.utils.generation import InferenceParams
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MambaBatch(Batch):
|
||||||
|
batch_id: int
|
||||||
|
requests: List[generate_pb2.Request]
|
||||||
|
requests_idx_mapping: Dict[int, int]
|
||||||
|
|
||||||
|
# Decoder values
|
||||||
|
input_ids: torch.Tensor
|
||||||
|
|
||||||
|
# All tokens
|
||||||
|
all_input_ids: List[torch.Tensor]
|
||||||
|
|
||||||
|
# Lengths of all generations present in the batch
|
||||||
|
input_lengths: List[int]
|
||||||
|
prefix_offsets: List[int]
|
||||||
|
read_offsets: List[int]
|
||||||
|
|
||||||
|
# Generation helpers
|
||||||
|
next_token_choosers: List[NextTokenChooser]
|
||||||
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
top_n_tokens: List[int]
|
||||||
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
|
# Metadata used for padding
|
||||||
|
max_input_length: int
|
||||||
|
padding_right_offset: int
|
||||||
|
|
||||||
|
# Maximum number of tokens this batch will grow to
|
||||||
|
max_tokens: int
|
||||||
|
|
||||||
|
# Past metadata
|
||||||
|
keys_head_dim_last: bool = True
|
||||||
|
|
||||||
|
# Inference params
|
||||||
|
inference_params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
|
return generate_pb2.CachedBatch(
|
||||||
|
id=self.batch_id,
|
||||||
|
request_ids=[r.id for r in self.requests],
|
||||||
|
size=len(self),
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "MambaBatch":
|
||||||
|
inputs = []
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
prefix_offsets = []
|
||||||
|
read_offsets = []
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
|
# Parse batch
|
||||||
|
max_truncation = 0
|
||||||
|
padding_right_offset = 0
|
||||||
|
max_decode_tokens = 0
|
||||||
|
for i, r in enumerate(pb.requests):
|
||||||
|
requests_idx_mapping[r.id] = i
|
||||||
|
inputs.append(r.inputs)
|
||||||
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
|
r.stopping_parameters, tokenizer
|
||||||
|
)
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||||
|
padding_right_offset = max(
|
||||||
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenized_inputs = tokenizer(
|
||||||
|
inputs,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
return_token_type_ids=False,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_truncation,
|
||||||
|
).to(device)
|
||||||
|
for _ in pb.requests:
|
||||||
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
|
prefix_offsets.append(input_len - 5)
|
||||||
|
read_offsets.append(input_len)
|
||||||
|
|
||||||
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
max_input_length = input_lengths.max()
|
||||||
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||||
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
return cls(
|
||||||
|
batch_id=pb.id,
|
||||||
|
requests=pb.requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=input_ids,
|
||||||
|
# past_input_ids=None,
|
||||||
|
all_input_ids=list(all_input_ids),
|
||||||
|
input_lengths=input_lengths.tolist(),
|
||||||
|
prefix_offsets=prefix_offsets,
|
||||||
|
read_offsets=read_offsets,
|
||||||
|
next_token_choosers=next_token_choosers,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
|
max_input_length=max_input_length.item(),
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]:
|
||||||
|
if len(request_ids) == 0:
|
||||||
|
raise ValueError("Batch must have at least one request")
|
||||||
|
if len(request_ids) == len(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
keep_indices = []
|
||||||
|
|
||||||
|
# New values after filtering
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
requests = []
|
||||||
|
input_lengths = []
|
||||||
|
prefix_offsets = []
|
||||||
|
read_offsets = []
|
||||||
|
all_input_ids = []
|
||||||
|
max_input_length = 0
|
||||||
|
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
|
total_remaining_decode_tokens = 0
|
||||||
|
new_padding_right_offset = 0
|
||||||
|
|
||||||
|
indices = []
|
||||||
|
for i, request_id in enumerate(request_ids):
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
requests_idx_mapping[request_id] = i
|
||||||
|
keep_indices.append(idx)
|
||||||
|
|
||||||
|
requests.append(self.requests[idx])
|
||||||
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
|
||||||
|
request_input_length = self.input_lengths[idx]
|
||||||
|
input_lengths.append(request_input_length)
|
||||||
|
max_input_length = max(max_input_length, request_input_length)
|
||||||
|
indices.append(idx)
|
||||||
|
|
||||||
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
|
remaining_decode_tokens = (
|
||||||
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
|
)
|
||||||
|
total_remaining_decode_tokens += remaining_decode_tokens
|
||||||
|
new_padding_right_offset = max(
|
||||||
|
new_padding_right_offset, remaining_decode_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||||
|
input_ids = self.input_ids[keep_indices]
|
||||||
|
|
||||||
|
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||||
|
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||||
|
|
||||||
|
self.requests = requests
|
||||||
|
self.requests_idx_mapping = requests_idx_mapping
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self.all_input_ids = all_input_ids
|
||||||
|
self.input_lengths = input_lengths
|
||||||
|
self.prefix_offsets = prefix_offsets
|
||||||
|
self.read_offsets = read_offsets
|
||||||
|
self.next_token_choosers = next_token_choosers
|
||||||
|
self.stopping_criterias = stopping_criterias
|
||||||
|
self.top_n_tokens = top_n_tokens
|
||||||
|
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||||
|
self.max_input_length = max_input_length
|
||||||
|
self.padding_right_offset = new_padding_right_offset
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
||||||
|
key_value_memory_dict = {}
|
||||||
|
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items():
|
||||||
|
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
|
||||||
|
self.inference_params.key_value_memory_dict = key_value_memory_dict
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
|
||||||
|
# Used for padding
|
||||||
|
total_batch_size = 0
|
||||||
|
max_input_length = 0
|
||||||
|
padding_right_offset = 0
|
||||||
|
for batch in batches:
|
||||||
|
total_batch_size += len(batch)
|
||||||
|
max_input_length = max(max_input_length, batch.max_input_length)
|
||||||
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||||
|
|
||||||
|
# Batch attributes
|
||||||
|
requests = []
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
input_lengths = []
|
||||||
|
prefix_offsets = []
|
||||||
|
read_offsets = []
|
||||||
|
all_input_ids = []
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
max_tokens = 0
|
||||||
|
max_seqlen = 0
|
||||||
|
batch_size = 0
|
||||||
|
seqlen_offset = 0
|
||||||
|
|
||||||
|
# Batch tensors
|
||||||
|
input_ids = None
|
||||||
|
top_n_tokens_tensor = None
|
||||||
|
|
||||||
|
# Used for slicing correctly inside the tensors
|
||||||
|
# Equivalent to a cumsum on batch sizes
|
||||||
|
start_index = 0
|
||||||
|
for i, batch in enumerate(batches):
|
||||||
|
requests.extend(batch.requests)
|
||||||
|
input_lengths.extend(batch.input_lengths)
|
||||||
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
|
read_offsets.extend(batch.read_offsets)
|
||||||
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
|
else:
|
||||||
|
# We need to offset the mapping for each batch by the cumulative batch size
|
||||||
|
for k, v in batch.requests_idx_mapping.items():
|
||||||
|
requests_idx_mapping[k] = v + start_index
|
||||||
|
|
||||||
|
# Slicing end index for this batch
|
||||||
|
end_index = start_index + len(batch)
|
||||||
|
|
||||||
|
# Create empty tensor
|
||||||
|
# input_ids is always of shape [batch_size, 1]
|
||||||
|
# We do not need to pad it
|
||||||
|
if input_ids is None:
|
||||||
|
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
|
||||||
|
# Copy to correct indices
|
||||||
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
|
|
||||||
|
if top_n_tokens_tensor is None:
|
||||||
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
|
total_batch_size,
|
||||||
|
)
|
||||||
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
|
|
||||||
|
# Add eventual padding tokens that were added while concatenating
|
||||||
|
max_tokens += batch.max_tokens + (
|
||||||
|
max_input_length - batch.max_input_length
|
||||||
|
) * len(batch)
|
||||||
|
|
||||||
|
max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen)
|
||||||
|
seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset)
|
||||||
|
batch_size += batch.inference_params.max_batch_size
|
||||||
|
|
||||||
|
start_index = end_index
|
||||||
|
|
||||||
|
|
||||||
|
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape
|
||||||
|
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
|
||||||
|
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
|
||||||
|
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
|
||||||
|
device = batches[0].inference_params.key_value_memory_dict[0][0].device
|
||||||
|
|
||||||
|
key_value_memory_dict = {}
|
||||||
|
for i in range(n_blocks):
|
||||||
|
conv_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
d_model,
|
||||||
|
d_conv,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
ssm_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
d_model,
|
||||||
|
d_state,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
key_value_memory_dict[i] = (conv_state, ssm_state)
|
||||||
|
lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
inference_params = InferenceParams(
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
seqlen_offset=seqlen_offset,
|
||||||
|
key_value_memory_dict=key_value_memory_dict,
|
||||||
|
lengths_per_sample=lengths_per_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_batch = 0
|
||||||
|
for batch in batches:
|
||||||
|
for i in range(n_blocks):
|
||||||
|
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
|
||||||
|
batch_size = batch.inference_params.max_batch_size
|
||||||
|
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
|
||||||
|
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
|
||||||
|
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
|
||||||
|
current_batch += batch_size
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=batches[0].batch_id,
|
||||||
|
requests=requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=input_ids,
|
||||||
|
all_input_ids=all_input_ids,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_offsets=prefix_offsets,
|
||||||
|
read_offsets=read_offsets,
|
||||||
|
next_token_choosers=next_token_choosers,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
|
max_input_length=max_input_length,
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
inference_params=inference_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.requests)
|
||||||
|
|
||||||
|
class Mamba(Model):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"EleutherAI/gpt-neox-20b",
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
config = MambaConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.bos_token_id = config.bos_token_id
|
||||||
|
tokenizer.eos_token_id = config.eos_token_id
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
config.quantize = quantize
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
model = MambaModel(config, weights)
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(Mamba, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_type(self) -> Type[MambaBatch]:
|
||||||
|
return MambaBatch
|
||||||
|
|
||||||
|
def warmup(self, batch) -> Optional[int]:
|
||||||
|
# TODO: implement warmup for Mamba if needed
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
past: Optional[List[torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return self.model(
|
||||||
|
input_ids,
|
||||||
|
past=past,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||||
|
start = time.time_ns()
|
||||||
|
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||||
|
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
max_seqlen = input_ids.shape[1]
|
||||||
|
dtype = input_ids.dtype
|
||||||
|
|
||||||
|
# Inference params
|
||||||
|
seqlen_og = 0
|
||||||
|
inf_cache = {}
|
||||||
|
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen
|
||||||
|
|
||||||
|
if batch.inference_params is None:
|
||||||
|
inference_params = InferenceParams(
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
seqlen_offset=seqlen_og,
|
||||||
|
key_value_memory_dict=inf_cache,
|
||||||
|
lengths_per_sample=lengths_per_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate inference cache
|
||||||
|
for res_block in self.model.blocks:
|
||||||
|
block = res_block.mamba_block
|
||||||
|
conv_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.model.config.d_model * self.model.config.expand,
|
||||||
|
self.model.config.d_conv,
|
||||||
|
device=block.conv1d.weight.device,
|
||||||
|
dtype=block.conv1d.weight.dtype,
|
||||||
|
)
|
||||||
|
ssm_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.model.config.d_model * self.model.config.expand,
|
||||||
|
self.model.config.d_state,
|
||||||
|
device=block.dt_proj.weight.device,
|
||||||
|
dtype=block.dt_proj.weight.dtype,
|
||||||
|
)
|
||||||
|
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state)
|
||||||
|
batch.inference_params = inference_params
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params)
|
||||||
|
|
||||||
|
batch.inference_params = new_inference_params
|
||||||
|
# Results
|
||||||
|
generations: List[Generation] = []
|
||||||
|
stopped = True
|
||||||
|
|
||||||
|
# Speculation is not active for causal
|
||||||
|
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
||||||
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch.top_n_tokens_tensor,
|
||||||
|
torch.log_softmax(logits[:, -1], -1),
|
||||||
|
accepted_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
|
# Zipped iterator
|
||||||
|
iterator = zip(
|
||||||
|
batch.requests,
|
||||||
|
batch.input_lengths,
|
||||||
|
batch.prefix_offsets,
|
||||||
|
batch.read_offsets,
|
||||||
|
logits,
|
||||||
|
batch.next_token_choosers,
|
||||||
|
batch.stopping_criterias,
|
||||||
|
batch.all_input_ids,
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch_top_token_ids,
|
||||||
|
batch_top_token_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each member of the batch
|
||||||
|
for i, (
|
||||||
|
request,
|
||||||
|
input_length,
|
||||||
|
prefix_offset,
|
||||||
|
read_offset,
|
||||||
|
logits,
|
||||||
|
next_token_chooser,
|
||||||
|
stopping_criteria,
|
||||||
|
all_input_ids,
|
||||||
|
top_n_tokens,
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
) in enumerate(iterator):
|
||||||
|
# Select next token
|
||||||
|
next_token_id, logprobs = next_token_chooser(
|
||||||
|
all_input_ids.view(1, -1), logits[-1:, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append next token to all tokens
|
||||||
|
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||||
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
|
# Generated token
|
||||||
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
|
all_input_ids[:, 0], prefix_offset, read_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate stopping criteria
|
||||||
|
stop, reason = stopping_criteria(
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not stop:
|
||||||
|
stopped = False
|
||||||
|
|
||||||
|
# Shard generations
|
||||||
|
# All generations will be appended in the rust sharded client
|
||||||
|
if i % self.world_size == self.rank:
|
||||||
|
if stop:
|
||||||
|
# Decode generated tokens
|
||||||
|
output_text, _, _ = self.decode_token(
|
||||||
|
all_input_ids[:, 0],
|
||||||
|
prefix_offset=len(all_input_ids)
|
||||||
|
- stopping_criteria.current_tokens
|
||||||
|
- 1,
|
||||||
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
generated_text = GeneratedText(
|
||||||
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generated_text = None
|
||||||
|
|
||||||
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
|
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||||
|
logits, -1
|
||||||
|
).gather(1, all_input_ids[1:]).squeeze(1)[
|
||||||
|
-new_input_length:-1
|
||||||
|
].tolist()
|
||||||
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
prefill_tokens = Tokens(
|
||||||
|
prefill_token_ids,
|
||||||
|
prefill_logprobs,
|
||||||
|
prefill_texts,
|
||||||
|
is_special=[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
|
||||||
|
if top_n_tokens > 0:
|
||||||
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
|
top_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
special_toptokens = [
|
||||||
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
|
]
|
||||||
|
top_tokens = Tokens(
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
top_tokens = None
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
Tokens(
|
||||||
|
[next_token_id_squeezed],
|
||||||
|
[next_token_logprob],
|
||||||
|
[next_token_text],
|
||||||
|
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||||
|
),
|
||||||
|
generated_text,
|
||||||
|
top_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
generations.append(generation)
|
||||||
|
|
||||||
|
# Update values
|
||||||
|
batch.input_ids[i, 0] = next_token_id
|
||||||
|
batch.all_input_ids[i] = all_input_ids
|
||||||
|
batch.input_lengths[i] = new_input_length
|
||||||
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
|
batch.read_offsets[i] = read_offset
|
||||||
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
|
|
||||||
|
# We finished all generations in the batch; there is no next batch
|
||||||
|
if stopped:
|
||||||
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
|
# Slice unused values from prefill
|
||||||
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
|
|
||||||
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, batch, (forward_ns, decode_ns)
|
Loading…
Reference in New Issue
Block a user