diff --git a/README-HF.md b/README-HF.md
new file mode 100644
index 00000000..effab42e
--- /dev/null
+++ b/README-HF.md
@@ -0,0 +1,275 @@
+
+
+
+
+# Text Generation Inference
+
+
+
+
+
+
+
+
+
+
+
+
+A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
+to power LLMs api-inference widgets.
+
+## Table of contents
+
+- [Features](#features)
+- [Optimized Architectures](#optimized-architectures)
+- [Get Started](#get-started)
+ - [Docker](#docker)
+ - [API Documentation](#api-documentation)
+ - [Using a private or gated model](#using-a-private-or-gated-model)
+ - [A note on Shared Memory](#a-note-on-shared-memory-shm)
+ - [Distributed Tracing](#distributed-tracing)
+ - [Local Install](#local-install)
+ - [CUDA Kernels](#cuda-kernels)
+- [Run Falcon](#run-falcon)
+ - [Run](#run)
+ - [Quantization](#quantization)
+- [Develop](#develop)
+- [Testing](#testing)
+
+## Features
+
+- Serve the most popular Large Language Models with a simple launcher
+- Tensor Parallelism for faster inference on multiple GPUs
+- Token streaming using Server-Sent Events (SSE)
+- [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput
+- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
+- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
+- [Safetensors](https://github.com/huggingface/safetensors) weight loading
+- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
+- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
+- Stop sequences
+- Log probabilities
+- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
+
+## Optimized architectures
+
+- [BLOOM](https://huggingface.co/bigscience/bloom)
+- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
+- [Galactica](https://huggingface.co/facebook/galactica-120b)
+- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
+- [Llama](https://github.com/facebookresearch/llama)
+- [OPT](https://huggingface.co/facebook/opt-66b)
+- [SantaCoder](https://huggingface.co/bigcode/santacoder)
+- [Starcoder](https://huggingface.co/bigcode/starcoder)
+- [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b)
+- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
+- [MPT](https://huggingface.co/mosaicml/mpt-30b)
+- [Llama V2](https://huggingface.co/meta-llama)
+
+Other architectures are supported on a best effort basis using:
+
+`AutoModelForCausalLM.from_pretrained(, device_map="auto")`
+
+or
+
+`AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto")`
+
+## Get started
+
+### Docker
+
+The easiest way of getting started is using the official Docker container:
+
+```shell
+model=tiiuae/falcon-7b-instruct
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.9.4 --model-id $model
+```
+**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
+
+To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli:
+```
+text-generation-launcher --help
+```
+
+You can then query the model using either the `/generate` or `/generate_stream` routes:
+
+```shell
+curl 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
+ -H 'Content-Type: application/json'
+```
+
+```shell
+curl 127.0.0.1:8080/generate_stream \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
+ -H 'Content-Type: application/json'
+```
+
+or from Python:
+
+```shell
+pip install text-generation
+```
+
+```python
+from text_generation import Client
+
+client = Client("http://127.0.0.1:8080")
+print(client.generate("What is Deep Learning?", max_new_tokens=20).generated_text)
+
+text = ""
+for response in client.generate_stream("What is Deep Learning?", max_new_tokens=20):
+ if not response.token.special:
+ text += response.token.text
+print(text)
+```
+
+### API documentation
+
+You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
+The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
+
+### Using a private or gated model
+
+You have the option to utilize the `HUGGING_FACE_HUB_TOKEN` environment variable for configuring the token employed by
+`text-generation-inference`. This allows you to gain access to protected resources.
+
+For example, if you want to serve the gated Llama V2 model variants:
+
+1. Go to https://huggingface.co/settings/tokens
+2. Copy your cli READ token
+3. Export `HUGGING_FACE_HUB_TOKEN=`
+
+or with Docker:
+
+```shell
+model=meta-llama/Llama-2-7b-chat-hf
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+token=
+
+docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.9.3 --model-id $model
+```
+
+### A note on Shared Memory (shm)
+
+[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
+`PyTorch` to do distributed training/inference. `text-generation-inference` make
+use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
+
+In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
+peer-to-peer using NVLink or PCI is not possible.
+
+To allow the container to use 1G of Shared Memory and support SHM sharing, we add `--shm-size 1g` on the above command.
+
+If you are running `text-generation-inference` inside `Kubernetes`. You can also add Shared Memory to the container by
+creating a volume with:
+
+```yaml
+- name: shm
+ emptyDir:
+ medium: Memory
+ sizeLimit: 1Gi
+```
+
+and mounting it to `/dev/shm`.
+
+Finally, you can also disable SHM sharing by using the `NCCL_SHM_DISABLE=1` environment variable. However, note that
+this will impact performance.
+
+### Distributed Tracing
+
+`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
+by setting the address to an OTLP collector with the `--otlp-endpoint` argument.
+
+### Local install
+
+You can also opt to install `text-generation-inference` locally.
+
+First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
+Python 3.9, e.g. using `conda`:
+
+```shell
+curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
+
+conda create -n text-generation-inference python=3.9
+conda activate text-generation-inference
+```
+
+You may also need to install Protoc.
+
+On Linux:
+
+```shell
+PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
+curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
+sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc
+sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
+rm -f $PROTOC_ZIP
+```
+
+On MacOS, using Homebrew:
+
+```shell
+brew install protobuf
+```
+
+Then run:
+
+```shell
+BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
+make run-falcon-7b-instruct
+```
+
+**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
+
+```shell
+sudo apt-get install libssl-dev gcc -y
+```
+
+### CUDA Kernels
+
+The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
+the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
+
+Be aware that the official Docker image has them enabled by default.
+
+## Run Falcon
+
+### Run
+
+```shell
+make run-falcon-7b-instruct
+```
+
+### Quantization
+
+You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
+
+```shell
+make run-falcon-7b-instruct-quantize
+```
+
+## Develop
+
+```shell
+make server-dev
+make router-dev
+```
+
+## Testing
+
+```shell
+# python
+make python-server-tests
+make python-client-tests
+# or both server and client tests
+make python-tests
+# rust cargo tests
+make rust-tests
+# integration tests
+make integration-tests
+```
diff --git a/README.md b/README.md
index effab42e..e2bae362 100644
--- a/README.md
+++ b/README.md
@@ -1,208 +1,33 @@
-
-
-
-
-# Text Generation Inference
-
-
-
-
-
-
-
-
-
-
-
-
-A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
-to power LLMs api-inference widgets.
-
-## Table of contents
-
-- [Features](#features)
-- [Optimized Architectures](#optimized-architectures)
-- [Get Started](#get-started)
- - [Docker](#docker)
- - [API Documentation](#api-documentation)
- - [Using a private or gated model](#using-a-private-or-gated-model)
- - [A note on Shared Memory](#a-note-on-shared-memory-shm)
- - [Distributed Tracing](#distributed-tracing)
- - [Local Install](#local-install)
- - [CUDA Kernels](#cuda-kernels)
-- [Run Falcon](#run-falcon)
- - [Run](#run)
- - [Quantization](#quantization)
-- [Develop](#develop)
-- [Testing](#testing)
-
-## Features
-
-- Serve the most popular Large Language Models with a simple launcher
-- Tensor Parallelism for faster inference on multiple GPUs
-- Token streaming using Server-Sent Events (SSE)
-- [Continuous batching of incoming requests](https://github.com/huggingface/text-generation-inference/tree/main/router) for increased total throughput
-- Optimized transformers code for inference using [flash-attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
-- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
-- [Safetensors](https://github.com/huggingface/safetensors) weight loading
-- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
-- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
-- Stop sequences
-- Log probabilities
-- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
-
-## Optimized architectures
-
-- [BLOOM](https://huggingface.co/bigscience/bloom)
-- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
-- [Galactica](https://huggingface.co/facebook/galactica-120b)
-- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
-- [Llama](https://github.com/facebookresearch/llama)
-- [OPT](https://huggingface.co/facebook/opt-66b)
-- [SantaCoder](https://huggingface.co/bigcode/santacoder)
-- [Starcoder](https://huggingface.co/bigcode/starcoder)
-- [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b)
-- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
-- [MPT](https://huggingface.co/mosaicml/mpt-30b)
-- [Llama V2](https://huggingface.co/meta-llama)
-
-Other architectures are supported on a best effort basis using:
-
-`AutoModelForCausalLM.from_pretrained(, device_map="auto")`
-
-or
-
-`AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto")`
-
-## Get started
-
-### Docker
-
-The easiest way of getting started is using the official Docker container:
-
-```shell
-model=tiiuae/falcon-7b-instruct
-volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
-
-docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.9.4 --model-id $model
-```
-**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
-
-To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli:
-```
-text-generation-launcher --help
-```
-
-You can then query the model using either the `/generate` or `/generate_stream` routes:
-
-```shell
-curl 127.0.0.1:8080/generate \
- -X POST \
- -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
- -H 'Content-Type: application/json'
-```
-
-```shell
-curl 127.0.0.1:8080/generate_stream \
- -X POST \
- -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
- -H 'Content-Type: application/json'
-```
-
-or from Python:
-
-```shell
-pip install text-generation
-```
-
-```python
-from text_generation import Client
-
-client = Client("http://127.0.0.1:8080")
-print(client.generate("What is Deep Learning?", max_new_tokens=20).generated_text)
-
-text = ""
-for response in client.generate_stream("What is Deep Learning?", max_new_tokens=20):
- if not response.token.special:
- text += response.token.text
-print(text)
-```
-
-### API documentation
-
-You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
-The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
-
-### Using a private or gated model
-
-You have the option to utilize the `HUGGING_FACE_HUB_TOKEN` environment variable for configuring the token employed by
-`text-generation-inference`. This allows you to gain access to protected resources.
-
-For example, if you want to serve the gated Llama V2 model variants:
-
-1. Go to https://huggingface.co/settings/tokens
-2. Copy your cli READ token
-3. Export `HUGGING_FACE_HUB_TOKEN=`
-
-or with Docker:
-
-```shell
-model=meta-llama/Llama-2-7b-chat-hf
-volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
-token=
-
-docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.9.3 --model-id $model
-```
-
-### A note on Shared Memory (shm)
-
-[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
-`PyTorch` to do distributed training/inference. `text-generation-inference` make
-use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
-
-In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
-peer-to-peer using NVLink or PCI is not possible.
-
-To allow the container to use 1G of Shared Memory and support SHM sharing, we add `--shm-size 1g` on the above command.
-
-If you are running `text-generation-inference` inside `Kubernetes`. You can also add Shared Memory to the container by
-creating a volume with:
-
-```yaml
-- name: shm
- emptyDir:
- medium: Memory
- sizeLimit: 1Gi
-```
-
-and mounting it to `/dev/shm`.
-
-Finally, you can also disable SHM sharing by using the `NCCL_SHM_DISABLE=1` environment variable. However, note that
-this will impact performance.
-
-### Distributed Tracing
-
-`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
-by setting the address to an OTLP collector with the `--otlp-endpoint` argument.
-
### Local install
You can also opt to install `text-generation-inference` locally.
-First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
-Python 3.9, e.g. using `conda`:
+First [install Rust](https://rustup.rs/):
-```shell
+```bash
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
-
-conda create -n text-generation-inference python=3.9
-conda activate text-generation-inference
```
-You may also need to install Protoc.
+Install conda:
-On Linux:
+```bash
+curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg
+sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg
+gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806
+echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | sudo tee -a /etc/apt/sources.list.d/conda.list
+sudo apt update && sudo apt install conda -y
+source /opt/conda/etc/profile.d/conda.sh
+conda -V
+```
+Create Env:
+
+```shell
+conda create -n dscb python=3.9
+conda activate dscb
+```
+
+Install PROTOC
```shell
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
@@ -211,65 +36,29 @@ sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
rm -f $PROTOC_ZIP
```
-On MacOS, using Homebrew:
-
-```shell
-brew install protobuf
-```
-
-Then run:
-
-```shell
-BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
-make run-falcon-7b-instruct
-```
-
-**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
-
+You might need to install these:
```shell
sudo apt-get install libssl-dev gcc -y
+sudo apt-get install pkg-config
```
-### CUDA Kernels
-
-The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
-the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
-
-Be aware that the official Docker image has them enabled by default.
-
-## Run Falcon
-
-### Run
-
+Install DeepSparse:
```shell
-make run-falcon-7b-instruct
+pip install deepsparse-nightly[transformers]
```
-### Quantization
-
-You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
-
+Install Server
```shell
-make run-falcon-7b-instruct-quantize
+make install-server
```
-## Develop
-
+Launch Server
+```shell
+python3 server/text_generation_server/cli.py download-weights bigscience/bloom-560m
+python3 server/text_generation_server/cli.py serve bigscience/bloom-560m
+```
+
+Launch Router
```shell
-make server-dev
make router-dev
-```
-
-## Testing
-
-```shell
-# python
-make python-server-tests
-make python-client-tests
-# or both server and client tests
-make python-tests
-# rust cargo tests
-make rust-tests
-# integration tests
-make integration-tests
-```
+```
\ No newline at end of file
diff --git a/interaction.ipynb b/interaction.ipynb
new file mode 100644
index 00000000..ea45c44d
--- /dev/null
+++ b/interaction.ipynb
@@ -0,0 +1,2154 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import deepsparse"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"WAND_OPT_FLAGS\"] = \"default,~pyramids\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-19 21:47:03 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "fa1ef98341ae44e7a4787a05fad4c5e8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading (…)se/model.onnx.tar.gz: 0%| | 0.00/789M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "aaf3c0d675ec4deca68fa242c48702da",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading (…)ployment/config.json: 0%| | 0.00/999 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "85b08e3752d8464494c0afd949e0917d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading (…)yment/tokenizer.json: 0%| | 0.00/2.02M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2563ccabcfc8417bb19414c612bc82d2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading (…)okenizer_config.json: 0%| | 0.00/240 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n",
+ "2023-08-19 21:48:07 deepsparse.transformers.pipelines.text_generation INFO Compiling an auxiliary engine to process a prompt with a larger processing length. This improves performance, but may result in additional memory consumption.\n",
+ "2023-08-19 21:48:08 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n",
+ "2023-08-19 21:48:35 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
+ ]
+ }
+ ],
+ "source": [
+ "pipeline = deepsparse.Pipeline.create(\n",
+ " task=\"text-generation\", \n",
+ " model_path=\"zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none\",\n",
+ " use_deepsparse_cache=False,\n",
+ " prompt_processing_sequence_length=4,\n",
+ " max_generated_tokens=64,\n",
+ " sequence_length=128\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n",
+ "2023-08-19 21:50:10 deepsparse.transformers.pipelines.text_generation INFO Compiling an auxiliary engine to process a prompt with a larger processing length. This improves performance, but may result in additional memory consumption.\n",
+ "2023-08-19 21:50:12 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "2023-08-19 21:50:37 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
+ ]
+ }
+ ],
+ "source": [
+ "pipeline2 = deepsparse.Pipeline.create(\n",
+ " task=\"text-generation\", \n",
+ " model_path=\"zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none\",\n",
+ " use_deepsparse_cache=True,\n",
+ " prompt_processing_sequence_length=4,\n",
+ " max_generated_tokens=64,\n",
+ " sequence_length=128\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "print(fib(int(input())))\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "print(fib(int(input())))\n"
+ ]
+ }
+ ],
+ "source": [
+ "output = pipeline(sequences=\"fib(n):\")\n",
+ "print(output.sequences[0])\n",
+ "\n",
+ "print(\"\\n\\n\\n\\n\")\n",
+ "\n",
+ "output = pipeline2(sequences=\"fib(n):\")\n",
+ "print(output.sequences[0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code\n"
+ ]
+ }
+ ],
+ "source": [
+ "sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\n fib(n):\"\n",
+ "print(pipeline(sequences=[sequence]).sequences[0])\n",
+ "print(\"\\n\\n\\n\")\n",
+ "print(pipeline2(sequences=[sequence]).sequences[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Deconstructing the Pipeline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "engine = pipeline.engine\n",
+ "multitoken_engine = pipeline.multitoken_engine"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "input_tokens = pipeline.tokenizer(\n",
+ " sequence,\n",
+ " return_tensors=\"np\",\n",
+ " max_length=pipeline.sequence_length,\n",
+ " padding=\"max_length\",\n",
+ " truncation=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256\n",
+ " 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256\n",
+ " 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256\n",
+ " 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256\n",
+ " 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256\n",
+ " 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256\n",
+ " 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256\n",
+ " 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256\n",
+ " 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256 50256\n",
+ " 50256 50256 48658 262 1708 2163 329 14492 257 12900 261 44456\n",
+ " 8379 25 220 628 12900 7 77 2599]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "input_ids = input_tokens[\"input_ids\"]\n",
+ "print(input_ids)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "attention_mask = input_tokens[\"attention_mask\"]\n",
+ "print(attention_mask)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 2 3 4 5 6 7 8 9 10\n",
+ " 11 12 13 14 15 16 17 18]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "positions = attention_mask.cumsum(1) * attention_mask\n",
+ "print(positions)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0]\n",
+ " [0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from deepsparse.transformers.utils.helpers import create_causal_mask\n",
+ "causal_mask = create_causal_mask(input_ids, attention_mask)\n",
+ "print(causal_mask[:,:,-20:,-20:])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['input_ids', 'attention_mask', 'positions', 'causal_mask']\n"
+ ]
+ }
+ ],
+ "source": [
+ "onnx_input_names = multitoken_engine.onnx_input_names_no_cache\n",
+ "assert(name in input_tokens for name in onnx_input_names)\n",
+ "print(onnx_input_names)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "input_tokens = dict(\n",
+ " **input_tokens, positions=positions, causal_mask=causal_mask\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "engine_input = [input_tokens[name] for name in onnx_input_names]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline._reset_engines_cache()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy \n",
+ "def engine_inputs_for_prefill(tokens):\n",
+ " num_batches = len(tokens) // pipeline.prompt_processing_sequence_length\n",
+ "\n",
+ " token_batches = [tokens[i * pipeline.prompt_processing_sequence_length : (i + 1) * pipeline.prompt_processing_sequence_length] for i in range(0, num_batches)]\n",
+ " \n",
+ " for idx, token_batch in enumerate(token_batches):\n",
+ " engine_inputs = []\n",
+ " num_cached_entries = multitoken_engine.num_non_blank_cache_entries\n",
+ " # print(num_cached_entries)\n",
+ " \n",
+ " for name in multitoken_engine.onnx_input_names_no_cache:\n",
+ " if name == \"input_ids\":\n",
+ " assert len(engine_inputs) == 0\n",
+ " engine_input = numpy.array([token_batch])\n",
+ " \n",
+ " elif name == \"attention_mask\":\n",
+ " assert len(engine_inputs) == 1\n",
+ " # create an empty attention mask\n",
+ " engine_input = numpy.zeros(\n",
+ " (1, pipeline.sequence_length), dtype=numpy.int64\n",
+ " )\n",
+ " # fill it out with 1s (from the right), so that the number\n",
+ " # of unmasked entries is equal to the sum of:\n",
+ " engine_input[\n",
+ " :,\n",
+ " -(\n",
+ " # ...the number of current input tokens...\n",
+ " pipeline.prompt_processing_sequence_length\n",
+ " # ...and the number of the previous cache entries\n",
+ " + num_cached_entries\n",
+ " ) :,\n",
+ " ] = 1\n",
+ " \n",
+ " elif name == \"causal_mask\":\n",
+ " continue\n",
+ " \n",
+ " elif name == \"positions\":\n",
+ " if pipeline.prompt_processing_sequence_length == 1:\n",
+ " # we need to treat `positions` as if we were in\n",
+ " # the autoregressive mode\n",
+ " engine_input = numpy.array([[idx]], dtype=numpy.int64)\n",
+ " else:\n",
+ " engine_input = (\n",
+ " numpy.arange(\n",
+ " num_cached_entries,\n",
+ " num_cached_entries\n",
+ " + pipeline.prompt_processing_sequence_length,\n",
+ " )\n",
+ " .reshape(1, -1)\n",
+ " .astype(numpy.int64)\n",
+ " )\n",
+ "\n",
+ " # print(f\"{name}:\")\n",
+ " # print(engine_input)\n",
+ " # print(engine_input.shape)\n",
+ " \n",
+ " engine_inputs.append(engine_input)\n",
+ "\n",
+ " assert \"causal_mask\" in multitoken_engine.onnx_input_names_no_cache\n",
+ " causal_mask = create_causal_mask(input_ids=engine_inputs[0], attention_mask=engine_inputs[1])\n",
+ " engine_inputs.append(causal_mask)\n",
+ "\n",
+ " # print(\"causal_mask:\")\n",
+ " # print(causal_mask)\n",
+ " # print(causal_mask.shape)\n",
+ "\n",
+ " yield engine_inputs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "INDEX = 32\n",
+ "cache_onnx_names = [\n",
+ " name\n",
+ " for name in multitoken_engine.engine.input_names\n",
+ " if name.startswith(\"past_key_values\")\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def decode(tokens):\n",
+ " input_ids = numpy.array([[tokens[-1]]])\n",
+ " \n",
+ " attention_mask = numpy.zeros((1, engine.sequence_length), dtype=numpy.int64)\n",
+ " num_tokens_processed = min(len(tokens), engine.sequence_length) # cap by seq len\n",
+ " attention_mask[:, -num_tokens_processed:] = 1 \n",
+ "\n",
+ " causal_mask = create_causal_mask(input_ids, attention_mask)\n",
+ " positions = numpy.array([[len(tokens) - 1]], dtype=numpy.int64)\n",
+ " \n",
+ " engine_inputs_map = dict(\n",
+ " input_ids=input_ids,\n",
+ " attention_mask=attention_mask,\n",
+ " causal_mask=causal_mask,\n",
+ " positions=positions\n",
+ " )\n",
+ "\n",
+ " engine_inputs = [\n",
+ " engine_inputs_map[name] for name in engine.onnx_input_names_no_cache\n",
+ " ]\n",
+ "\n",
+ " return call(engine, engine_inputs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def prefill(tokens):\n",
+ " prompt_logits = []\n",
+ " \n",
+ " # loop over multitoken engine\n",
+ " for engine_inputs in engine_inputs_for_prefill(tokens):\n",
+ " logits = call(multitoken_engine, engine_inputs)\n",
+ " prompt_logits.append(logits)\n",
+ " \n",
+ " # expand kv cache for new size 124 --> 127\n",
+ " engine.kv_cache._state = kv_cache_insert(engine.kv_cache._state, num_items=multitoken_engine.input_ids_length - engine.input_ids_length)\n",
+ "\n",
+ " # loop of singletoken engine for the rest\n",
+ " tokens_processed = engine.kv_cache.total_num_processed_tokens\n",
+ " while tokens_processed < len(tokens):\n",
+ " logits = decode(tokens[:tokens_processed + 1])\n",
+ " prompt_logits.append(logits)\n",
+ " tokens_processed += 1\n",
+ " \n",
+ " return prompt_logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def sample_token(logits):\n",
+ " return numpy.argmax(logits)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def kv_cache_slice(kv_cache, slice_idx):\n",
+ " for key in kv_cache:\n",
+ " kv_cache[key] = numpy.ascontiguousarray(kv_cache[key][:,:,slice_idx:,:])\n",
+ " return kv_cache\n",
+ "\n",
+ "def kv_cache_insert(kv_cache, num_items = 1, padding_value = 0):\n",
+ " indices = [0] * num_items\n",
+ " for key, value in kv_cache.items():\n",
+ " dtype = value.dtype\n",
+ " padding_value = numpy.array(padding_value, dtype=dtype)\n",
+ " kv_cache[key] = numpy.insert(value, indices, padding_value, axis=2)\n",
+ "\n",
+ " return kv_cache\n",
+ "\n",
+ "def call(eng, inputs):\n",
+ " inp = eng.add_kv_cache_to_input(inputs)\n",
+ " \n",
+ " logits, *kvs = eng.engine.run(inp, True)\n",
+ " new_kv_cache_state = {name: arr for name, arr in zip(cache_onnx_names, kvs)}\n",
+ "\n",
+ " eng.kv_cache.total_num_processed_tokens += eng.input_ids_length\n",
+ " eng.kv_cache._state = kv_cache_slice(new_kv_cache_state, eng.input_ids_length)\n",
+ "\n",
+ " return logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "128"
+ ]
+ },
+ "execution_count": 37,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "engine.sequence_length"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 184,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12, 23]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12, 23, 532]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12, 23, 532, 9]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12, 23, 532, 9, 12]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12, 23, 532, 9, 12, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12, 23, 532, 9, 12, 198, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12, 23, 532, 9, 12, 198, 198, 37811]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12, 23, 532, 9, 12, 198, 198, 37811, 198]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12, 23, 532, 9, 12, 198, 198, 37811, 198, 50284]\n",
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, 198, 50284, 361, 299, 6624, 657, 25, 198, 50280, 7783, 657, 198, 50284, 417, 361, 299, 6624, 352, 25, 198, 50280, 7783, 352, 198, 50284, 17772, 25, 198, 50280, 7783, 12900, 7, 77, 12, 16, 8, 1343, 12900, 7, 77, 12, 17, 8, 198, 198, 2, 4889, 262, 2163, 13, 198, 4798, 7, 69, 571, 7, 20, 4008, 198, 198, 2, 770, 2438, 318, 8639, 416, 11271, 71, 346, 26105, 14403, 7, 17172, 89, 1347, 62, 25816, 8, 198, 50256, 2, 48443, 14629, 14, 8800, 14, 24330, 21015, 198, 2, 532, 9, 12, 19617, 25, 3384, 69, 12, 23, 532, 9, 12, 198, 198, 37811, 198, 50284, 27730]\n"
+ ]
+ }
+ ],
+ "source": [
+ "tokens = engine_input[0][engine_input[1].nonzero()].tolist()\n",
+ "pipeline._reset_engines_cache()\n",
+ "\n",
+ "print(tokens)\n",
+ "logits = prefill(tokens)\n",
+ "tokens.append(sample_token(logits[-1][0,-1,:])) # assume always batch = 1, last token of last logit in array\n",
+ "\n",
+ "# first token from prefill was generated\n",
+ "while len(tokens) < engine.sequence_length:\n",
+ " print(tokens)\n",
+ " logits = decode(tokens)\n",
+ " tokens.append(sample_token(logits[0,-1,:])) # assume always batch = 1, last token of last logit in array\n",
+ "\n",
+ "# print(engine.kv_cache._state[\"past_key_values.0.key\"][0,0,-INDEX:,0])\n",
+ "# print(engine.kv_cache._state[\"past_key_values.0.value\"][0,0,-INDEX:,0])\n",
+ "# print(engine.kv_cache._state[\"past_key_values.19.key\"][0,0,-INDEX:,0])\n",
+ "# print(engine.kv_cache._state[\"past_key_values.19.value\"][0,0,-INDEX:,0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ " fib(n):\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
+ "<|endoftext|>#!/usr/bin/env python\n",
+ "# -*- coding: utf-8 -*-\n",
+ "\n",
+ "\"\"\"\n",
+ " Examples for\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(pipeline.tokenizer.decode(tokens))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[-2.16584 -1.4835675 -0.82798475 2.728714 2.230249 0.136684\n",
+ " -2.4744277 -3.7903032 -0.44804883 1.8597361 3.2892575 1.1238453\n",
+ " -0.535056 -2.4058022 -2.6181865 0.82309175 3.7468169 0.9127281\n",
+ " -0.08818069 -3.5193567 -2.554974 0.42606103 2.8396277 3.6084752\n",
+ " 0.720097 -3.140173 -2.3983316 -1.1198903 1.4021769 2.4038355\n",
+ " 1.416564 -1.1770982 ]\n",
+ "[-0.11449916 0.47542867 0.03680322 -0.4064121 0.34018266 0.11729242\n",
+ " -0.29119202 1.6026565 -0.60162723 1.6026565 0.03134436 0.005808\n",
+ " 0.03680322 0.34018266 1.0944448 0.89051616 1.9490317 -0.8315846\n",
+ " 0.8142212 -0.22642836 0.36906892 1.9490317 -0.15412031 1.0944448\n",
+ " 0.89051616 1.9490317 0.03680322 0.03680322 0.55909103 0.03680322\n",
+ " 0.76835376 0.07582108]\n",
+ "[-1.4705226 2.3718867 1.5622201 0.20804703 -0.930273 -5.223105\n",
+ " -2.31877 3.5253658 3.8794327 3.3048825 -2.4029026 -2.4765668\n",
+ " -0.68623084 1.3053839 6.9972997 4.6631894 -0.957654 -4.965276\n",
+ " -5.222634 0.77317643 5.6226482 6.351179 1.0147996 -5.322752\n",
+ " -5.885022 -1.1356002 0.9603227 2.44311 2.3220952 -1.8733013\n",
+ " -5.0550013 -2.9907336 ]\n",
+ "[-0.16411448 -0.05435281 -0.22059102 0.09352674 -0.05225876 -0.22478615\n",
+ " 0.4103162 -0.1921539 0.11564742 -0.38469723 -0.01235063 0.29627988\n",
+ " -0.06217921 0.3747058 0.1442022 0.31203395 0.669638 0.40900382\n",
+ " 0.34937513 0.07317603 0.49499115 -0.26419586 0.14836667 0.41960722\n",
+ " 0.53298324 0.6752395 0.5533317 0.20957318 0.25364277 0.08110742\n",
+ " -0.19118905 0.845217 ]\n"
+ ]
+ }
+ ],
+ "source": [
+ "pipeline._reset_engines_cache()\n",
+ "\n",
+ "onnx_input_names = (\n",
+ " pipeline.multitoken_engine.onnx_input_names_no_cache\n",
+ " if pipeline.multitoken_engine\n",
+ " else pipeline.engine.onnx_input_names_no_cache\n",
+ ")\n",
+ "engine_input = pipeline.tokens_to_engine_input(input_tokens, onnx_input_names)\n",
+ "tokens_theirs, logits_theirs = pipeline.prompt_inference(engine_input)\n",
+ "\n",
+ "while len(tokens_theirs) < pipeline.sequence_length:\n",
+ " token, logits = pipeline.autoregressive_inference(tokens_theirs)\n",
+ " tokens_theirs.append(token)\n",
+ " \n",
+ "print(engine.kv_cache._state[\"past_key_values.0.key\"][0,0,-INDEX:,0])\n",
+ "print(engine.kv_cache._state[\"past_key_values.0.value\"][0,0,-INDEX:,0])\n",
+ "print(engine.kv_cache._state[\"past_key_values.19.key\"][0,0,-INDEX:,0])\n",
+ "print(engine.kv_cache._state[\"past_key_values.19.value\"][0,0,-INDEX:,0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ " fib(n):\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
+ "<|endoftext|>#!/usr/bin/env python\n",
+ "# -*- coding: utf-8 -*-\n",
+ "\n",
+ "\"\"\"\n",
+ " Examples for\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(pipeline.tokenizer.decode(tokens_theirs))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 350,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(1, 16, 127, 64)"
+ ]
+ },
+ "execution_count": 350,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "engine.kv_cache._state[\"past_key_values.0.key\"].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 321,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(2, 4, 124, 16)\n"
+ ]
+ }
+ ],
+ "source": [
+ "arr = numpy.ones((2,4,124,16), dtype=numpy.uint8)\n",
+ "print(arr.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 334,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "indices = [0] * 3\n",
+ "updated = numpy.insert(arr, indices, 0, axis=2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 335,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "True\n",
+ "True\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(arr.data.contiguous)\n",
+ "print(updated.data.contiguous)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 302,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(engine.kv_cache)\n",
+ "print(multitoken_engine.kv_cache)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 305,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "128\n",
+ "1\n",
+ "128\n",
+ "4\n",
+ "124\n",
+ "124\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(engine.sequence_length)\n",
+ "print(engine.input_ids_length)\n",
+ "\n",
+ "print(multitoken_engine.sequence_length)\n",
+ "print(multitoken_engine.input_ids_length)\n",
+ "\n",
+ "print(engine.kv_cache.capacity)\n",
+ "print(multitoken_engine.kv_cache.capacity)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 297,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "engine.transfer_cache_state(cache=multitoken_engine.kv_cache)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 300,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "127\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(multitoken_engine.kv_cache.capacity)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 299,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "127\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(engine.kv_cache.capacity)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 128,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0\n",
+ "input_ids:\n",
+ "[[16594 257 2163 284]]\n",
+ "(1, 4)\n",
+ "attention_mask:\n",
+ "[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1]]\n",
+ "(1, 128)\n",
+ "positions:\n",
+ "[[0 1 2 3]]\n",
+ "(1, 4)\n",
+ "causal_mask:\n",
+ "[[[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]\n",
+ " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0]\n",
+ " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0]\n",
+ " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1]]]]\n",
+ "(1, 1, 4, 128)\n",
+ "\n",
+ "\n",
+ "\n",
+ "4\n",
+ "input_ids:\n",
+ "[[24061 257 12900 261]]\n",
+ "(1, 4)\n",
+ "attention_mask:\n",
+ "[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1]]\n",
+ "(1, 128)\n",
+ "positions:\n",
+ "[[4 5 6 7]]\n",
+ "(1, 4)\n",
+ "causal_mask:\n",
+ "[[[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0]\n",
+ " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0]\n",
+ " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0]\n",
+ " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1]]]]\n",
+ "(1, 1, 4, 128)\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "pipeline._reset_engines_cache()\n",
+ "\n",
+ "for engine_inputs in engine_inputs_for_prefill(tokens):\n",
+ " multitoken_engine(engine_inputs)\n",
+ " print(\"\\n\\n\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 272,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ " fib(n):\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
+ "<|endoftext|>\n"
+ ]
+ }
+ ],
+ "source": [
+ "from typing import Optional, List, Dict\n",
+ "from deepsparse import Context\n",
+ "from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine\n",
+ "from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs\n",
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "class DecoderEngine:\n",
+ " def __init__ (\n",
+ " self,\n",
+ " onnx_file_path: str, \n",
+ " sequence_length: int = 1024,\n",
+ " input_ids_length: int = 1,\n",
+ " engine_context: Optional[Context] = None,\n",
+ " ):\n",
+ "\n",
+ " onnx_file_path, _, data_type = overwrite_onnx_model_inputs(\n",
+ " onnx_file_path=onnx_file_path,\n",
+ " batch_size=1,\n",
+ " sequence_length=sequence_length,\n",
+ " input_ids_length=input_ids_length,\n",
+ " )\n",
+ "\n",
+ " self.past_key_value_dtype = data_type\n",
+ " self.engine = create_engine(\n",
+ " onnx_file_path=onnx_file_path,\n",
+ " engine_type=DEEPSPARSE_ENGINE,\n",
+ " engine_args={},\n",
+ " context=engine_context,\n",
+ " )\n",
+ " print(self.engine)\n",
+ "\n",
+ " self.onnx_inputs = self.engine.input_names\n",
+ " \n",
+ " self.past_onnx_inputs = [\n",
+ " name for name in self.engine.input_names\n",
+ " if name.startswith(\"past_key_values\")\n",
+ " ]\n",
+ "\n",
+ " self.non_past_onnx_inputs = [\n",
+ " name for name in self.engine.input_names\n",
+ " if not name.startswith(\"past_key_values\")\n",
+ " ]\n",
+ " \n",
+ " def __call__(\n",
+ " self,\n",
+ " inputs: Dict[str, numpy.ndarray],\n",
+ " past_key_values: Dict[str, numpy.ndarray],\n",
+ " val_inp: bool = True\n",
+ " ):\n",
+ " # format input\n",
+ " inp = [past_key_values[name] if name.startswith(\"past_key_values\") \n",
+ " else inputs[name] for name in self.engine.input_names]\n",
+ "\n",
+ " # run inference\n",
+ " logits, *kvs = self.engine.run(inp, True)\n",
+ " past_key_values = {name: arr for name, arr in zip(self.past_onnx_inputs, kvs)}\n",
+ " \n",
+ " return logits, past_key_values\n",
+ "\n",
+ "\n",
+ "class Model:\n",
+ " def __init__(\n",
+ " self,\n",
+ " onnx_file_path: str,\n",
+ " sequence_length: int = 1024,\n",
+ " multi_token_length: int = 16,\n",
+ " engine_context: Optional[Context] = None,\n",
+ " singletoken_engine = None,\n",
+ " multitoken_engine = None,\n",
+ " ):\n",
+ " self.sequence_length = sequence_length\n",
+ " self.multi_token_length = multi_token_length\n",
+ "\n",
+ " if singletoken_engine is not None and multitoken_engine is not None:\n",
+ " self.singletoken_engine = singletoken_engine\n",
+ " self.multitoken_engine = multitoken_engine\n",
+ " else:\n",
+ " self.singletoken_engine = DecoderEngine(\n",
+ " onnx_file_path=onnx_file_path,\n",
+ " engine_context=engine_context,\n",
+ " sequence_length=sequence_length,\n",
+ " input_ids_length=1,\n",
+ " )\n",
+ " \n",
+ " self.multitoken_engine = DecoderEngine(\n",
+ " onnx_file_path=onnx_file_path,\n",
+ " engine_context=engine_context,\n",
+ " sequence_length=sequence_length,\n",
+ " input_ids_length=self.multi_token_length,\n",
+ " )\n",
+ "\n",
+ " assert self.multitoken_engine.past_key_value_dtype == self.singletoken_engine.past_key_value_dtype\n",
+ " self.past_key_value_dtype = self.multitoken_engine.past_key_value_dtype\n",
+ " \n",
+ " assert len(self.singletoken_engine.non_past_onnx_inputs) == 4\n",
+ " assert \"input_ids\" in self.singletoken_engine.non_past_onnx_inputs\n",
+ " assert \"attention_mask\" in self.singletoken_engine.non_past_onnx_inputs\n",
+ " assert \"causal_mask\" in self.singletoken_engine.non_past_onnx_inputs\n",
+ " assert \"positions\" in self.singletoken_engine.non_past_onnx_inputs\n",
+ "\n",
+ " # create empty kv caches with the proper sizes based on onnx graph\n",
+ " def init_past_key_values(self):\n",
+ " past_key_values = {}\n",
+ " for idx, name in enumerate(self.multitoken_engine.onnx_inputs):\n",
+ " if name.startswith(\"past_key_values\"):\n",
+ " shape = self.multitoken_engine.engine.input_shapes[idx]\n",
+ " past_key_values[name] = numpy.zeros(shape, dtype=self.past_key_value_dtype)\n",
+ "\n",
+ " return past_key_values\n",
+ "\n",
+ " # insert into every K,V matrix in the list\n",
+ " # BAD [SLOW] --- A copy of arr with values inserted. Note that insert does not occur in-place: a new array is returned. If axis is None, out is a flattened array.\n",
+ " def insert_past_key_values(self, past_key_values, num_items=1, padding_value=0):\n",
+ " for name in past_key_values:\n",
+ " padding_value = numpy.array(padding_value, dtype=self.past_key_value_dtype)\n",
+ " past_key_values[name] = numpy.insert(past_key_values[name], [0]*num_items, padding_value, axis=2)\n",
+ " return past_key_values\n",
+ "\n",
+ " # slice every K,V matrix in the list\n",
+ " # BAD [SLOW] --- calls .ascontinugousarray\n",
+ " def slice_past_key_values(self, past_key_values, slice_idx):\n",
+ " for name in past_key_values:\n",
+ " past_key_values[name] = numpy.ascontiguousarray(past_key_values[name][:,:,slice_idx:,:])\n",
+ " return past_key_values\n",
+ " \n",
+ " # slice input tokens into groups, make inputs dict\n",
+ " def engine_inputs_for_prefill(self, tokens):\n",
+ " num_batches = len(tokens) // self.multi_token_length\n",
+ " token_batches = [tokens[i * self.multi_token_length : (i+1) * self.multi_token_length] for i in range(0, num_batches)]\n",
+ "\n",
+ " num_processed_tokens = 0\n",
+ " for idx, token_batch in enumerate(token_batches):\n",
+ " engine_inputs = {}\n",
+ " engine_inputs[\"input_ids\"] = numpy.array([token_batch])\n",
+ "\n",
+ " # make attention mask from the right\n",
+ " engine_inputs[\"attention_mask\"] = numpy.zeros((1, self.sequence_length), dtype=numpy.int64)\n",
+ " engine_inputs[\"attention_mask\"][:, -(self.multi_token_length + num_processed_tokens):] = 1\n",
+ " \n",
+ " # make positions (building from the right)\n",
+ " assert self.multi_token_length > 1\n",
+ " engine_inputs[\"positions\"] = numpy.arange(\n",
+ " num_processed_tokens, num_processed_tokens + self.multi_token_length\n",
+ " ).reshape(1, -1).astype(numpy.int64)\n",
+ "\n",
+ " # make causal mask (building from the right)\n",
+ " engine_inputs[\"causal_mask\"] = create_causal_mask(\n",
+ " input_ids=engine_inputs[\"input_ids\"], \n",
+ " attention_mask=engine_inputs[\"attention_mask\"]\n",
+ " )\n",
+ "\n",
+ "\n",
+ " yield engine_inputs\n",
+ "\n",
+ " def engine_inputs_for_decode(self, tokens):\n",
+ " assert(len(tokens) < self.sequence_length)\n",
+ " \n",
+ " engine_inputs = {}\n",
+ " engine_inputs[\"input_ids\"] = numpy.array([[tokens[-1]]])\n",
+ " engine_inputs[\"attention_mask\"] = numpy.zeros((1, self.sequence_length), dtype=numpy.int64)\n",
+ " engine_inputs[\"attention_mask\"][:, -len(tokens):] = 1\n",
+ " \n",
+ " engine_inputs[\"causal_mask\"] = create_causal_mask(\n",
+ " engine_inputs[\"input_ids\"], \n",
+ " engine_inputs[\"attention_mask\"]\n",
+ " )\n",
+ " engine_inputs[\"positions\"] = numpy.array([[len(tokens) - 1]], dtype=numpy.int64)\n",
+ " \n",
+ " return engine_inputs\n",
+ " \n",
+ " # run prefill inference\n",
+ " def prefill(self, tokens):\n",
+ " assert len(tokens) < self.sequence_length - 1\n",
+ " \n",
+ " # initialize state\n",
+ " past_key_values = self.init_past_key_values()\n",
+ " tokens_processed = 0\n",
+ " \n",
+ " # loop over multitoken engine\n",
+ " for inputs in self.engine_inputs_for_prefill(tokens): \n",
+ " logits, past_key_values = self.multitoken_engine(inputs, past_key_values)\n",
+ " tokens_processed += self.multi_token_length\n",
+ " \n",
+ " # (this is BAD - calls np.ascontiguous) - cleanup past_kv state \n",
+ " past_key_values = self.slice_past_key_values(past_key_values, self.multi_token_length)\n",
+ " \n",
+ " # (this is BAD - returns a copy) - expand kv cache for single token engine \n",
+ " past_key_values = self.insert_past_key_values(past_key_values, num_items=(self.multi_token_length-1))\n",
+ "\n",
+ " # loop of singletoken engine for anything left over\n",
+ " while tokens_processed < len(tokens):\n",
+ " logits, past_key_values = self.decode(\n",
+ " tokens=tokens[:tokens_processed+1],\n",
+ " past_key_values=past_key_values\n",
+ " )\n",
+ " tokens_processed += 1\n",
+ "\n",
+ " assert logits.shape[0] == 1 # assert batch 1 right now\n",
+ " return logits[:,:,:], past_key_values\n",
+ " \n",
+ " # run decode inference\n",
+ " def decode(self, tokens, past_key_values): \n",
+ " engine_inputs = self.engine_inputs_for_decode(tokens)\n",
+ "\n",
+ " logits, past_key_values = self.singletoken_engine(\n",
+ " inputs=engine_inputs,\n",
+ " past_key_values=past_key_values\n",
+ " )\n",
+ "\n",
+ " # cleanup state (this is BAD - calls np.ascontiguous)\n",
+ " past_key_values = self.slice_past_key_values(past_key_values, 1)\n",
+ "\n",
+ " assert logits.shape[0] == 1 # assert batch 1 right now\n",
+ " assert logits.shape[1] == 1 # assert only one element\n",
+ " return logits[:,:,:], past_key_values\n",
+ "\n",
+ "def sample_token(logits):\n",
+ " assert(logits.shape[0] == 1)\n",
+ " return numpy.argmax(logits[0,-1,:])\n",
+ "\n",
+ "model = Model(\n",
+ " onnx_file_path=onnx_path,\n",
+ " sequence_length=128,\n",
+ " multi_token_length=16,\n",
+ " singletoken_engine=model.singletoken_engine,\n",
+ " multitoken_engine=model.multitoken_engine\n",
+ ")\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
+ "tokenizer.padding_side = \"left\"\n",
+ "if not tokenizer.pad_token:\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ "\n",
+ "def generate(model, tokenizer, text):\n",
+ " input_tokens = tokenizer(text, return_tensors=\"np\", max_length=model.sequence_length, padding=\"longest\", truncation=False,)\n",
+ " tokens = input_tokens[\"input_ids\"][input_tokens[\"attention_mask\"].nonzero()].tolist()\n",
+ "\n",
+ " # prefill\n",
+ " logits, past_key_values = model.prefill(tokens)\n",
+ " tokens.append(sample_token(logits))\n",
+ "\n",
+ " # run decode\n",
+ " while len(tokens) < model.sequence_length and tokens[-1] != tokenizer.eos_token_id:\n",
+ " logits, past_key_values = model.decode(tokens, past_key_values)\n",
+ " tokens.append(sample_token(logits))\n",
+ " \n",
+ " return tokens\n",
+ " \n",
+ "tokens = generate(model, tokenizer, sequence)\n",
+ "print(tokenizer.decode(tokens))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 263,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(type(tokenizer.eos_token_id))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 266,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261, 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599, 198, array([0, 0, 0, ..., 0, 0, 0])]\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 220,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(1, 51200)"
+ ]
+ },
+ "execution_count": 220,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "logits.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 205,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-20 01:52:24 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deepsparse.engine.Engine:\n",
+ "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "\tbatch_size: 1\n",
+ "\tnum_cores: 8\n",
+ "\tnum_streams: 1\n",
+ "\tscheduler: Scheduler.default\n",
+ "\tfraction_of_supported_ops: 1.0\n",
+ "\tcpu_avx_type: avx2\n",
+ "\tcpu_vnni: False\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-20 01:52:54 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deepsparse.engine.Engine:\n",
+ "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "\tbatch_size: 1\n",
+ "\tnum_cores: 8\n",
+ "\tnum_streams: 1\n",
+ "\tscheduler: Scheduler.default\n",
+ "\tfraction_of_supported_ops: 1.0\n",
+ "\tcpu_avx_type: avx2\n",
+ "\tcpu_vnni: False\n"
+ ]
+ }
+ ],
+ "source": [
+ "onnx_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\"\n",
+ "model_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment\"\n",
+ "\n",
+ "model = Model(\n",
+ " onnx_file_path=onnx_path,\n",
+ " sequence_length=128,\n",
+ " multi_token_length=16\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 196,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n"
+ ]
+ }
+ ],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 197,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "18"
+ ]
+ },
+ "execution_count": 197,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 206,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "RuntimeError",
+ "evalue": "NM: error: Got invalid dimensions for input: causal_mask for the following indices\n index: 2 Got: 128 Expected: 1\n Please fix either the inputs or the model.",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[206], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprefill\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m)\u001b[49m\n",
+ "Cell \u001b[0;32mIn[204], line 187\u001b[0m, in \u001b[0;36mModel.prefill\u001b[0;34m(self, tokens)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# loop of singletoken engine for anything left over\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m tokens_processed \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mlen\u001b[39m(tokens):\n\u001b[0;32m--> 187\u001b[0m logits, past_key_values \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokens\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43mtokens_processed\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 189\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\n\u001b[1;32m 190\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 191\u001b[0m tokens_processed \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m logits\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;66;03m# assert batch 1 right now\u001b[39;00m\n",
+ "Cell \u001b[0;32mIn[204], line 200\u001b[0m, in \u001b[0;36mModel.decode\u001b[0;34m(self, tokens, past_key_values)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode\u001b[39m(\u001b[38;5;28mself\u001b[39m, tokens, past_key_values): \n\u001b[1;32m 198\u001b[0m engine_inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mengine_inputs_for_decode(tokens)\n\u001b[0;32m--> 200\u001b[0m logits, past_key_values \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msingletoken_engine\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 201\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mengine_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\n\u001b[1;32m 203\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 205\u001b[0m \u001b[38;5;66;03m# cleanup state (this is BAD - calls np.ascontiguous)\u001b[39;00m\n\u001b[1;32m 206\u001b[0m past_key_values \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mslice_past_key_values(past_key_values, \u001b[38;5;241m1\u001b[39m)\n",
+ "Cell \u001b[0;32mIn[204], line 54\u001b[0m, in \u001b[0;36mDecoderEngine.__call__\u001b[0;34m(self, inputs, past_key_values, val_inp)\u001b[0m\n\u001b[1;32m 50\u001b[0m inp \u001b[38;5;241m=\u001b[39m [past_key_values[name] \u001b[38;5;28;01mif\u001b[39;00m name\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpast_key_values\u001b[39m\u001b[38;5;124m\"\u001b[39m) \n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m inputs[name] \u001b[38;5;28;01mfor\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mengine\u001b[38;5;241m.\u001b[39minput_names]\n\u001b[1;32m 53\u001b[0m \u001b[38;5;66;03m# run inference\u001b[39;00m\n\u001b[0;32m---> 54\u001b[0m logits, \u001b[38;5;241m*\u001b[39mkvs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mengine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43minp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 55\u001b[0m past_key_values \u001b[38;5;241m=\u001b[39m {name: arr \u001b[38;5;28;01mfor\u001b[39;00m name, arr \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpast_onnx_inputs, kvs)}\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m logits, past_key_values\n",
+ "File \u001b[0;32m~/.conda/envs/dscb/lib/python3.9/site-packages/deepsparse/engine.py:527\u001b[0m, in \u001b[0;36mEngine.run\u001b[0;34m(self, inp, val_inp)\u001b[0m\n\u001b[1;32m 524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m val_inp:\n\u001b[1;32m 525\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_inputs(inp)\n\u001b[0;32m--> 527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_eng_net\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute_list_out\u001b[49m\u001b[43m(\u001b[49m\u001b[43minp\u001b[49m\u001b[43m)\u001b[49m\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: NM: error: Got invalid dimensions for input: causal_mask for the following indices\n index: 2 Got: 128 Expected: 1\n Please fix either the inputs or the model."
+ ]
+ }
+ ],
+ "source": [
+ "model.prefill(tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 168,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'input_ids': array([[[48658, 262, 1708, 2163, 329, 14492, 257, 12900, 261,\n",
+ " 44456, 8379, 25, 220, 628, 12900, 7, 77, 2599]]]), 'attention_mask': array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]), 'causal_mask': array([[[[0, 0, 0, ..., 0, 0, 0],\n",
+ " [0, 0, 0, ..., 0, 0, 0],\n",
+ " [0, 0, 0, ..., 0, 0, 0],\n",
+ " ...,\n",
+ " [0, 0, 0, ..., 1, 0, 0],\n",
+ " [0, 0, 0, ..., 1, 1, 0],\n",
+ " [0, 0, 0, ..., 1, 1, 1]]]]), 'positions': array([[0]])}\n"
+ ]
+ },
+ {
+ "ename": "RuntimeError",
+ "evalue": "NM: error: Invalid rank for input: input_ids Got: 3 Expected: 2 Please fix either the inputs or the model.",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[168], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprefill\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_tokens\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
+ "Cell \u001b[0;32mIn[166], line 187\u001b[0m, in \u001b[0;36mModel.prefill\u001b[0;34m(self, tokens)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# loop of singletoken engine for anything left over\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m tokens_processed \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mlen\u001b[39m(tokens):\n\u001b[0;32m--> 187\u001b[0m logits, past_key_values \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokens\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43mtokens_processed\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 189\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\n\u001b[1;32m 190\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 191\u001b[0m tokens_processed \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m logits\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;66;03m# assert batch 1 right now\u001b[39;00m\n",
+ "Cell \u001b[0;32mIn[166], line 201\u001b[0m, in \u001b[0;36mModel.decode\u001b[0;34m(self, tokens, past_key_values)\u001b[0m\n\u001b[1;32m 198\u001b[0m engine_inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mengine_inputs_for_decode(tokens)\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mprint\u001b[39m(engine_inputs)\n\u001b[0;32m--> 201\u001b[0m logits, past_key_values \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msingletoken_engine\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mengine_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 203\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\n\u001b[1;32m 204\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;66;03m# cleanup state (this is BAD - calls np.ascontiguous)\u001b[39;00m\n\u001b[1;32m 207\u001b[0m past_key_values \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mslice_past_key_values(past_key_values, \u001b[38;5;241m1\u001b[39m)\n",
+ "Cell \u001b[0;32mIn[146], line 54\u001b[0m, in \u001b[0;36mDecoderEngine.__call__\u001b[0;34m(self, inputs, past_key_values, val_inp)\u001b[0m\n\u001b[1;32m 50\u001b[0m inp \u001b[38;5;241m=\u001b[39m [past_key_values[name] \u001b[38;5;28;01mif\u001b[39;00m name\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpast_key_values\u001b[39m\u001b[38;5;124m\"\u001b[39m) \n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m inputs[name] \u001b[38;5;28;01mfor\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mengine\u001b[38;5;241m.\u001b[39minput_names]\n\u001b[1;32m 53\u001b[0m \u001b[38;5;66;03m# run inference\u001b[39;00m\n\u001b[0;32m---> 54\u001b[0m logits, \u001b[38;5;241m*\u001b[39mkvs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mengine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43minp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 55\u001b[0m past_key_values \u001b[38;5;241m=\u001b[39m {name: arr \u001b[38;5;28;01mfor\u001b[39;00m name, arr \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpast_onnx_names, kvs)}\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m logits, past_key_values\n",
+ "File \u001b[0;32m~/.conda/envs/dscb/lib/python3.9/site-packages/deepsparse/engine.py:527\u001b[0m, in \u001b[0;36mEngine.run\u001b[0;34m(self, inp, val_inp)\u001b[0m\n\u001b[1;32m 524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m val_inp:\n\u001b[1;32m 525\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_inputs(inp)\n\u001b[0;32m--> 527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_eng_net\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute_list_out\u001b[49m\u001b[43m(\u001b[49m\u001b[43minp\u001b[49m\u001b[43m)\u001b[49m\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: NM: error: Invalid rank for input: input_ids Got: 3 Expected: 2 Please fix either the inputs or the model."
+ ]
+ }
+ ],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Intenally Managed"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import deepsparse\n",
+ "from deepsparse.transformers.utils.helpers import create_causal_mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-20 11:07:42 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n",
+ "Using pad_token, but it is not set yet.\n",
+ "2023-08-20 11:07:54 deepsparse.transformers.pipelines.text_generation INFO Compiling an auxiliary engine to process a prompt with a larger processing length. This improves performance, but may result in additional memory consumption.\n",
+ "2023-08-20 11:07:56 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n",
+ "2023-08-20 11:08:20 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
+ ]
+ }
+ ],
+ "source": [
+ "pipeline = deepsparse.Pipeline.create(\n",
+ " task=\"text-generation\", \n",
+ " model_path=\"zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none\",\n",
+ " use_deepsparse_cache=True,\n",
+ " prompt_processing_sequence_length=4,\n",
+ " max_generated_tokens=64,\n",
+ " sequence_length=128\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\n fib(n):\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TextGenerationOutput(sequences=['\\n\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Call the function.\\nprint(fib(5))\\n\\n# This code'], logits=None, session_id=None)"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pipeline(sequences=sequence)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "singletoken_engine = pipeline.engine\n",
+ "multitoken_engine = pipeline.multitoken_engine\n",
+ "assert singletoken_engine.kv_cache == multitoken_engine.kv_cache\n",
+ "kv_cache = singletoken_engine.kv_cache"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 271,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "without maintaining\n",
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ " fib(n):\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
+ "<|endoftext|>\n",
+ "\n",
+ "\n",
+ "maintaining\n",
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ " fib(n):\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
+ "<|endoftext|>\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy\n",
+ "\n",
+ "multitoken_length = pipeline.prompt_processing_sequence_length\n",
+ "sequence_length = pipeline.sequence_length\n",
+ "\n",
+ "def empty_past_key_values(engine):\n",
+ " past_key_values = {}\n",
+ " for idx, name in enumerate(engine.engine.input_names):\n",
+ " if name.startswith(\"past_key_values\"):\n",
+ " shape = engine.engine.input_shapes[idx]\n",
+ " past_key_values[name] = numpy.zeros(shape, dtype=engine.kv_cache_data_type)\n",
+ "\n",
+ " return past_key_values\n",
+ "\n",
+ "def engine_inputs_for_decode(tokens):\n",
+ " assert(len(tokens) < sequence_length)\n",
+ " \n",
+ " engine_inputs = {}\n",
+ " engine_inputs[\"input_ids\"] = numpy.array([[tokens[-1]]])\n",
+ " engine_inputs[\"attention_mask\"] = numpy.zeros((1, sequence_length), dtype=numpy.int64)\n",
+ " engine_inputs[\"attention_mask\"][:, -len(tokens):] = 1\n",
+ " \n",
+ " engine_inputs[\"causal_mask\"] = create_causal_mask(\n",
+ " engine_inputs[\"input_ids\"],\n",
+ " engine_inputs[\"attention_mask\"]\n",
+ " )\n",
+ " engine_inputs[\"positions\"] = numpy.array([[len(tokens) - 1]], dtype=numpy.int64)\n",
+ " \n",
+ " return engine_inputs\n",
+ "\n",
+ "def engine_inputs_for_prefill(tokens):\n",
+ " num_batches = len(tokens) // multitoken_length\n",
+ " token_batches = [tokens[i * multitoken_length : (i+1) * multitoken_length] for i in range(0, num_batches)]\n",
+ "\n",
+ " for idx, token_batch in enumerate(token_batches):\n",
+ " num_processed_tokens = multitoken_length * idx\n",
+ " \n",
+ " engine_inputs = {}\n",
+ " engine_inputs[\"input_ids\"] = numpy.array([token_batch])\n",
+ "\n",
+ " # make attention mask from the right\n",
+ " engine_inputs[\"attention_mask\"] = numpy.zeros((1, sequence_length), dtype=numpy.int64)\n",
+ " engine_inputs[\"attention_mask\"][:, -(num_processed_tokens + multitoken_length):] = 1\n",
+ "\n",
+ " # make positions (building from the right)\n",
+ " assert multitoken_length > 1\n",
+ " engine_inputs[\"positions\"] = numpy.arange(\n",
+ " num_processed_tokens, num_processed_tokens + multitoken_length\n",
+ " ).reshape(1, -1).astype(numpy.int64)\n",
+ "\n",
+ " # make causal mask (building from the right)\n",
+ " engine_inputs[\"causal_mask\"] = create_causal_mask(\n",
+ " input_ids=engine_inputs[\"input_ids\"], \n",
+ " attention_mask=engine_inputs[\"attention_mask\"]\n",
+ " )\n",
+ "\n",
+ " yield engine_inputs\n",
+ "\n",
+ "def call_engine(engine, engine_inputs, past_key_values):\n",
+ " # format inputs as list\n",
+ " inputs = [\n",
+ " past_key_values[name] if name.startswith(\"past_key_values\") \n",
+ " else engine_inputs[name] for name in engine.engine.input_names\n",
+ " ]\n",
+ "\n",
+ " # run inference\n",
+ " logits, *kvs = engine.engine._eng_net.execute_list_out(inputs, engine.kv_cache._kv_cache)\n",
+ "\n",
+ " # format output as dict\n",
+ " past_names = [name for name in engine.engine.input_names if name.startswith(\"past_key_values\")]\n",
+ " past_key_values = {name: arr for name, arr in zip(past_names, kvs)}\n",
+ " \n",
+ " return logits, past_key_values\n",
+ "\n",
+ "# bad -- returns a numpy.insert returns a full copy (update does NOT happen in place)\n",
+ "def insert_past_key_values(past_key_values, num_items=1, padding_value=0):\n",
+ " dtype = next(iter(past_key_values.values())).dtype\n",
+ "\n",
+ " for name in past_key_values:\n",
+ " padding_value = numpy.array(padding_value, dtype=dtype)\n",
+ " past_key_values[name] = numpy.insert(past_key_values[name], [0]*num_items, padding_value, axis=2)\n",
+ " return past_key_values\n",
+ "\n",
+ "# bad --- calls np.ascontiguous\n",
+ "def slice_past_key_values(past_key_values, slice_idx):\n",
+ " for name in past_key_values:\n",
+ " past_key_values[name] = numpy.ascontiguousarray(past_key_values[name][:,:,slice_idx:,:])\n",
+ " return past_key_values\n",
+ " \n",
+ "# maintians the kv cache state at pipeline level\n",
+ "def decode_maintain(tokens, past_key_values): \n",
+ " engine_inputs = engine_inputs_for_decode(tokens)\n",
+ "\n",
+ " logits, past_key_values = call_engine(\n",
+ " singletoken_engine,\n",
+ " engine_inputs=engine_inputs,\n",
+ " past_key_values=past_key_values\n",
+ " )\n",
+ "\n",
+ " # cleanup state (this is BAD - calls np.ascontiguous)\n",
+ " past_key_values = slice_past_key_values(past_key_values, 1)\n",
+ "\n",
+ " assert logits.shape[0] == 1 # assert batch 1 right now\n",
+ " assert logits.shape[1] == 1 # assert only one element\n",
+ " return logits, past_key_values\n",
+ "\n",
+ "# maintians the kv cache state at pipeline level\n",
+ "def prefill_maintain(tokens):\n",
+ " tokens_processed = 0\n",
+ " past_key_values = empty_past_key_values(multitoken_engine)\n",
+ "\n",
+ " for engine_inputs in engine_inputs_for_prefill(tokens):\n",
+ " logits, past_key_values = call_engine(\n",
+ " multitoken_engine, \n",
+ " engine_inputs=engine_inputs, \n",
+ " past_key_values=past_key_values\n",
+ " )\n",
+ " tokens_processed += multitoken_length\n",
+ "\n",
+ " # BAD - calls np.ascontgious - cleans up that engine returns past with prior_seq_len + input_ids_len\n",
+ " past_key_values = slice_past_key_values(past_key_values, multitoken_length)\n",
+ " \n",
+ " # (this is BAD - returns a copy) - expand kv cache for single token engine \n",
+ " past_key_values = insert_past_key_values(past_key_values, num_items=(multitoken_length-1))\n",
+ "\n",
+ " # loop of singletoken engine for anything left over\n",
+ " while tokens_processed < len(tokens):\n",
+ " logits, past_key_values = decode_maintain(\n",
+ " tokens=tokens[:tokens_processed+1],\n",
+ " past_key_values=past_key_values\n",
+ " )\n",
+ " tokens_processed += 1\n",
+ " \n",
+ " return logits, past_key_values\n",
+ "\n",
+ "empty_past_key_values_multi = empty_past_key_values(multitoken_engine)\n",
+ "empty_past_key_values_single = empty_past_key_values(singletoken_engine)\n",
+ "\n",
+ "# does not maintian kv cache state at pipeline level\n",
+ "def decode(tokens):\n",
+ " engine_inputs = engine_inputs_for_decode(tokens)\n",
+ "\n",
+ " logits, past_key_values = call_engine(\n",
+ " singletoken_engine,\n",
+ " engine_inputs=engine_inputs,\n",
+ " past_key_values=empty_past_key_values_single\n",
+ " )\n",
+ "\n",
+ " return logits\n",
+ "\n",
+ "# does not maintain the state at pipeline level\n",
+ "def prefill(tokens):\n",
+ " tokens_processed = 0\n",
+ " \n",
+ " for engine_inputs in engine_inputs_for_prefill(tokens):\n",
+ " logits, _ = call_engine(\n",
+ " multitoken_engine, \n",
+ " engine_inputs=engine_inputs, \n",
+ " past_key_values=empty_past_key_values_multi\n",
+ " )\n",
+ " tokens_processed += multitoken_length\n",
+ "\n",
+ " # loop of singletoken engine for anything left over\n",
+ " while tokens_processed < len(tokens):\n",
+ " logits = decode(tokens[:tokens_processed+1])\n",
+ " tokens_processed += 1\n",
+ " \n",
+ " return logits\n",
+ "\n",
+ "def sample_token(logits):\n",
+ " assert(logits.shape[0] == 1)\n",
+ " return numpy.argmax(logits[0,-1,:])\n",
+ "\n",
+ "eos_token = pipeline.tokenizer.eos_token_id\n",
+ "\n",
+ "print(\"without maintaining\")\n",
+ "pipeline._reset_engines_cache()\n",
+ "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]\n",
+ "tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()\n",
+ "\n",
+ "logits = prefill(tokens)\n",
+ "tokens.append(sample_token(logits))\n",
+ "while len(tokens) < sequence_length and tokens[-1] != eos_token:\n",
+ " logits = decode(tokens)\n",
+ " tokens.append(sample_token(logits))\n",
+ "\n",
+ "print(pipeline.tokenizer.decode(tokens))\n",
+ "\n",
+ "print(\"\\n\\nmaintaining\")\n",
+ "pipeline._reset_engines_cache()\n",
+ "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]\n",
+ "tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()\n",
+ "\n",
+ "logits, past_key_values = prefill_maintain(tokens)\n",
+ "tokens.append(sample_token(logits))\n",
+ "while len(tokens) < sequence_length and tokens[-1] != eos_token:\n",
+ " logits, past_key_values = decode_maintain(tokens, past_key_values)\n",
+ " tokens.append(sample_token(logits))\n",
+ " \n",
+ "print(pipeline.tokenizer.decode(tokens))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 278,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ " fib(n):\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
+ "<|endoftext|>\n"
+ ]
+ }
+ ],
+ "source": [
+ "def prefill_pipeline(pipeline, tokens):\n",
+ " num_tokens_processed = 0\n",
+ " for engine_inputs in pipeline.engine_inputs_for_prefill(tokens):\n",
+ " _, logits = pipeline.multitoken_engine(engine_inputs)\n",
+ " num_tokens_processed += multitoken_length\n",
+ "\n",
+ " if num_tokens_processed > 0:\n",
+ " pipeline.engine.transfer_cache_state(cache=pipeline.multitoken_engine.kv_cache)\n",
+ "\n",
+ " run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed]\n",
+ " for token in tokens[num_tokens_processed:]:\n",
+ " run_tokens.append(token)\n",
+ " new_token, logits = pipeline.autoregressive_inference(run_tokens)\n",
+ " return logits\n",
+ " \n",
+ "pipeline._reset_engines_cache()\n",
+ "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]\n",
+ "tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()\n",
+ "\n",
+ "logits = prefill_pipeline(pipeline, tokens)\n",
+ "tokens.append(sample_token(logits))\n",
+ "\n",
+ "while len(tokens) < pipeline.sequence_length and tokens[-1] != eos_token:\n",
+ " _, logits = pipeline.autoregressive_inference(tokens)\n",
+ " tokens.append(sample_token(logits))\n",
+ "\n",
+ "print(pipeline.tokenizer.decode(tokens))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 277,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ " fib(n):\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"{sequence}{pipeline(sequences=sequence).sequences[0]}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 279,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n",
+ "2023-08-20 13:44:57 deepsparse.transformers.pipelines.text_generation INFO Compiling an auxiliary engine to process a prompt with a larger processing length. This improves performance, but may result in additional memory consumption.\n",
+ "2023-08-20 13:44:58 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "2023-08-20 13:45:23 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
+ ]
+ }
+ ],
+ "source": [
+ "pipeline2 = deepsparse.Pipeline.create(\n",
+ " task=\"text-generation\", \n",
+ " model_path=\"zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none\",\n",
+ " use_deepsparse_cache=False,\n",
+ " prompt_processing_sequence_length=4,\n",
+ " max_generated_tokens=64,\n",
+ " sequence_length=128\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.17"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/server-dev.ipynb b/server-dev.ipynb
new file mode 100644
index 00000000..80a5e83b
--- /dev/null
+++ b/server-dev.ipynb
@@ -0,0 +1,1274 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "2ab30adb-ca8a-4ca3-9cbc-dcae6e244754",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%reload_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a19786b8-e72c-43c1-964f-45d92fd171e9",
+ "metadata": {},
+ "source": [
+ "## Example Interacting With The Router"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "0b2c83cd-92ea-40d7-bc7e-f737b87d9b8d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-23 19:52:18 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n"
+ ]
+ }
+ ],
+ "source": [
+ "from server.deepsparse.deepsparse_router import DeepSparseRouter, batching_task\n",
+ "from server.deepsparse.deepsparse_service import DeepSparseService\n",
+ "from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLM"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "78acf813-3688-483d-9148-5c0df5d6b8e3",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n",
+ "2023-08-23 19:52:20 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deepsparse.engine.Engine:\n",
+ "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "\tbatch_size: 1\n",
+ "\tnum_cores: 8\n",
+ "\tnum_streams: 1\n",
+ "\tscheduler: Scheduler.default\n",
+ "\tfraction_of_supported_ops: 1.0\n",
+ "\tcpu_avx_type: avx2\n",
+ "\tcpu_vnni: False\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-23 19:52:44 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deepsparse.engine.Engine:\n",
+ "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "\tbatch_size: 1\n",
+ "\tnum_cores: 8\n",
+ "\tnum_streams: 1\n",
+ "\tscheduler: Scheduler.default\n",
+ "\tfraction_of_supported_ops: 1.0\n",
+ "\tcpu_avx_type: avx2\n",
+ "\tcpu_vnni: False\n"
+ ]
+ }
+ ],
+ "source": [
+ "tokenizer_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment\"\n",
+ "onnx_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\"\n",
+ "\n",
+ "model = DeepSparseCausalLM(\n",
+ " tokenizer_path=tokenizer_path,\n",
+ " model_path=onnx_path\n",
+ ")\n",
+ "\n",
+ "service = DeepSparseService(model=model)\n",
+ "router = DeepSparseRouter(service=service)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "e93bac63-8924-4cf4-8683-81ce9333a2f1",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ "def fib(n):\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Driver function to test above function\n",
+ "n = int(input(\"Enter the number: \"))\n",
+ "print(fib(n))\n",
+ "\n",
+ "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n",
+ "\n",
+ "\n",
+ "\n",
+ "Write a function for filtering a list of integers to include only positive numbers:\n",
+ "\n",
+ "def filter(lst):\n",
+ " return [x for x in lst if x > 0]\n",
+ "\n",
+ "# Test\n",
+ "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
+ "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
+ "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n",
+ "print(filter([1,\n",
+ "\n",
+ "\n",
+ "Write a function for checking if a word if a palindrome:\n",
+ "\n",
+ "def is_palindrome(word):\n",
+ " return word == word[::-1]\n",
+ "\n",
+ "# Test\n",
+ "print(is_palindrome(\"racecar\"))\n",
+ "print(is_palindrome(\"racecar\"))\n",
+ "print(is_palindrome(\"racecar\"))\n",
+ "print(is_palindrome(\"racecar\"))\n",
+ "print(is_palindrome(\"racecar\"))\n",
+ "print(is_palindrome(\"racecar\"))\n",
+ "print(is_palindrome(\"racecar\"))\n",
+ "print(\n",
+ "\n",
+ "\n",
+ "Write a function for reversing a string:\n",
+ "\n",
+ "def reverse_string(s):\n",
+ " return s[::-1]\n",
+ "\n",
+ "# Test\n",
+ "print(reverse_string(\"hello\"))\n",
+ "print(reverse_string(\"\"))\n",
+ "print(reverse_string(\"a\"))\n",
+ "print(reverse_string(\"\"))\n",
+ "print(reverse_string(\"\"))\n",
+ "print(reverse_string(\"\"))\n",
+ "print(reverse_string(\"\"))\n",
+ "print(reverse_string(\"\"))\n",
+ "print(reverse_string(\"\"))\n",
+ "print(reverse_string(\"\"))\n",
+ "print(reverse_string(\"\"))\n",
+ "print(reverse_string(\"\n",
+ "\n",
+ "\n",
+ "Write a function for sorting an array of integers:\n",
+ "\n",
+ "def merge_sort(arr):\n",
+ " if len(arr) <= 1:\n",
+ " return arr\n",
+ " mid = len(arr) // 2\n",
+ " left = arr[:mid]\n",
+ " right = arr[mid:]\n",
+ " left = merge_sort(left)\n",
+ " right = merge_sort(right)\n",
+ " return merge(left, right)\n",
+ "\n",
+ "def merge(left, right):\n",
+ " result = []\n",
+ " while len(left) > 0 and len(right) > 0:\n",
+ " if left[0]\n",
+ "\n",
+ "\n",
+ "stop\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from threading import Thread\n",
+ "import time\n",
+ "\n",
+ "batching_thread = Thread(target=batching_task, args=[router])\n",
+ "batching_thread.start()\n",
+ "\n",
+ "prompts = [\n",
+ " \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n",
+ " \"Write a function for filtering a list of integers to include only positive numbers:\\n\\ndef filter(lst):\",\n",
+ " \"Write a function for reversing a string:\\n\\ndef reverse_string(s):\",\n",
+ " \"Write a function for checking if a word if a palindrome:\\n\\ndef is_palindrome(word):\",\n",
+ " \"Write a function for sorting an array of integers:\\n\\ndef merge_sort(arr):\",\n",
+ "]\n",
+ "\n",
+ "def generate_task(prompt):\n",
+ " result = router.generate(prompt=prompt)\n",
+ " print(result)\n",
+ " print(\"\\n\")\n",
+ "\n",
+ "generate_threads = [\n",
+ " Thread(target=generate_task, args=[prompt]) for prompt in prompts\n",
+ "]\n",
+ "\n",
+ "# print(len(generate_threads))\n",
+ "\n",
+ "for gt in generate_threads:\n",
+ " gt.start()\n",
+ " time.sleep(0.5)\n",
+ "\n",
+ "for gt in generate_threads:\n",
+ " gt.join()\n",
+ "\n",
+ "\n",
+ "generate_task(\"stop\")\n",
+ "batching_thread.join()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7d43c041-2c79-4276-9104-2f224b2f8af6",
+ "metadata": {},
+ "source": [
+ "## Example Interacting With The Service"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "631e94eb-cca0-438e-8936-6e8a87166d63",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-22 14:26:39 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n"
+ ]
+ }
+ ],
+ "source": [
+ "from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLMBatch, DeepSparseCausalLM\n",
+ "from server.deepsparse.deepsparse_service import DeepSparseService\n",
+ "from server.deepsparse.deepsparse_requests import (\n",
+ " PrefillRequest, DecodeRequest, FilterBatchRequest, Request\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "c9c39557-2898-443f-aae8-443ef1171123",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n",
+ "2023-08-22 14:26:56 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deepsparse.engine.Engine:\n",
+ "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "\tbatch_size: 1\n",
+ "\tnum_cores: 8\n",
+ "\tnum_streams: 1\n",
+ "\tscheduler: Scheduler.default\n",
+ "\tfraction_of_supported_ops: 1.0\n",
+ "\tcpu_avx_type: avx2\n",
+ "\tcpu_vnni: False\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-22 14:27:21 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deepsparse.engine.Engine:\n",
+ "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "\tbatch_size: 1\n",
+ "\tnum_cores: 8\n",
+ "\tnum_streams: 1\n",
+ "\tscheduler: Scheduler.default\n",
+ "\tfraction_of_supported_ops: 1.0\n",
+ "\tcpu_avx_type: avx2\n",
+ "\tcpu_vnni: False\n"
+ ]
+ }
+ ],
+ "source": [
+ "tokenizer_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment\"\n",
+ "onnx_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\"\n",
+ "\n",
+ "model = DeepSparseCausalLM(\n",
+ " tokenizer_path=tokenizer_path,\n",
+ " model_path=onnx_path\n",
+ ")\n",
+ "\n",
+ "service = DeepSparseService(model=model)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "85ce9aab-1a56-4b6f-a82b-4e91d52290b7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "prompts = [\n",
+ " \"Finish the following function for computing a fibonacci sequence: \\n\\n fib(n):\",\n",
+ " \"Write a function for filtering a list of integers to include only positive numbers:\\n\\nfilter(lst):\",\n",
+ " \"Write a function for reversing a string:\\n\\ndef reverse_string(s):\",\n",
+ " \"Write a function for checking if a word if a palindrome:\\n\\ndef is_palindrome(word):\",\n",
+ " \"Write a function for sorting an array of integers:\\n\\ndef merge_sort(arr):\",\n",
+ "]\n",
+ "\n",
+ "def make_batch(id, prompt):\n",
+ " return Batch(\n",
+ " id=id,\n",
+ " requests=[Request(id=id, prompt=prompt)]\n",
+ " )\n",
+ "\n",
+ "class PrefillQueue:\n",
+ " def __init__(self, prompts):\n",
+ " self.queue = {\n",
+ " idx: PrefillRequest(batch=make_batch(id=idx, prompt=prompt))\n",
+ " for idx, prompt in enumerate(prompts)\n",
+ " }\n",
+ "\n",
+ " def pop(self):\n",
+ " keys = list(self.queue.keys())\n",
+ " if len(keys) == 0:\n",
+ " return None\n",
+ " else:\n",
+ " return self.queue.pop(keys[0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "d2441753-fe2a-45c0-ad80-135b6207947d",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "name 'Batch' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[5], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m service\u001b[38;5;241m.\u001b[39mClearCache()\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# prefill queue\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m prefill_queue \u001b[38;5;241m=\u001b[39m \u001b[43mPrefillQueue\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompts\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# cached batches\u001b[39;00m\n\u001b[1;32m 7\u001b[0m cached_batches \u001b[38;5;241m=\u001b[39m []\n",
+ "Cell \u001b[0;32mIn[4], line 17\u001b[0m, in \u001b[0;36mPrefillQueue.__init__\u001b[0;34m(self, prompts)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, prompts):\n\u001b[0;32m---> 17\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqueue \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 18\u001b[0m idx: PrefillRequest(batch\u001b[38;5;241m=\u001b[39mmake_batch(\u001b[38;5;28mid\u001b[39m\u001b[38;5;241m=\u001b[39midx, prompt\u001b[38;5;241m=\u001b[39mprompt))\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx, prompt \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(prompts)\n\u001b[1;32m 20\u001b[0m }\n",
+ "Cell \u001b[0;32mIn[4], line 18\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, prompts):\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqueue \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m---> 18\u001b[0m idx: PrefillRequest(batch\u001b[38;5;241m=\u001b[39m\u001b[43mmake_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43midx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprompt\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx, prompt \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(prompts)\n\u001b[1;32m 20\u001b[0m }\n",
+ "Cell \u001b[0;32mIn[4], line 10\u001b[0m, in \u001b[0;36mmake_batch\u001b[0;34m(id, prompt)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmake_batch\u001b[39m(\u001b[38;5;28mid\u001b[39m, prompt):\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mBatch\u001b[49m(\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28mid\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mid\u001b[39m,\n\u001b[1;32m 12\u001b[0m requests\u001b[38;5;241m=\u001b[39m[Request(\u001b[38;5;28mid\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mid\u001b[39m, prompt\u001b[38;5;241m=\u001b[39mprompt)]\n\u001b[1;32m 13\u001b[0m )\n",
+ "\u001b[0;31mNameError\u001b[0m: name 'Batch' is not defined"
+ ]
+ }
+ ],
+ "source": [
+ "service.ClearCache()\n",
+ "\n",
+ "# prefill queue\n",
+ "prefill_queue = PrefillQueue(prompts)\n",
+ "\n",
+ "# cached batches\n",
+ "cached_batches = []\n",
+ "\n",
+ "# generated\n",
+ "generated_text = {}\n",
+ "\n",
+ "def prefill(request):\n",
+ " generation, cached_batch = service.Prefill(request)\n",
+ " \n",
+ " assert request.batch.requests[0].id == generation.request_id\n",
+ " assert generation.request_id not in generated_text.keys()\n",
+ " \n",
+ " generated_text[generation.request_id] = request.batch.requests[0].prompt + generation.generated_text\n",
+ "\n",
+ " return cached_batch\n",
+ "\n",
+ "def decode(request):\n",
+ " for cached_batch in request.batches:\n",
+ " for request_id in cached_batch.request_ids:\n",
+ " assert request_id in generated_text.keys()\n",
+ "\n",
+ " generations, cached_batch = service.Decode(request)\n",
+ " if cached_batch is None:\n",
+ " print(\"All requests done!\\n\\n\")\n",
+ " return None\n",
+ " \n",
+ " active_request_ids = []\n",
+ " stopped_request_ids = []\n",
+ " \n",
+ " for generation in generations:\n",
+ " assert generation.request_id in generated_text.keys()\n",
+ "\n",
+ " # if text is None, we stopped\n",
+ " if generation.generated_text is None:\n",
+ " print(f\"Request {generation.request_id} is done!\")\n",
+ " stopped_request_ids.append(generation.request_id)\n",
+ " \n",
+ " else:\n",
+ " generated_text[generation.request_id] += generation.generated_text\n",
+ " active_request_ids.append(generation.request_id)\n",
+ " \n",
+ " # if any stopped, return this\n",
+ " if len(stopped_request_ids) > 0:\n",
+ " cached_batch = service.FilterBatch(FilterBatchRequest(\n",
+ " batch_id=cached_batch.batch_id,\n",
+ " request_ids=active_request_ids,\n",
+ " ))\n",
+ " \n",
+ " return cached_batch\n",
+ "\n",
+ "# run a prefille\n",
+ "queue_not_empty = True\n",
+ "while queue_not_empty:\n",
+ " prefill_request = prefill_queue.pop()\n",
+ " if prefill_request is not None:\n",
+ " cached_batch = prefill(prefill_request)\n",
+ " cached_batches.append(cached_batch)\n",
+ " else:\n",
+ " queue_not_empty = False\n",
+ " \n",
+ " # run a few decodes\n",
+ " for _ in range(5):\n",
+ " cached_batches = [decode(DecodeRequest(cached_batches))]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "dd6bcc43-63ef-4f92-a960-74e33b86dc97",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# run a few decodes\n",
+ "for _ in range(100):\n",
+ " cached_batch = decode(DecodeRequest(cached_batches))\n",
+ " if cached_batch is None:\n",
+ " break\n",
+ " cached_batches = [cached_batch]\n",
+ " \n",
+ "for idx, value in generated_text.items():\n",
+ " print(f\"INDEX = {idx}:\")\n",
+ " print(value)\n",
+ " print(\"\\n\")\n",
+ "\n",
+ "print(cached_batches)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f9198565-a7e3-4ba4-8f46-b21adc4d87ac",
+ "metadata": {},
+ "source": [
+ "## Example DeepSparseCausalLMBatch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "5bf269cd-3d85-46c4-b80c-7d3d7199756a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-22 01:33:22 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n"
+ ]
+ }
+ ],
+ "source": [
+ "from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLMBatch, DeepSparseCausalLM\n",
+ "from server.deepsparse.deepsparse_requests import Request, Batch\n",
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "tokenizer_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment\"\n",
+ "onnx_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "fc4c3d6a-d90d-46d2-943d-4d12297599eb",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n",
+ "2023-08-22 01:33:25 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deepsparse.engine.Engine:\n",
+ "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "\tbatch_size: 1\n",
+ "\tnum_cores: 8\n",
+ "\tnum_streams: 1\n",
+ "\tscheduler: Scheduler.default\n",
+ "\tfraction_of_supported_ops: 1.0\n",
+ "\tcpu_avx_type: avx2\n",
+ "\tcpu_vnni: False\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-22 01:33:49 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deepsparse.engine.Engine:\n",
+ "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "\tbatch_size: 1\n",
+ "\tnum_cores: 8\n",
+ "\tnum_streams: 1\n",
+ "\tscheduler: Scheduler.default\n",
+ "\tfraction_of_supported_ops: 1.0\n",
+ "\tcpu_avx_type: avx2\n",
+ "\tcpu_vnni: False\n"
+ ]
+ }
+ ],
+ "source": [
+ "ds_model = DeepSparseCausalLM(\n",
+ " tokenizer_path=tokenizer_path,\n",
+ " model_path=onnx_path\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "442c3dfd-c03e-4791-a1ae-212a2820857b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ " fib(n):\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code\n",
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ " fib(n):\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code\n"
+ ]
+ }
+ ],
+ "source": [
+ "sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\n fib(n):\"\n",
+ "\n",
+ "def make_n_requests(n=1):\n",
+ " requests = []\n",
+ " for i in range(n):\n",
+ " request = Request(\n",
+ " id=i,\n",
+ " prompt=sequence,\n",
+ " )\n",
+ " requests.append(request)\n",
+ " return requests\n",
+ "\n",
+ "batch_size = 2\n",
+ "batch = Batch(\n",
+ " id=0,\n",
+ " requests = make_n_requests(n=batch_size),\n",
+ ")\n",
+ "\n",
+ "ds_batch = DeepSparseCausalLMBatch.from_batch(\n",
+ " batch=batch,\n",
+ " tokenizer=tokenizer, \n",
+ ")\n",
+ "\n",
+ "next_batch = ds_batch\n",
+ "for _ in range(64):\n",
+ " # print(tokenizer.batch_decode(next_batch.input_ids_list[0]))\n",
+ " generation, next_batch = ds_model.generate_token(next_batch)\n",
+ "\n",
+ "for input_ids in next_batch.input_ids_list:\n",
+ " print(tokenizer.batch_decode(input_ids)[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a45ba351-0e14-4440-9962-bb692599ae2a",
+ "metadata": {},
+ "source": [
+ "## Compare to DeepSparse Pipeline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 134,
+ "id": "fc45233a-9a34-42bb-b6b0-7b19dd5763e9",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Finish the following function for computing a fibonacci sequence: \n",
+ "\n",
+ " fib(n):\n",
+ "\n",
+ " if n == 0:\n",
+ " return 0\n",
+ " elif n == 1:\n",
+ " return 1\n",
+ " else:\n",
+ " return fib(n-1) + fib(n-2)\n",
+ "\n",
+ "# Call the function.\n",
+ "print(fib(5))\n",
+ "\n",
+ "# This code is\n"
+ ]
+ }
+ ],
+ "source": [
+ "multitoken_length = 4\n",
+ "\n",
+ "def sample_token(logits):\n",
+ " assert(logits.shape[0] == 1) # assert b=1 for now\n",
+ " return np.argmax(logits[0,-1,:]) \n",
+ " \n",
+ "def prefill_pipeline(pipeline, tokens):\n",
+ " num_tokens_processed = 0\n",
+ " for engine_inputs in pipeline.engine_inputs_for_prefill(tokens):\n",
+ " _, logits = pipeline.multitoken_engine(engine_inputs)\n",
+ " num_tokens_processed += multitoken_length\n",
+ " \n",
+ " if num_tokens_processed > 0:\n",
+ " pipeline.engine.transfer_cache_state(cache=pipeline.multitoken_engine.kv_cache)\n",
+ "\n",
+ " run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed]\n",
+ " for token in tokens[num_tokens_processed:]:\n",
+ " run_tokens.append(token)\n",
+ " new_token, logits = pipeline.autoregressive_inference(run_tokens)\n",
+ " return logits\n",
+ " \n",
+ "pipeline._reset_engines_cache()\n",
+ "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]\n",
+ "tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()\n",
+ "\n",
+ "logits = prefill_pipeline(pipeline, tokens)\n",
+ "# print(logits)\n",
+ "tokens.append(sample_token(logits))\n",
+ "\n",
+ "for _ in range(64):\n",
+ " _, logits = pipeline.autoregressive_inference(tokens)\n",
+ " # print(logits)\n",
+ " tokens.append(sample_token(logits))\n",
+ "\n",
+ "print(pipeline.tokenizer.decode(tokens))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6ac484d6-093d-411f-909a-2ac143b26cec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from deepsparse import Pipeline\n",
+ "pipeline = Pipeline.create(\n",
+ " task=\"text-generation\", \n",
+ " model_path=\"zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none\",\n",
+ " use_deepsparse_cache=False,\n",
+ " prompt_processing_sequence_length=4,\n",
+ " max_generated_tokens=64,\n",
+ " sequence_length=128\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 101,
+ "id": "9574f0f7-c882-499a-ba8a-c107df0655ad",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(1, 18)"
+ ]
+ },
+ "execution_count": 101,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "next_batch.input_ids_list[0].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 97,
+ "id": "eeb1449f-82f2-4bad-9265-5ddbf0944a4d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "numpy.ndarray"
+ ]
+ },
+ "execution_count": 97,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type(next_batch.input_ids_list[0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 98,
+ "id": "9a0104a8-3412-41a4-acd0-0dbbdf0fd9da",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "argument 'ids': 'list' object cannot be interpreted as an integer",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[98], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnext_batch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_ids_list\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.conda/envs/dscb/lib/python3.9/site-packages/transformers/models/codegen/tokenization_codegen_fast.py:219\u001b[0m, in \u001b[0;36mCodeGenTokenizerFast.decode\u001b[0;34m(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, truncate_before_pattern, **kwargs)\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode\u001b[39m(\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 188\u001b[0m token_ids: Union[\u001b[38;5;28mint\u001b[39m, List[\u001b[38;5;28mint\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnp.ndarray\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch.Tensor\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtf.Tensor\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 193\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 194\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;124;03m Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special\u001b[39;00m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;124;03m tokens and clean up tokenization spaces.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;124;03m `str`: The decoded sentence.\u001b[39;00m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 219\u001b[0m decoded_text \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 220\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 221\u001b[0m \u001b[43m \u001b[49m\u001b[43mskip_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 222\u001b[0m \u001b[43m \u001b[49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 223\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m truncate_before_pattern \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(truncate_before_pattern) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 227\u001b[0m decoded_text \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtruncate(decoded_text, truncate_before_pattern)\n",
+ "File \u001b[0;32m~/.conda/envs/dscb/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:3496\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.decode\u001b[0;34m(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)\u001b[0m\n\u001b[1;32m 3493\u001b[0m \u001b[38;5;66;03m# Convert inputs to python lists\u001b[39;00m\n\u001b[1;32m 3494\u001b[0m token_ids \u001b[38;5;241m=\u001b[39m to_py_obj(token_ids)\n\u001b[0;32m-> 3496\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_decode\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3497\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3498\u001b[0m \u001b[43m \u001b[49m\u001b[43mskip_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3499\u001b[0m \u001b[43m \u001b[49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclean_up_tokenization_spaces\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3500\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3501\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.conda/envs/dscb/lib/python3.9/site-packages/transformers/tokenization_utils_fast.py:549\u001b[0m, in \u001b[0;36mPreTrainedTokenizerFast._decode\u001b[0;34m(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)\u001b[0m\n\u001b[1;32m 547\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(token_ids, \u001b[38;5;28mint\u001b[39m):\n\u001b[1;32m 548\u001b[0m token_ids \u001b[38;5;241m=\u001b[39m [token_ids]\n\u001b[0;32m--> 549\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_tokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtoken_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_special_tokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 551\u001b[0m clean_up_tokenization_spaces \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 552\u001b[0m clean_up_tokenization_spaces\n\u001b[1;32m 553\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m clean_up_tokenization_spaces \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 554\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclean_up_tokenization_spaces\n\u001b[1;32m 555\u001b[0m )\n\u001b[1;32m 556\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m clean_up_tokenization_spaces:\n",
+ "\u001b[0;31mTypeError\u001b[0m: argument 'ids': 'list' object cannot be interpreted as an integer"
+ ]
+ }
+ ],
+ "source": [
+ "tokenizer.decode(next_batch.input_ids_list[0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 80,
+ "id": "ce285999-6394-42b5-9c6b-d8e1743d068b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(1, 20)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(next_batch.input_ids_list[1].shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "46d64cbf-e67d-4f24-b672-5365153a4781",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n",
+ "2023-08-21 18:14:09 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deepsparse.engine.Engine:\n",
+ "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "\tbatch_size: 1\n",
+ "\tnum_cores: 8\n",
+ "\tnum_streams: 1\n",
+ "\tscheduler: Scheduler.default\n",
+ "\tfraction_of_supported_ops: 1.0\n",
+ "\tcpu_avx_type: avx2\n",
+ "\tcpu_vnni: False\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-08-21 18:14:33 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deepsparse.engine.Engine:\n",
+ "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
+ "\tbatch_size: 1\n",
+ "\tnum_cores: 8\n",
+ "\tnum_streams: 1\n",
+ "\tscheduler: Scheduler.default\n",
+ "\tfraction_of_supported_ops: 1.0\n",
+ "\tcpu_avx_type: avx2\n",
+ "\tcpu_vnni: False\n"
+ ]
+ }
+ ],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "id": "dbb071c7-076a-469e-9cfe-a9b9e4108c2d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[9]]\n",
+ "(1, 10)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "a = np.array([np.arange(10)]*2)\n",
+ "b = np.array([np.arange(10)]*1)\n",
+ "\n",
+ "print(b[:,-1:])\n",
+ "print(b.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "id": "53616cc6-ae91-410d-b6fa-4f0bd71be16a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(1, 18)\n",
+ "(1, 19)\n"
+ ]
+ }
+ ],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6c1fed0d-6930-4b03-96a1-04a7f6d13434",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5086c67f-a20a-44e8-865a-a026641d2761",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c74361be-2020-44e9-8646-0d14298e577d",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "59f6d438-ecd4-44a5-acd1-334c408a891e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import deepsparse\n",
+ "import torch\n",
+ "from transformers import AutoTokenizer\n",
+ "from server.text_generation_server.models.deepsparse_causal_lm import DeepSparseCausalLMBatch\n",
+ "from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling, StopSequenceCriteria\n",
+ "\n",
+ "from server.text_generation_server.pb.generate_pb2 import (\n",
+ " Batch, \n",
+ " Request, \n",
+ " NextTokenChooserParameters, \n",
+ " StoppingCriteriaParameters\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "06b86098-120f-4fff-9952-06a217494b31",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment\"\n",
+ "onnx_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "629ffcf7-a648-4a2c-a8b5-1eedc97ffa21",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "next_token_chooser = NextTokenChooser(\n",
+ " watermark=False,\n",
+ " temperature=1.0,\n",
+ " repetition_penalty=0.0,\n",
+ " top_k=None,\n",
+ " top_p=None,\n",
+ " typical_p=None,\n",
+ " do_sample=False,\n",
+ " seed=0,\n",
+ " device=\"cpu\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "765bc684-d0cd-4c0d-bf52-33a90def89ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "stopping_crtieria=StoppingCriteria(\n",
+ " eos_token_id=tokenizer.eos_token_id,\n",
+ " stop_sequence_criterias=[],\n",
+ " max_new_tokens=20,\n",
+ " ignore_eos_token=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "15489a78-44a0-412a-8a73-13b8552e6ca6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\n fib(n):\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "08d015d8-d9fc-45a7-9d4a-c674c994084a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "request_idx = 0\n",
+ "\n",
+ "max_new_tokens = 64\n",
+ "\n",
+ "parameters = NextTokenChooserParameters(\n",
+ " watermark=False,\n",
+ " temperature=1.0,\n",
+ " repetition_penalty=0.0,\n",
+ " do_sample=False,\n",
+ " typical_p=1.0,\n",
+ " top_k = 0,\n",
+ " top_p = 1.0,\n",
+ ")\n",
+ "\n",
+ "stopping_parameters = StoppingCriteriaParameters(\n",
+ " max_new_tokens=max_new_tokens\n",
+ ")\n",
+ "\n",
+ "def make_n_requests(n=1):\n",
+ " requests = []\n",
+ " for i in range(n):\n",
+ " request = Request(\n",
+ " id=request_idx,\n",
+ " inputs=sequence,\n",
+ " truncate=False,\n",
+ " parameters=parameters,\n",
+ " stopping_parameters=stopping_parameters,\n",
+ " prefill_logprobs=False\n",
+ " )\n",
+ " requests.append(request)\n",
+ " return requests\n",
+ "\n",
+ "batch_size = 2\n",
+ "requests = make_n_requests(n=batch_size)\n",
+ "\n",
+ "batch = Batch(\n",
+ " id = 0,\n",
+ " requests = requests,\n",
+ " size=len(requests),\n",
+ ")\n",
+ "\n",
+ "ds_batch = DeepSparseCausalLMBatch.from_pb(\n",
+ " pb=batch, \n",
+ " tokenizer=tokenizer, \n",
+ " dtype=torch.float32,\n",
+ " device=\"cpu\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c5873e4a-3c60-4764-9a78-85003bf4516f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"True\"\n",
+ "os.environ[\"WAND_OPT_FLAGS\"] = \"default,~pyramids\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4160e9fa-875b-4cb5-9284-d98fbda1c53f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from server.text_generation_server.models.deepsparse_model import DeepSparseDecoderModel, DeepSparsePastKeyValues\n",
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "model_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment\"\n",
+ "onnx_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "95f56f49-8dd9-4281-a37c-74011b4fdfd9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ds_decoder_model = DeepSparseDecoderModel(\n",
+ " onnx_file_path = onnx_path,\n",
+ " sequence_length = 128,\n",
+ " multitoken_length = 4,\n",
+ " # singletoken_engine = ds_decoder_model.singletoken_engine,\n",
+ " # multitoken_engine = ds_decoder_model.multitoken_engine\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f780b506-7a92-4b52-83a9-424d4337b0dd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from deepsparse import Pipeline\n",
+ "pipeline = Pipeline.create(\n",
+ " task=\"text-generation\", \n",
+ " model_path=\"zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none\",\n",
+ " use_deepsparse_cache=False,\n",
+ " prompt_processing_sequence_length=4,\n",
+ " max_generated_tokens=64,\n",
+ " sequence_length=128\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d4abe7e2-98e4-4b5b-b2af-8c6037e71ba4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\n fib(n):\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ff677bb4-e3dc-4201-bcb7-6b28da1cbf9e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "def sample_token(logits):\n",
+ " assert(logits.shape[0] == 1)\n",
+ " return np.argmax(logits[0,-1,:])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "92ead309-995b-4d96-9974-012be3fc46bc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"testing DeepSparseDecoderModel:\\n\")\n",
+ "\n",
+ "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]\n",
+ "tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()\n",
+ "\n",
+ "past_key_values = DeepSparsePastKeyValues()\n",
+ "logits, past_key_values = ds_decoder_model.prefill(tokens, past_key_values)\n",
+ "tokens.append(sample_token(logits))\n",
+ "\n",
+ "while len(tokens) < 64:\n",
+ " logits, past_key_values = ds_decoder_model.decode(tokens, past_key_values)\n",
+ " tokens.append(sample_token(logits))\n",
+ "\n",
+ "print(pipeline.tokenizer.decode(tokens))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c12819c0-0d74-4e68-9620-43f4ca9a69ec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "multitoken_length = 4\n",
+ "\n",
+ "def prefill_pipeline(pipeline, tokens):\n",
+ " num_tokens_processed = 0\n",
+ " for engine_inputs in pipeline.engine_inputs_for_prefill(tokens):\n",
+ " _, logits = pipeline.multitoken_engine(engine_inputs)\n",
+ " num_tokens_processed += multitoken_length\n",
+ "\n",
+ " if num_tokens_processed > 0:\n",
+ " pipeline.engine.transfer_cache_state(cache=pipeline.multitoken_engine.kv_cache)\n",
+ "\n",
+ " run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed]\n",
+ " for token in tokens[num_tokens_processed:]:\n",
+ " run_tokens.append(token)\n",
+ " new_token, logits = pipeline.autoregressive_inference(run_tokens)\n",
+ " return logits\n",
+ " \n",
+ "pipeline._reset_engines_cache()\n",
+ "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]\n",
+ "tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()\n",
+ "\n",
+ "logits = prefill_pipeline(pipeline, tokens)\n",
+ "tokens.append(sample_token(logits))\n",
+ "\n",
+ "while len(tokens) < 64:\n",
+ " _, logits = pipeline.autoregressive_inference(tokens)\n",
+ " tokens.append(sample_token(logits))\n",
+ "\n",
+ "print(pipeline.tokenizer.decode(tokens))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3098f6f5-e745-4b08-be11-eaf8aa03f858",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(f\"{sequence}{pipeline(sequences=sequence).sequences[0]}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.17"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/server/deepsparse/deepsparse_causal_lm.py b/server/deepsparse/deepsparse_causal_lm.py
new file mode 100644
index 00000000..bf182395
--- /dev/null
+++ b/server/deepsparse/deepsparse_causal_lm.py
@@ -0,0 +1,251 @@
+import numpy as np
+from dataclasses import dataclass
+from typing import List, Dict, Optional
+
+from transformers import AutoTokenizer, PreTrainedTokenizerBase
+
+from server.deepsparse.deepsparse_model import (
+ DeepSparsePastKeyValues, DeepSparseDecoderModel
+)
+from server.deepsparse.deepsparse_requests import (
+ Request, Batch, CachedBatch, Generation
+)
+
+DEEPSPARSE_SEQUENCE_LENGTH = 128
+DEEPSPARSE_MULTITOKEN_LENGTH = 4
+
+@dataclass
+class DeepSparseCausalLMBatch:
+ batch_id: int
+ requests: List[Request]
+ requests_idx_mapping: Dict[int,int]
+ input_ids_list: List[np.ndarray]
+ past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
+
+ @classmethod
+ def from_batch(
+ cls,
+ batch: Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ ) -> "DeepSparseCausalLMBatch":
+
+ # parse batch
+ requests_idx_mapping = {}
+ input_ids_list = []
+
+ # setup tokenizer for deepsparse left padding
+ tokenizer.padding_side = "left"
+ if not tokenizer.pad_token:
+ tokenizer.pad_token = tokenizer.eos_token
+ padding, truncation = "longest", False
+
+ # loop through items in the batch
+ for idx, r in enumerate(batch.requests):
+ requests_idx_mapping[r.id] = idx
+
+ # setup inputs_ids, past_key_values
+ tokenized_inputs = tokenizer(
+ r.prompt,
+ return_tensors="np",
+ padding=padding,
+ truncation=truncation,
+ return_token_type_ids=False,
+ max_length=DEEPSPARSE_SEQUENCE_LENGTH
+ )
+ input_ids_list.append(tokenized_inputs["input_ids"])
+
+ return cls(
+ batch_id=batch.id,
+ requests=batch.requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids_list=input_ids_list,
+ past_key_values_list=[None] * len(batch.requests),
+ )
+
+ def to_batch(self) -> CachedBatch:
+ return CachedBatch(
+ batch_id = self.batch_id,
+ request_ids=[r.id for r in self.requests],
+ )
+
+ # length of the batch
+ def __len__(self):
+ return len(self.requests)
+
+ # pass list of request ids, returns batch with only those request ids
+ def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
+ assert(len(request_ids) > 0)
+
+ requests_idx_mapping = {}
+ requests = []
+ input_ids_list = []
+ past_key_values_list = []
+
+ # loop through requests, keep ones that should remain
+ for new_idx, request_id in enumerate(request_ids):
+ assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch"
+
+ requests_idx_mapping[request_id] = new_idx
+
+ old_idx = self.requests_idx_mapping[request_id]
+ requests.append(self.requests[old_idx])
+ input_ids_list.append(self.input_ids_list[old_idx])
+ past_key_values_list.append(self.past_key_values_list[old_idx])
+
+ # update batch state
+ self.requests = requests
+ self.requests_idx_mapping = requests_idx_mapping
+ self.input_ids_list = input_ids_list
+ self.past_key_values_list = past_key_values_list
+
+ return self
+
+ # combine two batches into one
+ @classmethod
+ def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
+ assert len(batches) > 1, "must have more than 1 batch to concatenate"
+
+ requests_idx_mapping = {}
+ requests = []
+ input_ids_list = []
+ past_key_values_list = []
+
+ start_index = 0
+ for i, batch in enumerate(batches):
+ assert batch.past_key_values_list is not None, "only concatenate prefilled batches"
+
+ # concatenate request, input_ids, and past_key_values lists
+ requests.extend(batch.requests)
+ input_ids_list.extend(batch.input_ids_list)
+ past_key_values_list.extend(batch.past_key_values_list)
+
+ # merge the request_id to index mapping
+ if i == 0:
+ requests_idx_mapping = batch.requests_idx_mapping
+ else:
+ for k, v in batch.requests_idx_mapping.items():
+ requests_idx_mapping[k] = v + start_index
+
+ start_index += len(batch)
+
+ return cls(
+ batch_id=batches[0].batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids_list=input_ids_list,
+ past_key_values_list=past_key_values_list
+ )
+
+class DeepSparseCausalLM:
+ def __init__(
+ self,
+ model_path: str,
+ tokenizer_path: str,
+ ):
+ # setup tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ self.tokenizer.padding_side = "left"
+ if not self.tokenizer.pad_token:
+ assert self.tokenizer.eos_token
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ # setup model
+ self.model = DeepSparseDecoderModel(
+ onnx_file_path = model_path,
+ sequence_length = DEEPSPARSE_SEQUENCE_LENGTH,
+ multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
+ )
+
+ # TODO (@rsnm2): switch to NextTokenChooser
+ def sample_token(
+ self,
+ logits: np.ndarray
+ ):
+ assert(logits.shape[0] == 1) # assert b=1 for now
+ return np.argmax(logits[0,-1,:]) # grab logits for the last item in the sequence
+
+ # TODO (@rsnm2): switch to StoppingCriteria
+ def should_stop(
+ self,
+ num_tokens_processed: int,
+ generated_token_id: int
+ ):
+ if num_tokens_processed >= self.model.sequence_length:
+ return True
+ if generated_token_id == self.tokenizer.eos_token_id:
+ return True
+ return False
+
+ def generate_token(
+ self,
+ batch: DeepSparseCausalLMBatch,
+ ) -> (List[Generation], Optional[DeepSparseCausalLMBatch]):
+
+ generations: List[Generation] = []
+ all_stopped = True
+
+ # if we supported continuous batching, we would do batched inference here
+ # logits, past_key_values = self.model(batch)
+
+ # for each member of the batch:
+ # a) run inference
+ # b) sample and check stopping criteria
+ # c) create generation + update batch
+ for i, (
+ request,
+ input_ids,
+ past_key_values,
+ ) in enumerate(zip(
+ batch.requests,
+ batch.input_ids_list,
+ batch.past_key_values_list
+ )):
+
+ # run inference
+ logits, past_key_values = self.model(input_ids, past_key_values)
+
+ # sample token
+ # todo: simple for now --- should use NextTokenChooser
+ generated_token_id = self.sample_token(logits)
+
+ # check stopping criteria
+ # todo: simple for now --- should use StoppingCriteria
+ assert len(input_ids.shape) == 2
+ assert input_ids.shape[0] == 1
+
+ stop = self.should_stop(
+ num_tokens_processed=input_ids.shape[1] + 1,
+ generated_token_id = generated_token_id
+ )
+
+ # if not stopped, convert token id to text
+ generated_text = None
+ if not stop:
+ all_stopped = False
+ generated_text = self.tokenizer.decode(
+ generated_token_id,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False
+ )
+ generations.append(Generation(
+ request_id=request.id,
+ generated_text=generated_text
+ ))
+
+ # update values in the batch
+ # bad --- this does not occur in place
+ assert len(batch.input_ids_list[i].shape) == 2
+ assert batch.input_ids_list[i].shape[0] == 1
+ batch.input_ids_list[i] = np.append(
+ batch.input_ids_list[i],
+ np.array([[generated_token_id]]),
+ axis=1
+ )
+ batch.past_key_values_list[i] = past_key_values
+
+ # if all elements of the batch are done, return generation + null for batch
+ if all_stopped:
+ return generations, None
+
+ # return generation + updated batch
+ return generations, batch
\ No newline at end of file
diff --git a/server/deepsparse/deepsparse_model.py b/server/deepsparse/deepsparse_model.py
new file mode 100644
index 00000000..9b0082bc
--- /dev/null
+++ b/server/deepsparse/deepsparse_model.py
@@ -0,0 +1,241 @@
+import os
+os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"
+
+import numpy as np
+from typing import Optional, List, Dict
+
+from deepsparse import Context
+from deepsparse.engine import LIB
+from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
+from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs, create_causal_mask
+
+PAST_KEY_VALUES_NAME = "past_key_values"
+
+class DeepSparsePastKeyValues:
+ def __init__(self):
+ prev_num_tokens = 0
+ num_frozen_tokens = 1
+ self.internal_past_key_values = LIB.kv_cache(prev_num_tokens, num_frozen_tokens)
+
+class DeepSparseDecoderEngine:
+ def __init__ (
+ self,
+ onnx_file_path: str,
+ sequence_length: int = 1024,
+ input_ids_length: int = 1,
+ engine_context: Optional[Context] = None,
+ ):
+
+ # setup ONNX graph
+ onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs(
+ onnx_file_path=onnx_file_path,
+ batch_size=1,
+ sequence_length=sequence_length,
+ input_ids_length=input_ids_length,
+ )
+
+ # compile engine
+ self.engine = create_engine(
+ onnx_file_path=onnx_file_path,
+ engine_type=DEEPSPARSE_ENGINE,
+ engine_args={"cached_outputs": cached_outputs},
+ context=engine_context,
+ )
+ print(self.engine)
+
+ # save utilties
+ self.past_key_value_dtype = data_type
+ self.onnx_inputs = self.engine.input_names
+ self.empty_past_key_values = self.make_empty_past_key_values()
+
+ # forward function
+ def __call__(
+ self,
+ engine_inputs: Dict[str, np.ndarray],
+ past_key_values: DeepSparsePastKeyValues,
+ val_inputs: bool = True
+ ):
+ # format input into lists (we pass empty past key values)
+ inputs = [
+ self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME)
+ else engine_inputs[name] for name in self.engine.input_names
+ ]
+
+ # validate inputs formatted correctly
+ if val_inputs:
+ self.engine._validate_inputs(inputs)
+
+ # run inference, updates past_key_values internally
+ output = self.engine._eng_net.execute_list_out(
+ inputs,
+ past_key_values.internal_past_key_values
+ )
+ logits = output[0]
+ return logits, past_key_values
+
+ # empty past kvs (dummy values to be passed around)
+ def make_empty_past_key_values(self):
+ past_key_values = {}
+ for idx, name in enumerate(self.onnx_inputs):
+ if name.startswith(PAST_KEY_VALUES_NAME):
+ past_key_values[name] = np.zeros(
+ self.engine.input_shapes[idx],
+ dtype=self.past_key_value_dtype
+ )
+
+ return past_key_values
+
+class DeepSparseDecoderModel:
+ def __init__(
+ self,
+ onnx_file_path: str,
+ sequence_length: int = 1024,
+ multitoken_length: int = 16,
+ engine_context: Optional[Context] = None,
+ ):
+ self.sequence_length = sequence_length
+ self.multitoken_length = multitoken_length
+
+ # compile decode engine
+ self.singletoken_engine = DeepSparseDecoderEngine(
+ onnx_file_path=onnx_file_path,
+ engine_context=engine_context,
+ sequence_length=sequence_length,
+ input_ids_length=1,
+ )
+
+ # compile prefill engine
+ self.multitoken_engine = DeepSparseDecoderEngine(
+ onnx_file_path=onnx_file_path,
+ engine_context=engine_context,
+ sequence_length=sequence_length,
+ input_ids_length=self.multitoken_length,
+ )
+
+ assert "input_ids" in self.singletoken_engine.onnx_inputs
+ assert "attention_mask" in self.singletoken_engine.onnx_inputs
+ assert "causal_mask" in self.singletoken_engine.onnx_inputs
+ assert "positions" in self.singletoken_engine.onnx_inputs
+
+ def engine_inputs_for_prefill(
+ self,
+ input_ids: np.ndarray,
+ ):
+ # split batch into N token_batches
+ num_batches = input_ids.shape[1] // self.multitoken_length
+ token_batches = [
+ input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length]
+ for i in range(0, num_batches)
+ ]
+
+ # format inputs for each of the N token_batches
+ for idx, token_batch in enumerate(token_batches):
+ num_processed_tokens = self.multitoken_length * idx
+
+ engine_inputs = {}
+ engine_inputs["input_ids"] = token_batch
+
+ # make attention mask from the right
+ engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
+ engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1
+
+ # make positions (building from the right)
+ # TODO: handle case when multitoken engine is 1
+ assert self.multitoken_length > 1
+ engine_inputs["positions"] = np.arange(
+ num_processed_tokens, num_processed_tokens + self.multitoken_length
+ ).reshape(1, -1).astype(np.int64)
+
+ # make causal mask (building from the right)
+ engine_inputs["causal_mask"] = create_causal_mask(
+ input_ids=engine_inputs["input_ids"],
+ attention_mask=engine_inputs["attention_mask"]
+ )
+ yield engine_inputs
+
+ def engine_inputs_for_decode(
+ self,
+ input_ids: np.ndarray,
+ ):
+ engine_inputs = {}
+ engine_inputs["input_ids"] = input_ids[:,-1:]
+ engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
+ engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1
+
+ engine_inputs["causal_mask"] = create_causal_mask(
+ engine_inputs["input_ids"],
+ engine_inputs["attention_mask"]
+ )
+ engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64)
+
+ return engine_inputs
+
+ def decode(
+ self,
+ input_ids: np.ndarray,
+ past_key_values: DeepSparsePastKeyValues
+ ) -> (np.ndarray, DeepSparsePastKeyValues):
+
+ # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
+ assert len(input_ids.shape) == 2
+ assert input_ids.shape[0] == 1
+ assert input_ids.shape[1] < self.sequence_length
+
+ engine_inputs = self.engine_inputs_for_decode(input_ids)
+ logits, past_key_values = self.singletoken_engine(
+ engine_inputs,
+ past_key_values
+ )
+
+ return logits, past_key_values
+
+ def prefill(
+ self,
+ input_ids: np.ndarray,
+ ) -> (np.ndarray, DeepSparsePastKeyValues):
+
+ # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len
+ assert len(input_ids.shape) == 2
+ assert input_ids.shape[0] == 1
+ assert input_ids.shape[1] < self.sequence_length
+
+ tokens_processed = 0
+
+ # setup empty past key values
+ past_key_values = DeepSparsePastKeyValues()
+
+ # loop through chunks, run inference w/ multitoken engine
+ for engine_inputs in self.engine_inputs_for_prefill(input_ids):
+ logits, past_key_values = self.multitoken_engine(
+ engine_inputs,
+ past_key_values
+ )
+ tokens_processed += self.multitoken_length
+
+ # if anything left over, run inference w/ singletoken engine
+ while tokens_processed < input_ids.shape[1]:
+ logits, past_key_values = self.decode(
+ input_ids=input_ids[:,:tokens_processed+1],
+ past_key_values=past_key_values
+ )
+ tokens_processed += 1
+ # print(logits[:,-1:,:])
+
+ return logits, past_key_values
+
+ def forward(
+ self,
+ input_ids: np.ndarray,
+ past_key_values: Optional[DeepSparsePastKeyValues] = None,
+ ):
+ if past_key_values is None:
+ return self.prefill(input_ids)
+ else:
+ return self.decode(input_ids, past_key_values)
+
+ def __call__(
+ self,
+ input_ids: np.ndarray,
+ past_key_values: Optional[DeepSparsePastKeyValues] = None,
+ ):
+ return self.forward(input_ids, past_key_values)
\ No newline at end of file
diff --git a/server/deepsparse/deepsparse_queue.py b/server/deepsparse/deepsparse_queue.py
new file mode 100644
index 00000000..438d33f7
--- /dev/null
+++ b/server/deepsparse/deepsparse_queue.py
@@ -0,0 +1,58 @@
+from typing import Deque, Optional, Tuple, Dict
+from collections import deque
+from threading import Condition
+from server.deepsparse.deepsparse_requests import Batch, Request
+
+class GenerateRequest:
+ def __init__(
+ self,
+ prompt: str,
+ max_generated_tokens: int
+ ):
+ self.prompt = prompt
+ self.generation = prompt
+ self.max_generated_tokens = max_generated_tokens
+ self.cv = Condition()
+ self.is_stopped = False
+
+# todo: implement logic for maximum memory usage
+class DeepSparseQueue:
+ def __init__(self):
+ self.next_request_id: int = 0
+ self.next_batch_id: int = 0
+ self.queue: Deque[GenerateRequest] = deque()
+
+ def append(self, generate_request: GenerateRequest):
+ self.queue.append(generate_request)
+
+ def is_empty(self):
+ return len(self.queue) == 0
+
+ # (todo): enable multiple prefill requests in a batch
+ def next_batch(self) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
+ if self.is_empty():
+ return None
+
+ # pop first generate_request in the queue
+ generate_request = self.queue.popleft()
+ generate_requests = {
+ self.next_request_id: generate_request
+ }
+
+ # format into request
+ request = Request(
+ id=self.next_request_id,
+ prompt=generate_request.prompt,
+ max_generated_tokens=generate_request.max_generated_tokens
+ )
+ self.next_request_id += 1
+
+ # format into batch
+ batch = Batch(
+ id = self.next_batch_id,
+ requests=[request]
+ )
+ self.next_batch_id += 1
+
+ # return batch, generate_requests
+ return (batch, generate_requests)
\ No newline at end of file
diff --git a/server/deepsparse/deepsparse_requests.py b/server/deepsparse/deepsparse_requests.py
new file mode 100644
index 00000000..430f3473
--- /dev/null
+++ b/server/deepsparse/deepsparse_requests.py
@@ -0,0 +1,39 @@
+from dataclasses import dataclass
+from typing import List, Optional
+
+@dataclass
+class Request:
+ id: int
+ prompt: str
+ max_generated_tokens: int
+
+@dataclass
+class Batch:
+ id: int
+ requests: List[Request]
+
+@dataclass
+class CachedBatch:
+ batch_id: int
+ request_ids: List[int]
+
+ def __len__(self):
+ return len(self.request_ids)
+
+@dataclass
+class Generation:
+ request_id: int
+ generated_text: Optional[str]
+
+@dataclass
+class PrefillRequest:
+ batch: Batch
+
+@dataclass
+class DecodeRequest:
+ batches: List[CachedBatch]
+
+@dataclass
+class FilterBatchRequest:
+ batch_id: int
+ request_ids: List[int]
\ No newline at end of file
diff --git a/server/deepsparse/deepsparse_router.py b/server/deepsparse/deepsparse_router.py
new file mode 100644
index 00000000..647d7f3f
--- /dev/null
+++ b/server/deepsparse/deepsparse_router.py
@@ -0,0 +1,184 @@
+from threading import Condition
+from typing import List, Dict, Optional
+
+from server.deepsparse.deepsparse_service import DeepSparseService
+from server.deepsparse.deepsparse_requests import (
+ CachedBatch, Batch, Generation,
+ PrefillRequest, DecodeRequest, FilterBatchRequest,
+)
+from server.deepsparse.deepsparse_queue import (
+ DeepSparseQueue, GenerateRequest
+)
+
+class DeepSparseRouter:
+ def __init__(self, service: DeepSparseService):
+ self.service: DeepSparseService = service
+ self.queue: DeepSparseQueue = DeepSparseQueue()
+ self.cv: Condition = Condition()
+
+ def generate(self, prompt:str) -> str:
+ generate_request = GenerateRequest(
+ prompt=prompt,
+ max_generated_tokens=100
+ )
+
+ with self.cv:
+ # print("router: acquired cv")
+ self.queue.append(generate_request)
+ self.cv.notify()
+
+ if prompt == "stop":
+ return "stop"
+
+ with generate_request.cv:
+ # print("generate_request: acquired cv")
+ if not generate_request.is_stopped:
+ # print("generate_request: waiting")
+ generate_request.cv.wait()
+
+ # print("generate_request: done waiting")
+
+ return generate_request.generation
+
+ def prefill(
+ self,
+ batch: Batch,
+ generate_requests: Dict[int,GenerateRequest]
+ ) -> Optional[CachedBatch]:
+ # print("prefill")
+ generation, next_batch = self.service.Prefill(
+ PrefillRequest(batch=batch)
+ )
+
+ self.filter_notify_update([generation], generate_requests)
+
+ return self.filter_batch(
+ batch=next_batch,
+ generate_requests=generate_requests
+ )
+
+ def decode(
+ self,
+ batches: List[CachedBatch],
+ generate_requests: Dict[int,GenerateRequest]
+ ) -> Optional[CachedBatch]:
+ # print("decode")
+ generations, next_batch = self.service.Decode(
+ DecodeRequest(batches=batches)
+ )
+
+ self.filter_notify_update(generations, generate_requests)
+
+ return self.filter_batch(
+ batch=next_batch,
+ generate_requests=generate_requests
+ )
+
+ def filter_notify_update(
+ self,
+ generations: List[Generation],
+ generate_requests: Dict[int, GenerateRequest]
+ ):
+ # print("filter_notify_update")
+ for generation in generations:
+ request_id = generation.request_id
+
+ # if we hit a stopping criteria
+ if generation.generated_text is None:
+ # remove from active requests and notify
+ stopped_generate_request = generate_requests.pop(request_id)
+ with stopped_generate_request.cv:
+ stopped_generate_request.is_stopped = True
+ stopped_generate_request.cv.notify()
+
+ # otherwise, update generation
+ else:
+ generate_requests[request_id].generation += generation.generated_text
+
+ def filter_batch(
+ self,
+ batch: Optional[CachedBatch],
+ generate_requests: Dict[int, GenerateRequest]
+ ) -> Optional[CachedBatch]:
+ # print("filter_batch")
+
+ # batch is already done
+ if batch is None:
+ return batch
+
+ # no need to filter
+ if len(batch) == len(generate_requests):
+ return batch
+
+ # retain only requests that are still in active generation requests
+ batch.request_ids = [id for id in batch.request_ids if id in generate_requests]
+
+ # if all requests complete, clear cache and return None
+ if len(batch) == 0:
+ self.service.ClearCache()
+ return None
+
+ # otherwise call the filter batch service
+ return self.service.FilterBatch(
+ FilterBatchRequest(
+ batch_id=batch.batch_id,
+ request_ids=batch.request_ids,
+ )
+ )
+
+def batching_task(
+ router: DeepSparseRouter
+) -> bool:
+ # infinite_loop
+ while True:
+ # block while the queue is empty
+ # print("batching_task: about to acquire cv")
+ with router.cv:
+ while router.queue.is_empty():
+ # print(f"batching_task cv: waiting")
+ router.cv.wait()
+ # print(f"batching_task: done waiting")
+
+ # loop until all batches in the queue are processed
+ next_batch = router.queue.next_batch()
+ while next_batch is not None:
+ batch, generate_requests = next_batch
+
+ # hack to break out of the cycle
+ if batch.requests[0].prompt == "stop":
+ assert router.queue.is_empty()
+ assert len(router.service.cache) == 0
+ return True
+
+ cached_batch = router.prefill(
+ batch=batch,
+ generate_requests=generate_requests
+ )
+
+ # loop until we do not reiceve any cached batch from the service (== until
+ # all requests have met their stopping criteria
+ while cached_batch is not None:
+ # print(f"batch_size = {len(cached_batch)}")
+ batches = [cached_batch]
+
+ # try to get a new batch and run prefill on this batch
+ next_batch = router.queue.next_batch()
+ if next_batch is not None:
+ new_batch, new_generate_requests = next_batch
+ new_cached_batch = router.prefill(
+ batch=new_batch,
+ generate_requests=new_generate_requests
+ )
+
+ if new_cached_batch is not None:
+ batches.append(new_cached_batch)
+ assert len(generate_requests.keys() & new_generate_requests.keys()) == 0
+ generate_requests.update(new_generate_requests)
+
+ # run decode
+ cached_batch = router.decode(
+ batches=batches,
+ generate_requests=generate_requests
+ )
+
+ next_batch = router.queue.next_batch()
\ No newline at end of file
diff --git a/server/deepsparse/deepsparse_service.py b/server/deepsparse/deepsparse_service.py
new file mode 100644
index 00000000..f4eae070
--- /dev/null
+++ b/server/deepsparse/deepsparse_service.py
@@ -0,0 +1,93 @@
+from typing import Optional, Dict, List
+from server.deepsparse.deepsparse_causal_lm import (
+ DeepSparseCausalLM, DeepSparseCausalLMBatch
+)
+from server.deepsparse.deepsparse_requests import (
+ PrefillRequest, DecodeRequest, FilterBatchRequest,
+ Generation, CachedBatch
+)
+
+class Cache:
+ def __init__(self):
+ self.cache: Dict[int, DeepSparseCausalLMBatch] = {}
+
+ def pop(self, batch_id: int) -> Optional[DeepSparseCausalLMBatch]:
+ return self.cache.pop(batch_id, None)
+
+ def set(self, entry: DeepSparseCausalLMBatch):
+ if entry is not None:
+ self.cache[entry.batch_id] = entry
+
+ def delete(self, batch_id: int):
+ batch = self.pop(batch_id)
+ if batch is not None:
+ del batch
+
+ def clear(self):
+ keys = list(self.cache.keys())
+ for k in keys:
+ self.delete(k)
+
+ def __len__(self):
+ return len(self.cache.keys())
+
+class DeepSparseService:
+ def __init__(
+ self,
+ model: DeepSparseCausalLM
+ ):
+ self.model = model
+ self.cache = Cache()
+
+ def ClearCache(self):
+ self.cache.clear()
+
+ def FilterBatch(
+ self,
+ request: FilterBatchRequest
+ ) -> CachedBatch:
+
+ ds_batch = self.cache.pop(request.batch_id)
+ assert ds_batch is not None, "Batch ID {request.batch_id} not found in cache."
+ filtered_batch = ds_batch.filter(request.request_ids)
+ self.cache.set(filtered_batch)
+
+ return filtered_batch.to_batch()
+
+ def Prefill(
+ self,
+ request: PrefillRequest
+ ) -> [Generation, CachedBatch]:
+
+ ds_batch = DeepSparseCausalLMBatch.from_batch(
+ batch=request.batch,
+ tokenizer=self.model.tokenizer
+ )
+
+ generations, next_ds_batch = self.model.generate_token(ds_batch)
+ assert len(generations) == 1
+ self.cache.set(next_ds_batch)
+
+ return generations[0], next_ds_batch.to_batch()
+
+ def Decode(
+ self,
+ request: DecodeRequest
+ ) -> [List[Generation], CachedBatch]:
+ assert len(request.batches) != 0, "Must provide at least one batch"
+
+ ds_batches = []
+ for batch in request.batches:
+ ds_batch = self.cache.pop(batch.batch_id)
+ assert batch is not None, "Batch ID {batch.id} not found in cache."
+ ds_batches.append(ds_batch)
+
+ if len(ds_batches) > 1:
+ ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches)
+ else:
+ ds_batch = ds_batches[0]
+
+ generations, next_ds_batch = self.model.generate_token(ds_batch)
+ self.cache.set(next_ds_batch)
+
+ return generations, (next_ds_batch.to_batch() if next_ds_batch else None)
\ No newline at end of file