From e7f3eac8d552d042c10e329fee4df18728bfb3b7 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Tue, 22 Aug 2023 18:12:03 +0000 Subject: [PATCH 1/3] intermediate results from server --- server-dev.ipynb | 170 +++++++---------------- server/deepsparse/deepsparse_requests.py | 4 + 2 files changed, 57 insertions(+), 117 deletions(-) diff --git a/server-dev.ipynb b/server-dev.ipynb index 782920c8..b8f684a8 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -21,10 +21,39 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 12, + "id": "f23bc085-94db-44b6-af42-fc8a05f2cf6a", + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax (260114089.py, line 2)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m Cell \u001b[0;32mIn[12], line 2\u001b[0;36m\u001b[0m\n\u001b[0;31m b = (a = 5) < 5\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" + ] + } + ], + "source": [ + "a = None\n", + "b = (a = 5) < 5\n", + "print(b)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, "id": "631e94eb-cca0-438e-8936-6e8a87166d63", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-22 14:26:39 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n" + ] + } + ], "source": [ "from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLMBatch, DeepSparseCausalLM\n", "from server.deepsparse.deepsparse_service import DeepSparseService\n", @@ -44,7 +73,7 @@ "output_type": "stream", "text": [ "Using pad_token, but it is not set yet.\n", - "2023-08-22 03:09:19 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", + "2023-08-22 14:26:56 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n" ] }, @@ -67,7 +96,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "2023-08-22 03:09:45 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n" + "2023-08-22 14:27:21 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n" ] }, { @@ -100,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 4, "id": "85ce9aab-1a56-4b6f-a82b-4e91d52290b7", "metadata": {}, "outputs": [], @@ -136,10 +165,25 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 5, "id": "d2441753-fe2a-45c0-ad80-135b6207947d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'Batch' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m service\u001b[38;5;241m.\u001b[39mClearCache()\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# prefill queue\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m prefill_queue \u001b[38;5;241m=\u001b[39m \u001b[43mPrefillQueue\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompts\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# cached batches\u001b[39;00m\n\u001b[1;32m 7\u001b[0m cached_batches \u001b[38;5;241m=\u001b[39m []\n", + "Cell \u001b[0;32mIn[4], line 17\u001b[0m, in \u001b[0;36mPrefillQueue.__init__\u001b[0;34m(self, prompts)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, prompts):\n\u001b[0;32m---> 17\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqueue \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 18\u001b[0m idx: PrefillRequest(batch\u001b[38;5;241m=\u001b[39mmake_batch(\u001b[38;5;28mid\u001b[39m\u001b[38;5;241m=\u001b[39midx, prompt\u001b[38;5;241m=\u001b[39mprompt))\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx, prompt \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(prompts)\n\u001b[1;32m 20\u001b[0m }\n", + "Cell \u001b[0;32mIn[4], line 18\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, prompts):\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqueue \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m---> 18\u001b[0m idx: PrefillRequest(batch\u001b[38;5;241m=\u001b[39m\u001b[43mmake_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43midx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprompt\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx, prompt \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(prompts)\n\u001b[1;32m 20\u001b[0m }\n", + "Cell \u001b[0;32mIn[4], line 10\u001b[0m, in \u001b[0;36mmake_batch\u001b[0;34m(id, prompt)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmake_batch\u001b[39m(\u001b[38;5;28mid\u001b[39m, prompt):\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mBatch\u001b[49m(\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28mid\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mid\u001b[39m,\n\u001b[1;32m 12\u001b[0m requests\u001b[38;5;241m=\u001b[39m[Request(\u001b[38;5;28mid\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mid\u001b[39m, prompt\u001b[38;5;241m=\u001b[39mprompt)]\n\u001b[1;32m 13\u001b[0m )\n", + "\u001b[0;31mNameError\u001b[0m: name 'Batch' is not defined" + ] + } + ], "source": [ "service.ClearCache()\n", "\n", @@ -213,118 +257,10 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": null, "id": "dd6bcc43-63ef-4f92-a960-74e33b86dc97", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Request 0 is done!\n", - "Request 1 is done!\n", - "Request 3 is done!\n", - "Request 2 is done!\n", - "All Requests Done!\n", - "\n", - "\n", - "INDEX = 0:\n", - "Finish the following function for computing a fibonacci sequence: \n", - "\n", - " fib(n):\n", - "\n", - " if n == 0:\n", - " return 0\n", - " elif n == 1:\n", - " return 1\n", - " else:\n", - " return fib(n-1) + fib(n-2)\n", - "\n", - "# Call the function.\n", - "print(fib(5))\n", - "\n", - "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n", - "\n", - "\n", - "\n", - "INDEX = 1:\n", - "Write a function for filtering a list of integers to include only positive numbers:\n", - "\n", - "filter(lst):\n", - "\n", - "lst = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n", - "\n", - "def filter_positive(lst):\n", - " return [num for num in lst if num > 0]\n", - "\n", - "print(filter_positive(lst))\n", - "\n", - "# filter_positive([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n", - "\n", - "# filter_positive([1, 2, 3, 4, 5\n", - "\n", - "\n", - "INDEX = 2:\n", - "Write a function for reversing a string:\n", - "\n", - "def reverse_string(s):\n", - " return s[::-1]\n", - "\n", - "# Test\n", - "print(reverse_string(\"hello\"))\n", - "print(reverse_string(\"\"))\n", - "print(reverse_string(\"a\"))\n", - "print(reverse_string(\"\"))\n", - "print(reverse_string(\"\"))\n", - "print(reverse_string(\"\"))\n", - "print(reverse_string(\"\"))\n", - "print(reverse_string(\"\"))\n", - "print(reverse_string(\"\"))\n", - "print(reverse_string(\"\"))\n", - "print(reverse_string(\"\"))\n", - "print(reverse_string(\"\n", - "\n", - "\n", - "INDEX = 3:\n", - "Write a function for checking if a word if a palindrome:\n", - "\n", - "def is_palindrome(word):\n", - " return word == word[::-1]\n", - "\n", - "# Test\n", - "print(is_palindrome(\"racecar\"))\n", - "print(is_palindrome(\"racecar\"))\n", - "print(is_palindrome(\"racecar\"))\n", - "print(is_palindrome(\"racecar\"))\n", - "print(is_palindrome(\"racecar\"))\n", - "print(is_palindrome(\"racecar\"))\n", - "print(is_palindrome(\"racecar\"))\n", - "print(\n", - "\n", - "\n", - "INDEX = 4:\n", - "Write a function for sorting an array of integers:\n", - "\n", - "def merge_sort(arr):\n", - " if len(arr) <= 1:\n", - " return arr\n", - " mid = len(arr) // 2\n", - " left = arr[:mid]\n", - " right = arr[mid:]\n", - " left = merge_sort(left)\n", - " right = merge_sort(right)\n", - " return merge(left, right)\n", - "\n", - "def merge(left, right):\n", - " result = []\n", - " while len(left) > 0 and len(right) > 0:\n", - " if left[0]\n", - "\n", - "\n", - "[CachedBatch(batch_id=0, request_ids=[4])]\n" - ] - } - ], + "outputs": [], "source": [ "# run a few decodes\n", "for _ in range(100):\n", diff --git a/server/deepsparse/deepsparse_requests.py b/server/deepsparse/deepsparse_requests.py index 524793b0..430f3473 100644 --- a/server/deepsparse/deepsparse_requests.py +++ b/server/deepsparse/deepsparse_requests.py @@ -5,6 +5,7 @@ from typing import List, Optional class Request: id: int prompt: str + max_generated_tokens: int @dataclass class Batch: @@ -16,6 +17,9 @@ class CachedBatch: batch_id: int request_ids: List[int] + def __len__(self): + return len(self.request_ids) + @dataclass class Generation: request_id: int From e7ec2ff282a4d943c473427adc7dbdd50b8ec98f Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Tue, 22 Aug 2023 18:12:24 +0000 Subject: [PATCH 2/3] added all the files --- server/deepsparse/deepsparse_queue.py | 56 ++++++++++++++ server/deepsparse/deepsparse_router.py | 100 +++++++++++++++++++++++++ 2 files changed, 156 insertions(+) create mode 100644 server/deepsparse/deepsparse_queue.py create mode 100644 server/deepsparse/deepsparse_router.py diff --git a/server/deepsparse/deepsparse_queue.py b/server/deepsparse/deepsparse_queue.py new file mode 100644 index 00000000..6f0194b6 --- /dev/null +++ b/server/deepsparse/deepsparse_queue.py @@ -0,0 +1,56 @@ +from typing import Deque, Optional, Tuple, Dict +from collections import deque +from threading import Condition +from server.deepsparse.deepsparse_requests import Batch, Request + +class GenerateRequest: + def __init__( + self, + prompt: str, + max_generated_tokens: int + ): + self.prompt = prompt + self.generation = prompt + self.max_generated_tokens = max_generated_tokens + self.cv = Condition() + +class DeepSparseQueue: + def __init__(self): + self.next_request_id: int = 0 + self.next_batch_id: int = 0 + self.queue: Deque[GenerateRequest] = deque() + + def append(self, generate_request: GenerateRequest): + self.queue.append(generate_request) + + def is_empty(self): + return len(self.queue) == 0 + + # (todo): enable multiple prefill requests in a batch + def next_batch(self) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]: + if self.is_empty(): + return None + + # pop first generate_request in the queue + generate_request = self.queue.popleft() + generate_requests = { + self.next_request_id: generate_request + } + + # format into request + request = Request( + id=self.next_request_id, + prompt=generate_request.prompt, + max_generated_tokens=generate_request.max_generated_tokens + ) + self.next_request_id += 1 + + # format into batch + batch = Batch( + id = self.next_batch_id, + requests=[request] + ) + self.next_batch_id += 1 + + # return batch, generate_requests + return (batch, generate_requests) \ No newline at end of file diff --git a/server/deepsparse/deepsparse_router.py b/server/deepsparse/deepsparse_router.py new file mode 100644 index 00000000..a09b374e --- /dev/null +++ b/server/deepsparse/deepsparse_router.py @@ -0,0 +1,100 @@ +from threading import Condition +from typing import List, Dict, Optional + +from server.deepsparse.deepsparse_service import DeepSparseService +from server.deepsparse.deepsparse_requests import ( + CachedBatch, Batch, Generation, + PrefillRequest, DecodeRequest, FilterBatchRequest, +) +from server.deepsparse.deepsparse_queue import ( + DeepSparseQueue, GenerateRequest +) + +class DeepSparseRouter: + def __init__(self, service: DeepSparseService): + self.service: DeepSparseService = service + self.queue: DeepSparseQueue = DeepSparseQueue() + self.cv: Condition = Condition() + + def generate(self): + pass + + def prefill( + self, + batch: Batch, + generation_requests: Dict[int,GenerateRequest] + ) -> Optional[CachedBatch]: + + generation, next_batch = self.service.Prefill( + PrefillRequest(batch=batch) + ) + + self.filter_notify_update([generation], generation_requests) + + return self.filter_batch( + batch=next_batch, + generation_requests=generation_requests + ) + + def decode(self): + pass + + def filter_notify_update( + self, + generations: List[Generation], + generation_requests: Dict[int, GenerateRequest] + ): + for generation in generations: + request_id = generation.request_id + + # if we hit a stopping criteria + if generation.generated_text is None: + # remove from active requests and notify + stopped_generation_request = generation_requests.pop() + stopped_generation_request[request_id].cv.notify() + + # otherwise, update generation + else: + generation_requests[request_id].generation += generation.generated_text + + def filter_batch( + self, + batch: CachedBatch, + generation_requests: Dict[int, GenerateRequest] + ) -> Optional[CachedBatch]: + + # no need to filter + if len(batch) == len(generation_requests): + return batch + + # retain only requests that are still in active generation requests + batch.request_ids = [id for id in batch.request_ids if id in generation_requests] + + # if all requests complete, clear cache and return None + if len(batch) == 0: + self.service.ClearCache() + return None + + # otherwise call the filter batch service + return self.service.FilterBatch( + FilterBatchRequest( + batch_id=batch.batch_id, + request_ids=batch.request_ids, + ) + ) + + def batching_task(self): + while True: + with self.cv: + while self.queue.is_empty(): + self.cv.wait() + + # loop until the queue is empty + next_batch = self.queue.next_batch() + while next_batch is not None: + cached_batch = self.prefill(*next_batch) + + + + next_batch = self.queue.next_batch() + \ No newline at end of file From 1f87c7762fa96d2ecb020ba285a1ff93c87962c6 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Wed, 23 Aug 2023 19:54:31 +0000 Subject: [PATCH 3/3] implemented a basic naive router --- server-dev.ipynb | 235 ++++++++++++++++++++-- server/deepsparse/deepsparse_causal_lm.py | 4 +- server/deepsparse/deepsparse_queue.py | 2 + server/deepsparse/deepsparse_router.py | 142 ++++++++++--- server/deepsparse/deepsparse_service.py | 8 +- 5 files changed, 345 insertions(+), 46 deletions(-) diff --git a/server-dev.ipynb b/server-dev.ipynb index b8f684a8..80a5e83b 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -13,31 +13,242 @@ }, { "cell_type": "markdown", - "id": "7d43c041-2c79-4276-9104-2f224b2f8af6", + "id": "a19786b8-e72c-43c1-964f-45d92fd171e9", "metadata": {}, "source": [ - "## Example Interacting With The Service" + "## Example Interacting With The Router" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "f23bc085-94db-44b6-af42-fc8a05f2cf6a", + "execution_count": 2, + "id": "0b2c83cd-92ea-40d7-bc7e-f737b87d9b8d", "metadata": {}, "outputs": [ { - "ename": "SyntaxError", - "evalue": "invalid syntax (260114089.py, line 2)", - "output_type": "error", - "traceback": [ - "\u001b[0;36m Cell \u001b[0;32mIn[12], line 2\u001b[0;36m\u001b[0m\n\u001b[0;31m b = (a = 5) < 5\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-23 19:52:18 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n" ] } ], "source": [ - "a = None\n", - "b = (a = 5) < 5\n", - "print(b)\n" + "from server.deepsparse.deepsparse_router import DeepSparseRouter, batching_task\n", + "from server.deepsparse.deepsparse_service import DeepSparseService\n", + "from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLM" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "78acf813-3688-483d-9148-5c0df5d6b8e3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using pad_token, but it is not set yet.\n", + "2023-08-23 19:52:20 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", + "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deepsparse.engine.Engine:\n", + "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", + "\tbatch_size: 1\n", + "\tnum_cores: 8\n", + "\tnum_streams: 1\n", + "\tscheduler: Scheduler.default\n", + "\tfraction_of_supported_ops: 1.0\n", + "\tcpu_avx_type: avx2\n", + "\tcpu_vnni: False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-23 19:52:44 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deepsparse.engine.Engine:\n", + "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", + "\tbatch_size: 1\n", + "\tnum_cores: 8\n", + "\tnum_streams: 1\n", + "\tscheduler: Scheduler.default\n", + "\tfraction_of_supported_ops: 1.0\n", + "\tcpu_avx_type: avx2\n", + "\tcpu_vnni: False\n" + ] + } + ], + "source": [ + "tokenizer_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment\"\n", + "onnx_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\"\n", + "\n", + "model = DeepSparseCausalLM(\n", + " tokenizer_path=tokenizer_path,\n", + " model_path=onnx_path\n", + ")\n", + "\n", + "service = DeepSparseService(model=model)\n", + "router = DeepSparseRouter(service=service)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e93bac63-8924-4cf4-8683-81ce9333a2f1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + "def fib(n):\n", + " if n == 0:\n", + " return 0\n", + " elif n == 1:\n", + " return 1\n", + " else:\n", + " return fib(n-1) + fib(n-2)\n", + "\n", + "# Driver function to test above function\n", + "n = int(input(\"Enter the number: \"))\n", + "print(fib(n))\n", + "\n", + "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n", + "\n", + "\n", + "\n", + "Write a function for filtering a list of integers to include only positive numbers:\n", + "\n", + "def filter(lst):\n", + " return [x for x in lst if x > 0]\n", + "\n", + "# Test\n", + "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n", + "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n", + "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n", + "print(filter([1,\n", + "\n", + "\n", + "Write a function for checking if a word if a palindrome:\n", + "\n", + "def is_palindrome(word):\n", + " return word == word[::-1]\n", + "\n", + "# Test\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(\n", + "\n", + "\n", + "Write a function for reversing a string:\n", + "\n", + "def reverse_string(s):\n", + " return s[::-1]\n", + "\n", + "# Test\n", + "print(reverse_string(\"hello\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"a\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\n", + "\n", + "\n", + "Write a function for sorting an array of integers:\n", + "\n", + "def merge_sort(arr):\n", + " if len(arr) <= 1:\n", + " return arr\n", + " mid = len(arr) // 2\n", + " left = arr[:mid]\n", + " right = arr[mid:]\n", + " left = merge_sort(left)\n", + " right = merge_sort(right)\n", + " return merge(left, right)\n", + "\n", + "def merge(left, right):\n", + " result = []\n", + " while len(left) > 0 and len(right) > 0:\n", + " if left[0]\n", + "\n", + "\n", + "stop\n", + "\n", + "\n" + ] + } + ], + "source": [ + "from threading import Thread\n", + "import time\n", + "\n", + "batching_thread = Thread(target=batching_task, args=[router])\n", + "batching_thread.start()\n", + "\n", + "prompts = [\n", + " \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n", + " \"Write a function for filtering a list of integers to include only positive numbers:\\n\\ndef filter(lst):\",\n", + " \"Write a function for reversing a string:\\n\\ndef reverse_string(s):\",\n", + " \"Write a function for checking if a word if a palindrome:\\n\\ndef is_palindrome(word):\",\n", + " \"Write a function for sorting an array of integers:\\n\\ndef merge_sort(arr):\",\n", + "]\n", + "\n", + "def generate_task(prompt):\n", + " result = router.generate(prompt=prompt)\n", + " print(result)\n", + " print(\"\\n\")\n", + "\n", + "generate_threads = [\n", + " Thread(target=generate_task, args=[prompt]) for prompt in prompts\n", + "]\n", + "\n", + "# print(len(generate_threads))\n", + "\n", + "for gt in generate_threads:\n", + " gt.start()\n", + " time.sleep(0.5)\n", + "\n", + "for gt in generate_threads:\n", + " gt.join()\n", + "\n", + "\n", + "generate_task(\"stop\")\n", + "batching_thread.join()" + ] + }, + { + "cell_type": "markdown", + "id": "7d43c041-2c79-4276-9104-2f224b2f8af6", + "metadata": {}, + "source": [ + "## Example Interacting With The Service" ] }, { diff --git a/server/deepsparse/deepsparse_causal_lm.py b/server/deepsparse/deepsparse_causal_lm.py index 2afeb028..bf182395 100644 --- a/server/deepsparse/deepsparse_causal_lm.py +++ b/server/deepsparse/deepsparse_causal_lm.py @@ -205,11 +205,11 @@ class DeepSparseCausalLM: logits, past_key_values = self.model(input_ids, past_key_values) # sample token - # simple for now --- should use NextTokenChooser + # todo: simple for now --- should use NextTokenChooser generated_token_id = self.sample_token(logits) # check stopping criteria - # simple for now --- should use StoppingCriteria + # todo: simple for now --- should use StoppingCriteria assert len(input_ids.shape) == 2 assert input_ids.shape[0] == 1 diff --git a/server/deepsparse/deepsparse_queue.py b/server/deepsparse/deepsparse_queue.py index 6f0194b6..438d33f7 100644 --- a/server/deepsparse/deepsparse_queue.py +++ b/server/deepsparse/deepsparse_queue.py @@ -13,7 +13,9 @@ class GenerateRequest: self.generation = prompt self.max_generated_tokens = max_generated_tokens self.cv = Condition() + self.is_stopped = False +# todo: implement logic for maximum memory usage class DeepSparseQueue: def __init__(self): self.next_request_id: int = 0 diff --git a/server/deepsparse/deepsparse_router.py b/server/deepsparse/deepsparse_router.py index a09b374e..647d7f3f 100644 --- a/server/deepsparse/deepsparse_router.py +++ b/server/deepsparse/deepsparse_router.py @@ -16,59 +16,102 @@ class DeepSparseRouter: self.queue: DeepSparseQueue = DeepSparseQueue() self.cv: Condition = Condition() - def generate(self): - pass + def generate(self, prompt:str) -> str: + generate_request = GenerateRequest( + prompt=prompt, + max_generated_tokens=100 + ) + + with self.cv: + # print("router: acquired cv") + self.queue.append(generate_request) + self.cv.notify() + + if prompt == "stop": + return "stop" + + with generate_request.cv: + # print("generate_request: acquired cv") + if not generate_request.is_stopped: + # print("generate_request: waiting") + generate_request.cv.wait() + + # print("generate_request: done waiting") + + return generate_request.generation def prefill( self, batch: Batch, - generation_requests: Dict[int,GenerateRequest] + generate_requests: Dict[int,GenerateRequest] ) -> Optional[CachedBatch]: - + # print("prefill") generation, next_batch = self.service.Prefill( PrefillRequest(batch=batch) ) - self.filter_notify_update([generation], generation_requests) + self.filter_notify_update([generation], generate_requests) return self.filter_batch( batch=next_batch, - generation_requests=generation_requests + generate_requests=generate_requests ) - def decode(self): - pass + def decode( + self, + batches: List[CachedBatch], + generate_requests: Dict[int,GenerateRequest] + ) -> Optional[CachedBatch]: + # print("decode") + generations, next_batch = self.service.Decode( + DecodeRequest(batches=batches) + ) + + self.filter_notify_update(generations, generate_requests) + + return self.filter_batch( + batch=next_batch, + generate_requests=generate_requests + ) def filter_notify_update( self, generations: List[Generation], - generation_requests: Dict[int, GenerateRequest] + generate_requests: Dict[int, GenerateRequest] ): + # print("filter_notify_update") for generation in generations: request_id = generation.request_id # if we hit a stopping criteria if generation.generated_text is None: # remove from active requests and notify - stopped_generation_request = generation_requests.pop() - stopped_generation_request[request_id].cv.notify() + stopped_generate_request = generate_requests.pop(request_id) + with stopped_generate_request.cv: + stopped_generate_request.is_stopped = True + stopped_generate_request.cv.notify() # otherwise, update generation else: - generation_requests[request_id].generation += generation.generated_text + generate_requests[request_id].generation += generation.generated_text def filter_batch( self, - batch: CachedBatch, - generation_requests: Dict[int, GenerateRequest] + batch: Optional[CachedBatch], + generate_requests: Dict[int, GenerateRequest] ) -> Optional[CachedBatch]: + # print("filter_batch") + + # batch is already done + if batch is None: + return batch # no need to filter - if len(batch) == len(generation_requests): + if len(batch) == len(generate_requests): return batch # retain only requests that are still in active generation requests - batch.request_ids = [id for id in batch.request_ids if id in generation_requests] + batch.request_ids = [id for id in batch.request_ids if id in generate_requests] # if all requests complete, clear cache and return None if len(batch) == 0: @@ -83,18 +126,59 @@ class DeepSparseRouter: ) ) - def batching_task(self): - while True: - with self.cv: - while self.queue.is_empty(): - self.cv.wait() +def batching_task( + router: DeepSparseRouter +) -> bool: + # infinite_loop + while True: + # block while the queue is empty + # print("batching_task: about to acquire cv") + with router.cv: + while router.queue.is_empty(): + # print(f"batching_task cv: waiting") + router.cv.wait() + # print(f"batching_task: done waiting") + + # loop until all batches in the queue are processed + next_batch = router.queue.next_batch() + while next_batch is not None: + batch, generate_requests = next_batch - # loop until the queue is empty - next_batch = self.queue.next_batch() - while next_batch is not None: - cached_batch = self.prefill(*next_batch) + # hack to break out of the cycle + if batch.requests[0].prompt == "stop": + assert router.queue.is_empty() + assert len(router.service.cache) == 0 + return True + + cached_batch = router.prefill( + batch=batch, + generate_requests=generate_requests + ) + + # loop until we do not reiceve any cached batch from the service (== until + # all requests have met their stopping criteria + while cached_batch is not None: + # print(f"batch_size = {len(cached_batch)}") + batches = [cached_batch] - - - next_batch = self.queue.next_batch() - \ No newline at end of file + # try to get a new batch and run prefill on this batch + next_batch = router.queue.next_batch() + if next_batch is not None: + new_batch, new_generate_requests = next_batch + new_cached_batch = router.prefill( + batch=new_batch, + generate_requests=new_generate_requests + ) + + if new_cached_batch is not None: + batches.append(new_cached_batch) + assert len(generate_requests.keys() & new_generate_requests.keys()) == 0 + generate_requests.update(new_generate_requests) + + # run decode + cached_batch = router.decode( + batches=batches, + generate_requests=generate_requests + ) + + next_batch = router.queue.next_batch() \ No newline at end of file diff --git a/server/deepsparse/deepsparse_service.py b/server/deepsparse/deepsparse_service.py index 9458a8a6..f4eae070 100644 --- a/server/deepsparse/deepsparse_service.py +++ b/server/deepsparse/deepsparse_service.py @@ -7,7 +7,7 @@ from server.deepsparse.deepsparse_requests import ( Generation, CachedBatch ) -class BatchCache: +class Cache: def __init__(self): self.cache: Dict[int, DeepSparseCausalLMBatch] = {} @@ -37,7 +37,7 @@ class DeepSparseService: model: DeepSparseCausalLM ): self.model = model - self.cache = BatchCache() + self.cache = Cache() def ClearCache(self): self.cache.clear() @@ -46,6 +46,7 @@ class DeepSparseService: self, request: FilterBatchRequest ) -> CachedBatch: + ds_batch = self.cache.pop(request.batch_id) assert ds_batch is not None, "Batch ID {request.batch_id} not found in cache." filtered_batch = ds_batch.filter(request.request_ids) @@ -57,6 +58,7 @@ class DeepSparseService: self, request: PrefillRequest ) -> [Generation, CachedBatch]: + ds_batch = DeepSparseCausalLMBatch.from_batch( batch=request.batch, tokenizer=self.model.tokenizer @@ -88,4 +90,4 @@ class DeepSparseService: generations, next_ds_batch = self.model.generate_token(ds_batch) self.cache.set(next_ds_batch) - return generations, next_ds_batch.to_batch() if next_ds_batch else None \ No newline at end of file + return generations, (next_ds_batch.to_batch() if next_ds_batch else None) \ No newline at end of file