From 06fc85f93cb9326b1fd84bd87663535d6d1d228f Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Thu, 24 Aug 2023 17:49:28 +0000 Subject: [PATCH] refactored router code to use queues (which can be treated like a stream for IPC --- server-dev.ipynb | 386 +++++++++++++++++++++- server/deepsparse/deepsparse_causal_lm.py | 251 -------------- server/deepsparse/deepsparse_model.py | 241 -------------- server/deepsparse/deepsparse_queue.py | 58 ---- server/deepsparse/deepsparse_requests.py | 39 --- server/deepsparse/deepsparse_router.py | 184 ----------- server/deepsparse/deepsparse_service.py | 93 ------ 7 files changed, 375 insertions(+), 877 deletions(-) delete mode 100644 server/deepsparse/deepsparse_causal_lm.py delete mode 100644 server/deepsparse/deepsparse_model.py delete mode 100644 server/deepsparse/deepsparse_queue.py delete mode 100644 server/deepsparse/deepsparse_requests.py delete mode 100644 server/deepsparse/deepsparse_router.py delete mode 100644 server/deepsparse/deepsparse_service.py diff --git a/server-dev.ipynb b/server-dev.ipynb index 80a5e83b..24b11e9d 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "0b2c83cd-92ea-40d7-bc7e-f737b87d9b8d", "metadata": {}, "outputs": [ @@ -29,19 +29,19 @@ "name": "stderr", "output_type": "stream", "text": [ - "2023-08-23 19:52:18 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n" + "2023-08-24 17:46:45 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n" ] } ], "source": [ - "from server.deepsparse.deepsparse_router import DeepSparseRouter, batching_task\n", - "from server.deepsparse.deepsparse_service import DeepSparseService\n", - "from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLM" + "from server.deepsparse.router import DeepSparseRouter, batching_task\n", + "from server.deepsparse.service.service import DeepSparseService\n", + "from server.deepsparse.service.causal_lm import DeepSparseCausalLM" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "78acf813-3688-483d-9148-5c0df5d6b8e3", "metadata": {}, "outputs": [ @@ -50,7 +50,7 @@ "output_type": "stream", "text": [ "Using pad_token, but it is not set yet.\n", - "2023-08-23 19:52:20 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", + "2023-08-24 17:47:02 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n" ] }, @@ -73,7 +73,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "2023-08-23 19:52:44 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n" + "2023-08-24 17:47:27 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n" ] }, { @@ -99,12 +99,376 @@ "model = DeepSparseCausalLM(\n", " tokenizer_path=tokenizer_path,\n", " model_path=onnx_path\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c011f167-f150-4a46-b24d-3074e8564151", + "metadata": {}, + "outputs": [], + "source": [ + "from server.deepsparse.utils import GenerateRequest\n", + "from queue import Queue \n", + "from threading import Thread\n", + "import time\n", "\n", "service = DeepSparseService(model=model)\n", - "router = DeepSparseRouter(service=service)" + "router = DeepSparseRouter(service=service)\n", + "batching_thread = Thread(target=batching_task, args=[router])\n", + "batching_thread.start()" ] }, + { + "cell_type": "code", + "execution_count": 8, + "id": "538ba71e-9562-4325-899b-c67a4cf74075", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "starting request\n", + "starting request\n", + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + "def fib(n):\n", + " if n == 0:\n", + " return 0\n", + " elif n == 1:\n", + " return 1\n", + " else:\n", + " return fib(n-1) + fib(n-2)\n", + "\n", + "# Driver function to test above function\n", + "n = int(input(\"Enter the number: \"))\n", + "print(fib(n))\n", + "\n", + "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n", + "\n", + "Write a function for filtering a list of integers to include only positive numbers:\n", + "\n", + "def filter(lst):\n", + " return [x for x in lst if x > 0]\n", + "\n", + "# Test\n", + "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n", + "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n", + "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n", + "print(filter([1,\n" + ] + } + ], + "source": [ + "prompts = [\n", + " \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n", + " \"Write a function for filtering a list of integers to include only positive numbers:\\n\\ndef filter(lst):\",\n", + " \"Write a function for reversing a string:\\n\\ndef reverse_string(s):\",\n", + " \"Write a function for checking if a word if a palindrome:\\n\\ndef is_palindrome(word):\",\n", + " \"Write a function for sorting an array of integers:\\n\\ndef merge_sort(arr):\",\n", + "]\n", + "\n", + "def generate_request(prompt):\n", + " print(\"starting request\")\n", + " response_stream = Queue()\n", + " \n", + " g_request = GenerateRequest(\n", + " prompt=prompt,\n", + " max_generated_tokens=100,\n", + " response_stream=response_stream\n", + " )\n", + "\n", + " router.generate(g_request)\n", + "\n", + " str = prompt\n", + " generation = response_stream.get()\n", + " while not generation.stopped:\n", + " str += generation.token\n", + " generation = response_stream.get()\n", + "\n", + " print(str)\n", + "\n", + "generate_threads = [\n", + " Thread(target=generate_request, args=[prompt]) for prompt in prompts[:2]\n", + "]\n", + "\n", + "# print(len(generate_threads))\n", + "\n", + "for gt in generate_threads:\n", + " gt.start()\n", + " time.sleep(1)\n", + "\n", + "for gt in generate_threads:\n", + " gt.join()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d9a1c1cc-e1d1-4915-8c3a-c1c9f2aa5147", + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m service \u001b[38;5;241m=\u001b[39m DeepSparseService(model\u001b[38;5;241m=\u001b[39mmodel)\n\u001b[1;32m 2\u001b[0m router \u001b[38;5;241m=\u001b[39m DeepSparseRouter(service\u001b[38;5;241m=\u001b[39mservice)\n\u001b[0;32m----> 3\u001b[0m \u001b[43mbatching_thread\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/dscb/lib/python3.9/threading.py:1060\u001b[0m, in \u001b[0;36mThread.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot join current thread\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1059\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1060\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wait_for_tstate_lock\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1061\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1062\u001b[0m \u001b[38;5;66;03m# the behavior of a negative timeout isn't documented, but\u001b[39;00m\n\u001b[1;32m 1063\u001b[0m \u001b[38;5;66;03m# historically .join(timeout=x) for x<0 has acted as if timeout=0\u001b[39;00m\n\u001b[1;32m 1064\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wait_for_tstate_lock(timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mmax\u001b[39m(timeout, \u001b[38;5;241m0\u001b[39m))\n", + "File \u001b[0;32m~/.conda/envs/dscb/lib/python3.9/threading.py:1080\u001b[0m, in \u001b[0;36mThread._wait_for_tstate_lock\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 1077\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1079\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1080\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mlock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43mblock\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 1081\u001b[0m lock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[1;32m 1082\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_stop()\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "service = DeepSparseService(model=model)\n", + "router = DeepSparseRouter(service=service)\n", + "batching_thread.join()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a762fbbb-9711-4bf4-aeb9-3083a1f5b1f8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9a556e2-d626-4733-84eb-e7e40f244981", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aff1a7dd-de34-443a-bf37-6ef039bd25c8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "2c50cb8f-bc1b-487d-94e2-90cee1cabdf4", + "metadata": {}, + "outputs": [], + "source": [ + "batch, generate_requests = router.queue.next_batch(block=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "8c048933-7e50-4f6a-b360-41a3ae7ae8c8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CachedBatch(batch_id=0, request_ids=[0])\n" + ] + } + ], + "source": [ + "cached_batch = router.prefill(batch, generate_requests)\n", + "print(cached_batch)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "43363249-5256-4259-b465-b99fc4d7f42c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "None\n" + ] + } + ], + "source": [ + "next_batch = router.queue.next_batch(block=False)\n", + "print(next_batch)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "1ebb411e-c9ff-471a-9cd0-77fcdb2299df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n", + "CachedBatch(batch_id=0, request_ids=[0])\n" + ] + } + ], + "source": [ + "for i in range(20):\n", + " cached_batch = router.decode(batches=[cached_batch], generate_requests=generate_requests)\n", + " print(cached_batch)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "1fb10db6-54bb-4fcb-8662-e12a65fc0fa8", + "metadata": {}, + "outputs": [], + "source": [ + "router.generate(g_request1)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "7bf59798-3325-4171-9a4b-fab283c8194b", + "metadata": {}, + "outputs": [], + "source": [ + "next_batch = router.queue.next_batch(block=False)\n", + "if next_batch is not None:\n", + " new_batch, new_generate_requests = next_batch\n", + "\n", + "new_cached_batch = router.prefill(\n", + " batch=new_batch,\n", + " generate_requests=new_generate_requests\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "9e9e5654-beb0-4fba-8bf6-49824d335925", + "metadata": {}, + "outputs": [], + "source": [ + "batches = [cached_batch]\n", + "batches.append(new_cached_batch)\n", + "generate_requests.update(new_generate_requests)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "764448bb-7841-4910-ab6b-56c3c7a2cde8", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(20):\n", + " cached_batch = router.decode(batches=batches, generate_requests=generate_requests)\n", + " batches = [cached_batch]" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "1536911c-8504-41b3-aa96-cc33024446f6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{0: GenerateRequest(prompt='Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):', max_generated_tokens=100, response_stream=),\n", + " 1: GenerateRequest(prompt='Write a function for reversing a string:\\n\\ndef reverse_string(s):', max_generated_tokens=100, response_stream=)}" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "generate_requests" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "d1dfb9ed-8bf0-4a38-a7f6-9cd22212e663", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generation(request_id=1, token='\\n', token_id=198, stopped=False)\n", + "Generation(request_id=1, token=' ', token_id=50284, stopped=False)\n", + "Generation(request_id=1, token='return', token_id=7783, stopped=False)\n", + "Generation(request_id=1, token=' s', token_id=264, stopped=False)\n", + "Generation(request_id=1, token='[', token_id=58, stopped=False)\n", + "Generation(request_id=1, token='::', token_id=3712, stopped=False)\n", + "Generation(request_id=1, token='-', token_id=12, stopped=False)\n", + "Generation(request_id=1, token='1', token_id=16, stopped=False)\n", + "Generation(request_id=1, token=']', token_id=60, stopped=False)\n", + "Generation(request_id=1, token='\\n', token_id=198, stopped=False)\n", + "Generation(request_id=1, token='\\n', token_id=198, stopped=False)\n", + "Generation(request_id=1, token='#', token_id=2, stopped=False)\n", + "Generation(request_id=1, token=' Test', token_id=6208, stopped=False)\n", + "Generation(request_id=1, token='\\n', token_id=198, stopped=False)\n", + "Generation(request_id=1, token='print', token_id=4798, stopped=False)\n", + "Generation(request_id=1, token='(', token_id=7, stopped=False)\n", + "Generation(request_id=1, token='reverse', token_id=50188, stopped=False)\n", + "Generation(request_id=1, token='_', token_id=62, stopped=False)\n", + "Generation(request_id=1, token='string', token_id=8841, stopped=False)\n", + "Generation(request_id=1, token='(\"', token_id=7203, stopped=False)\n" + ] + } + ], + "source": [ + "for i in range(20):\n", + " generation = response_stream1.get(block=False)\n", + " print(generation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41ead58c-e59b-4a9c-8b35-a1a658f22a7f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b57985b2-75f8-4460-9b77-e531a12bd12c", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 5, @@ -239,7 +603,7 @@ " gt.join()\n", "\n", "\n", - "generate_task(\"stop\")\n", + "generate(\"stop\")\n", "batching_thread.join()" ] }, diff --git a/server/deepsparse/deepsparse_causal_lm.py b/server/deepsparse/deepsparse_causal_lm.py deleted file mode 100644 index bf182395..00000000 --- a/server/deepsparse/deepsparse_causal_lm.py +++ /dev/null @@ -1,251 +0,0 @@ -import numpy as np -from dataclasses import dataclass -from typing import List, Dict, Optional - -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from server.deepsparse.deepsparse_model import ( - DeepSparsePastKeyValues, DeepSparseDecoderModel -) -from server.deepsparse.deepsparse_requests import ( - Request, Batch, CachedBatch, Generation -) - -DEEPSPARSE_SEQUENCE_LENGTH = 128 -DEEPSPARSE_MULTITOKEN_LENGTH = 4 - -@dataclass -class DeepSparseCausalLMBatch: - batch_id: int - requests: List[Request] - requests_idx_mapping: Dict[int,int] - input_ids_list: List[np.ndarray] - past_key_values_list: List[Optional[DeepSparsePastKeyValues]] - - @classmethod - def from_batch( - cls, - batch: Batch, - tokenizer: PreTrainedTokenizerBase, - ) -> "DeepSparseCausalLMBatch": - - # parse batch - requests_idx_mapping = {} - input_ids_list = [] - - # setup tokenizer for deepsparse left padding - tokenizer.padding_side = "left" - if not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.eos_token - padding, truncation = "longest", False - - # loop through items in the batch - for idx, r in enumerate(batch.requests): - requests_idx_mapping[r.id] = idx - - # setup inputs_ids, past_key_values - tokenized_inputs = tokenizer( - r.prompt, - return_tensors="np", - padding=padding, - truncation=truncation, - return_token_type_ids=False, - max_length=DEEPSPARSE_SEQUENCE_LENGTH - ) - input_ids_list.append(tokenized_inputs["input_ids"]) - - return cls( - batch_id=batch.id, - requests=batch.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids_list=input_ids_list, - past_key_values_list=[None] * len(batch.requests), - ) - - def to_batch(self) -> CachedBatch: - return CachedBatch( - batch_id = self.batch_id, - request_ids=[r.id for r in self.requests], - ) - - # length of the batch - def __len__(self): - return len(self.requests) - - # pass list of request ids, returns batch with only those request ids - def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]: - assert(len(request_ids) > 0) - - requests_idx_mapping = {} - requests = [] - input_ids_list = [] - past_key_values_list = [] - - # loop through requests, keep ones that should remain - for new_idx, request_id in enumerate(request_ids): - assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch" - - requests_idx_mapping[request_id] = new_idx - - old_idx = self.requests_idx_mapping[request_id] - requests.append(self.requests[old_idx]) - input_ids_list.append(self.input_ids_list[old_idx]) - past_key_values_list.append(self.past_key_values_list[old_idx]) - - # update batch state - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids_list = input_ids_list - self.past_key_values_list = past_key_values_list - - return self - - # combine two batches into one - @classmethod - def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch": - assert len(batches) > 1, "must have more than 1 batch to concatenate" - - requests_idx_mapping = {} - requests = [] - input_ids_list = [] - past_key_values_list = [] - - start_index = 0 - for i, batch in enumerate(batches): - assert batch.past_key_values_list is not None, "only concatenate prefilled batches" - - # concatenate request, input_ids, and past_key_values lists - requests.extend(batch.requests) - input_ids_list.extend(batch.input_ids_list) - past_key_values_list.extend(batch.past_key_values_list) - - # merge the request_id to index mapping - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - start_index += len(batch) - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids_list=input_ids_list, - past_key_values_list=past_key_values_list - ) - -class DeepSparseCausalLM: - def __init__( - self, - model_path: str, - tokenizer_path: str, - ): - # setup tokenizer - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - self.tokenizer.padding_side = "left" - if not self.tokenizer.pad_token: - assert self.tokenizer.eos_token - self.tokenizer.pad_token = self.tokenizer.eos_token - - # setup model - self.model = DeepSparseDecoderModel( - onnx_file_path = model_path, - sequence_length = DEEPSPARSE_SEQUENCE_LENGTH, - multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, - ) - - # TODO (@rsnm2): switch to NextTokenChooser - def sample_token( - self, - logits: np.ndarray - ): - assert(logits.shape[0] == 1) # assert b=1 for now - return np.argmax(logits[0,-1,:]) # grab logits for the last item in the sequence - - # TODO (@rsnm2): switch to StoppingCriteria - def should_stop( - self, - num_tokens_processed: int, - generated_token_id: int - ): - if num_tokens_processed >= self.model.sequence_length: - return True - if generated_token_id == self.tokenizer.eos_token_id: - return True - return False - - def generate_token( - self, - batch: DeepSparseCausalLMBatch, - ) -> (List[Generation], Optional[DeepSparseCausalLMBatch]): - - generations: List[Generation] = [] - all_stopped = True - - # if we supported continuous batching, we would do batched inference here - # logits, past_key_values = self.model(batch) - - # for each member of the batch: - # a) run inference - # b) sample and check stopping criteria - # c) create generation + update batch - for i, ( - request, - input_ids, - past_key_values, - ) in enumerate(zip( - batch.requests, - batch.input_ids_list, - batch.past_key_values_list - )): - - # run inference - logits, past_key_values = self.model(input_ids, past_key_values) - - # sample token - # todo: simple for now --- should use NextTokenChooser - generated_token_id = self.sample_token(logits) - - # check stopping criteria - # todo: simple for now --- should use StoppingCriteria - assert len(input_ids.shape) == 2 - assert input_ids.shape[0] == 1 - - stop = self.should_stop( - num_tokens_processed=input_ids.shape[1] + 1, - generated_token_id = generated_token_id - ) - - # if not stopped, convert token id to text - generated_text = None - if not stop: - all_stopped = False - generated_text = self.tokenizer.decode( - generated_token_id, - skip_special_tokens=True, - clean_up_tokenization_spaces=False - ) - generations.append(Generation( - request_id=request.id, - generated_text=generated_text - )) - - # update values in the batch - # bad --- this does not occur in place - assert len(batch.input_ids_list[i].shape) == 2 - assert batch.input_ids_list[i].shape[0] == 1 - batch.input_ids_list[i] = np.append( - batch.input_ids_list[i], - np.array([[generated_token_id]]), - axis=1 - ) - batch.past_key_values_list[i] = past_key_values - - # if all elements of the batch are done, return generation + null for batch - if all_stopped: - return generations, None - - # return generation + updated batch - return generations, batch \ No newline at end of file diff --git a/server/deepsparse/deepsparse_model.py b/server/deepsparse/deepsparse_model.py deleted file mode 100644 index 9b0082bc..00000000 --- a/server/deepsparse/deepsparse_model.py +++ /dev/null @@ -1,241 +0,0 @@ -import os -os.environ["WAND_OPT_FLAGS"] = "default,~pyramids" - -import numpy as np -from typing import Optional, List, Dict - -from deepsparse import Context -from deepsparse.engine import LIB -from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine -from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs, create_causal_mask - -PAST_KEY_VALUES_NAME = "past_key_values" - -class DeepSparsePastKeyValues: - def __init__(self): - prev_num_tokens = 0 - num_frozen_tokens = 1 - self.internal_past_key_values = LIB.kv_cache(prev_num_tokens, num_frozen_tokens) - -class DeepSparseDecoderEngine: - def __init__ ( - self, - onnx_file_path: str, - sequence_length: int = 1024, - input_ids_length: int = 1, - engine_context: Optional[Context] = None, - ): - - # setup ONNX graph - onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs( - onnx_file_path=onnx_file_path, - batch_size=1, - sequence_length=sequence_length, - input_ids_length=input_ids_length, - ) - - # compile engine - self.engine = create_engine( - onnx_file_path=onnx_file_path, - engine_type=DEEPSPARSE_ENGINE, - engine_args={"cached_outputs": cached_outputs}, - context=engine_context, - ) - print(self.engine) - - # save utilties - self.past_key_value_dtype = data_type - self.onnx_inputs = self.engine.input_names - self.empty_past_key_values = self.make_empty_past_key_values() - - # forward function - def __call__( - self, - engine_inputs: Dict[str, np.ndarray], - past_key_values: DeepSparsePastKeyValues, - val_inputs: bool = True - ): - # format input into lists (we pass empty past key values) - inputs = [ - self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME) - else engine_inputs[name] for name in self.engine.input_names - ] - - # validate inputs formatted correctly - if val_inputs: - self.engine._validate_inputs(inputs) - - # run inference, updates past_key_values internally - output = self.engine._eng_net.execute_list_out( - inputs, - past_key_values.internal_past_key_values - ) - logits = output[0] - return logits, past_key_values - - # empty past kvs (dummy values to be passed around) - def make_empty_past_key_values(self): - past_key_values = {} - for idx, name in enumerate(self.onnx_inputs): - if name.startswith(PAST_KEY_VALUES_NAME): - past_key_values[name] = np.zeros( - self.engine.input_shapes[idx], - dtype=self.past_key_value_dtype - ) - - return past_key_values - -class DeepSparseDecoderModel: - def __init__( - self, - onnx_file_path: str, - sequence_length: int = 1024, - multitoken_length: int = 16, - engine_context: Optional[Context] = None, - ): - self.sequence_length = sequence_length - self.multitoken_length = multitoken_length - - # compile decode engine - self.singletoken_engine = DeepSparseDecoderEngine( - onnx_file_path=onnx_file_path, - engine_context=engine_context, - sequence_length=sequence_length, - input_ids_length=1, - ) - - # compile prefill engine - self.multitoken_engine = DeepSparseDecoderEngine( - onnx_file_path=onnx_file_path, - engine_context=engine_context, - sequence_length=sequence_length, - input_ids_length=self.multitoken_length, - ) - - assert "input_ids" in self.singletoken_engine.onnx_inputs - assert "attention_mask" in self.singletoken_engine.onnx_inputs - assert "causal_mask" in self.singletoken_engine.onnx_inputs - assert "positions" in self.singletoken_engine.onnx_inputs - - def engine_inputs_for_prefill( - self, - input_ids: np.ndarray, - ): - # split batch into N token_batches - num_batches = input_ids.shape[1] // self.multitoken_length - token_batches = [ - input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length] - for i in range(0, num_batches) - ] - - # format inputs for each of the N token_batches - for idx, token_batch in enumerate(token_batches): - num_processed_tokens = self.multitoken_length * idx - - engine_inputs = {} - engine_inputs["input_ids"] = token_batch - - # make attention mask from the right - engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) - engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1 - - # make positions (building from the right) - # TODO: handle case when multitoken engine is 1 - assert self.multitoken_length > 1 - engine_inputs["positions"] = np.arange( - num_processed_tokens, num_processed_tokens + self.multitoken_length - ).reshape(1, -1).astype(np.int64) - - # make causal mask (building from the right) - engine_inputs["causal_mask"] = create_causal_mask( - input_ids=engine_inputs["input_ids"], - attention_mask=engine_inputs["attention_mask"] - ) - yield engine_inputs - - def engine_inputs_for_decode( - self, - input_ids: np.ndarray, - ): - engine_inputs = {} - engine_inputs["input_ids"] = input_ids[:,-1:] - engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) - engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1 - - engine_inputs["causal_mask"] = create_causal_mask( - engine_inputs["input_ids"], - engine_inputs["attention_mask"] - ) - engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64) - - return engine_inputs - - def decode( - self, - input_ids: np.ndarray, - past_key_values: DeepSparsePastKeyValues - ) -> (np.ndarray, DeepSparsePastKeyValues): - - # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len - assert len(input_ids.shape) == 2 - assert input_ids.shape[0] == 1 - assert input_ids.shape[1] < self.sequence_length - - engine_inputs = self.engine_inputs_for_decode(input_ids) - logits, past_key_values = self.singletoken_engine( - engine_inputs, - past_key_values - ) - - return logits, past_key_values - - def prefill( - self, - input_ids: np.ndarray, - ) -> (np.ndarray, DeepSparsePastKeyValues): - - # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len - assert len(input_ids.shape) == 2 - assert input_ids.shape[0] == 1 - assert input_ids.shape[1] < self.sequence_length - - tokens_processed = 0 - - # setup empty past key values - past_key_values = DeepSparsePastKeyValues() - - # loop through chunks, run inference w/ multitoken engine - for engine_inputs in self.engine_inputs_for_prefill(input_ids): - logits, past_key_values = self.multitoken_engine( - engine_inputs, - past_key_values - ) - tokens_processed += self.multitoken_length - - # if anything left over, run inference w/ singletoken engine - while tokens_processed < input_ids.shape[1]: - logits, past_key_values = self.decode( - input_ids=input_ids[:,:tokens_processed+1], - past_key_values=past_key_values - ) - tokens_processed += 1 - # print(logits[:,-1:,:]) - - return logits, past_key_values - - def forward( - self, - input_ids: np.ndarray, - past_key_values: Optional[DeepSparsePastKeyValues] = None, - ): - if past_key_values is None: - return self.prefill(input_ids) - else: - return self.decode(input_ids, past_key_values) - - def __call__( - self, - input_ids: np.ndarray, - past_key_values: Optional[DeepSparsePastKeyValues] = None, - ): - return self.forward(input_ids, past_key_values) \ No newline at end of file diff --git a/server/deepsparse/deepsparse_queue.py b/server/deepsparse/deepsparse_queue.py deleted file mode 100644 index 438d33f7..00000000 --- a/server/deepsparse/deepsparse_queue.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Deque, Optional, Tuple, Dict -from collections import deque -from threading import Condition -from server.deepsparse.deepsparse_requests import Batch, Request - -class GenerateRequest: - def __init__( - self, - prompt: str, - max_generated_tokens: int - ): - self.prompt = prompt - self.generation = prompt - self.max_generated_tokens = max_generated_tokens - self.cv = Condition() - self.is_stopped = False - -# todo: implement logic for maximum memory usage -class DeepSparseQueue: - def __init__(self): - self.next_request_id: int = 0 - self.next_batch_id: int = 0 - self.queue: Deque[GenerateRequest] = deque() - - def append(self, generate_request: GenerateRequest): - self.queue.append(generate_request) - - def is_empty(self): - return len(self.queue) == 0 - - # (todo): enable multiple prefill requests in a batch - def next_batch(self) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]: - if self.is_empty(): - return None - - # pop first generate_request in the queue - generate_request = self.queue.popleft() - generate_requests = { - self.next_request_id: generate_request - } - - # format into request - request = Request( - id=self.next_request_id, - prompt=generate_request.prompt, - max_generated_tokens=generate_request.max_generated_tokens - ) - self.next_request_id += 1 - - # format into batch - batch = Batch( - id = self.next_batch_id, - requests=[request] - ) - self.next_batch_id += 1 - - # return batch, generate_requests - return (batch, generate_requests) \ No newline at end of file diff --git a/server/deepsparse/deepsparse_requests.py b/server/deepsparse/deepsparse_requests.py deleted file mode 100644 index 430f3473..00000000 --- a/server/deepsparse/deepsparse_requests.py +++ /dev/null @@ -1,39 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -@dataclass -class Request: - id: int - prompt: str - max_generated_tokens: int - -@dataclass -class Batch: - id: int - requests: List[Request] - -@dataclass -class CachedBatch: - batch_id: int - request_ids: List[int] - - def __len__(self): - return len(self.request_ids) - -@dataclass -class Generation: - request_id: int - generated_text: Optional[str] - -@dataclass -class PrefillRequest: - batch: Batch - -@dataclass -class DecodeRequest: - batches: List[CachedBatch] - -@dataclass -class FilterBatchRequest: - batch_id: int - request_ids: List[int] \ No newline at end of file diff --git a/server/deepsparse/deepsparse_router.py b/server/deepsparse/deepsparse_router.py deleted file mode 100644 index 647d7f3f..00000000 --- a/server/deepsparse/deepsparse_router.py +++ /dev/null @@ -1,184 +0,0 @@ -from threading import Condition -from typing import List, Dict, Optional - -from server.deepsparse.deepsparse_service import DeepSparseService -from server.deepsparse.deepsparse_requests import ( - CachedBatch, Batch, Generation, - PrefillRequest, DecodeRequest, FilterBatchRequest, -) -from server.deepsparse.deepsparse_queue import ( - DeepSparseQueue, GenerateRequest -) - -class DeepSparseRouter: - def __init__(self, service: DeepSparseService): - self.service: DeepSparseService = service - self.queue: DeepSparseQueue = DeepSparseQueue() - self.cv: Condition = Condition() - - def generate(self, prompt:str) -> str: - generate_request = GenerateRequest( - prompt=prompt, - max_generated_tokens=100 - ) - - with self.cv: - # print("router: acquired cv") - self.queue.append(generate_request) - self.cv.notify() - - if prompt == "stop": - return "stop" - - with generate_request.cv: - # print("generate_request: acquired cv") - if not generate_request.is_stopped: - # print("generate_request: waiting") - generate_request.cv.wait() - - # print("generate_request: done waiting") - - return generate_request.generation - - def prefill( - self, - batch: Batch, - generate_requests: Dict[int,GenerateRequest] - ) -> Optional[CachedBatch]: - # print("prefill") - generation, next_batch = self.service.Prefill( - PrefillRequest(batch=batch) - ) - - self.filter_notify_update([generation], generate_requests) - - return self.filter_batch( - batch=next_batch, - generate_requests=generate_requests - ) - - def decode( - self, - batches: List[CachedBatch], - generate_requests: Dict[int,GenerateRequest] - ) -> Optional[CachedBatch]: - # print("decode") - generations, next_batch = self.service.Decode( - DecodeRequest(batches=batches) - ) - - self.filter_notify_update(generations, generate_requests) - - return self.filter_batch( - batch=next_batch, - generate_requests=generate_requests - ) - - def filter_notify_update( - self, - generations: List[Generation], - generate_requests: Dict[int, GenerateRequest] - ): - # print("filter_notify_update") - for generation in generations: - request_id = generation.request_id - - # if we hit a stopping criteria - if generation.generated_text is None: - # remove from active requests and notify - stopped_generate_request = generate_requests.pop(request_id) - with stopped_generate_request.cv: - stopped_generate_request.is_stopped = True - stopped_generate_request.cv.notify() - - # otherwise, update generation - else: - generate_requests[request_id].generation += generation.generated_text - - def filter_batch( - self, - batch: Optional[CachedBatch], - generate_requests: Dict[int, GenerateRequest] - ) -> Optional[CachedBatch]: - # print("filter_batch") - - # batch is already done - if batch is None: - return batch - - # no need to filter - if len(batch) == len(generate_requests): - return batch - - # retain only requests that are still in active generation requests - batch.request_ids = [id for id in batch.request_ids if id in generate_requests] - - # if all requests complete, clear cache and return None - if len(batch) == 0: - self.service.ClearCache() - return None - - # otherwise call the filter batch service - return self.service.FilterBatch( - FilterBatchRequest( - batch_id=batch.batch_id, - request_ids=batch.request_ids, - ) - ) - -def batching_task( - router: DeepSparseRouter -) -> bool: - # infinite_loop - while True: - # block while the queue is empty - # print("batching_task: about to acquire cv") - with router.cv: - while router.queue.is_empty(): - # print(f"batching_task cv: waiting") - router.cv.wait() - # print(f"batching_task: done waiting") - - # loop until all batches in the queue are processed - next_batch = router.queue.next_batch() - while next_batch is not None: - batch, generate_requests = next_batch - - # hack to break out of the cycle - if batch.requests[0].prompt == "stop": - assert router.queue.is_empty() - assert len(router.service.cache) == 0 - return True - - cached_batch = router.prefill( - batch=batch, - generate_requests=generate_requests - ) - - # loop until we do not reiceve any cached batch from the service (== until - # all requests have met their stopping criteria - while cached_batch is not None: - # print(f"batch_size = {len(cached_batch)}") - batches = [cached_batch] - - # try to get a new batch and run prefill on this batch - next_batch = router.queue.next_batch() - if next_batch is not None: - new_batch, new_generate_requests = next_batch - new_cached_batch = router.prefill( - batch=new_batch, - generate_requests=new_generate_requests - ) - - if new_cached_batch is not None: - batches.append(new_cached_batch) - assert len(generate_requests.keys() & new_generate_requests.keys()) == 0 - generate_requests.update(new_generate_requests) - - # run decode - cached_batch = router.decode( - batches=batches, - generate_requests=generate_requests - ) - - next_batch = router.queue.next_batch() \ No newline at end of file diff --git a/server/deepsparse/deepsparse_service.py b/server/deepsparse/deepsparse_service.py deleted file mode 100644 index f4eae070..00000000 --- a/server/deepsparse/deepsparse_service.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Optional, Dict, List -from server.deepsparse.deepsparse_causal_lm import ( - DeepSparseCausalLM, DeepSparseCausalLMBatch -) -from server.deepsparse.deepsparse_requests import ( - PrefillRequest, DecodeRequest, FilterBatchRequest, - Generation, CachedBatch -) - -class Cache: - def __init__(self): - self.cache: Dict[int, DeepSparseCausalLMBatch] = {} - - def pop(self, batch_id: int) -> Optional[DeepSparseCausalLMBatch]: - return self.cache.pop(batch_id, None) - - def set(self, entry: DeepSparseCausalLMBatch): - if entry is not None: - self.cache[entry.batch_id] = entry - - def delete(self, batch_id: int): - batch = self.pop(batch_id) - if batch is not None: - del batch - - def clear(self): - keys = list(self.cache.keys()) - for k in keys: - self.delete(k) - - def __len__(self): - return len(self.cache.keys()) - -class DeepSparseService: - def __init__( - self, - model: DeepSparseCausalLM - ): - self.model = model - self.cache = Cache() - - def ClearCache(self): - self.cache.clear() - - def FilterBatch( - self, - request: FilterBatchRequest - ) -> CachedBatch: - - ds_batch = self.cache.pop(request.batch_id) - assert ds_batch is not None, "Batch ID {request.batch_id} not found in cache." - filtered_batch = ds_batch.filter(request.request_ids) - self.cache.set(filtered_batch) - - return filtered_batch.to_batch() - - def Prefill( - self, - request: PrefillRequest - ) -> [Generation, CachedBatch]: - - ds_batch = DeepSparseCausalLMBatch.from_batch( - batch=request.batch, - tokenizer=self.model.tokenizer - ) - - generations, next_ds_batch = self.model.generate_token(ds_batch) - assert len(generations) == 1 - self.cache.set(next_ds_batch) - - return generations[0], next_ds_batch.to_batch() - - def Decode( - self, - request: DecodeRequest - ) -> [List[Generation], CachedBatch]: - assert len(request.batches) != 0, "Must provide at least one batch" - - ds_batches = [] - for batch in request.batches: - ds_batch = self.cache.pop(batch.batch_id) - assert batch is not None, "Batch ID {batch.id} not found in cache." - ds_batches.append(ds_batch) - - if len(ds_batches) > 1: - ds_batch = DeepSparseCausalLMBatch.concatenate(ds_batches) - else: - ds_batch = ds_batches[0] - - generations, next_ds_batch = self.model.generate_token(ds_batch) - self.cache.set(next_ds_batch) - - return generations, (next_ds_batch.to_batch() if next_ds_batch else None) \ No newline at end of file