From 1f87c7762fa96d2ecb020ba285a1ff93c87962c6 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Wed, 23 Aug 2023 19:54:31 +0000 Subject: [PATCH] implemented a basic naive router --- server-dev.ipynb | 235 ++++++++++++++++++++-- server/deepsparse/deepsparse_causal_lm.py | 4 +- server/deepsparse/deepsparse_queue.py | 2 + server/deepsparse/deepsparse_router.py | 142 ++++++++++--- server/deepsparse/deepsparse_service.py | 8 +- 5 files changed, 345 insertions(+), 46 deletions(-) diff --git a/server-dev.ipynb b/server-dev.ipynb index b8f684a8..80a5e83b 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -13,31 +13,242 @@ }, { "cell_type": "markdown", - "id": "7d43c041-2c79-4276-9104-2f224b2f8af6", + "id": "a19786b8-e72c-43c1-964f-45d92fd171e9", "metadata": {}, "source": [ - "## Example Interacting With The Service" + "## Example Interacting With The Router" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "f23bc085-94db-44b6-af42-fc8a05f2cf6a", + "execution_count": 2, + "id": "0b2c83cd-92ea-40d7-bc7e-f737b87d9b8d", "metadata": {}, "outputs": [ { - "ename": "SyntaxError", - "evalue": "invalid syntax (260114089.py, line 2)", - "output_type": "error", - "traceback": [ - "\u001b[0;36m Cell \u001b[0;32mIn[12], line 2\u001b[0;36m\u001b[0m\n\u001b[0;31m b = (a = 5) < 5\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-23 19:52:18 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n" ] } ], "source": [ - "a = None\n", - "b = (a = 5) < 5\n", - "print(b)\n" + "from server.deepsparse.deepsparse_router import DeepSparseRouter, batching_task\n", + "from server.deepsparse.deepsparse_service import DeepSparseService\n", + "from server.deepsparse.deepsparse_causal_lm import DeepSparseCausalLM" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "78acf813-3688-483d-9148-5c0df5d6b8e3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using pad_token, but it is not set yet.\n", + "2023-08-23 19:52:20 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", + "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deepsparse.engine.Engine:\n", + "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", + "\tbatch_size: 1\n", + "\tnum_cores: 8\n", + "\tnum_streams: 1\n", + "\tscheduler: Scheduler.default\n", + "\tfraction_of_supported_ops: 1.0\n", + "\tcpu_avx_type: avx2\n", + "\tcpu_vnni: False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-23 19:52:44 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deepsparse.engine.Engine:\n", + "\tonnx_file_path: /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", + "\tbatch_size: 1\n", + "\tnum_cores: 8\n", + "\tnum_streams: 1\n", + "\tscheduler: Scheduler.default\n", + "\tfraction_of_supported_ops: 1.0\n", + "\tcpu_avx_type: avx2\n", + "\tcpu_vnni: False\n" + ] + } + ], + "source": [ + "tokenizer_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment\"\n", + "onnx_path = \"/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\"\n", + "\n", + "model = DeepSparseCausalLM(\n", + " tokenizer_path=tokenizer_path,\n", + " model_path=onnx_path\n", + ")\n", + "\n", + "service = DeepSparseService(model=model)\n", + "router = DeepSparseRouter(service=service)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e93bac63-8924-4cf4-8683-81ce9333a2f1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + "def fib(n):\n", + " if n == 0:\n", + " return 0\n", + " elif n == 1:\n", + " return 1\n", + " else:\n", + " return fib(n-1) + fib(n-2)\n", + "\n", + "# Driver function to test above function\n", + "n = int(input(\"Enter the number: \"))\n", + "print(fib(n))\n", + "\n", + "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n", + "\n", + "\n", + "\n", + "Write a function for filtering a list of integers to include only positive numbers:\n", + "\n", + "def filter(lst):\n", + " return [x for x in lst if x > 0]\n", + "\n", + "# Test\n", + "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n", + "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n", + "print(filter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))\n", + "print(filter([1,\n", + "\n", + "\n", + "Write a function for checking if a word if a palindrome:\n", + "\n", + "def is_palindrome(word):\n", + " return word == word[::-1]\n", + "\n", + "# Test\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(is_palindrome(\"racecar\"))\n", + "print(\n", + "\n", + "\n", + "Write a function for reversing a string:\n", + "\n", + "def reverse_string(s):\n", + " return s[::-1]\n", + "\n", + "# Test\n", + "print(reverse_string(\"hello\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"a\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\"))\n", + "print(reverse_string(\"\n", + "\n", + "\n", + "Write a function for sorting an array of integers:\n", + "\n", + "def merge_sort(arr):\n", + " if len(arr) <= 1:\n", + " return arr\n", + " mid = len(arr) // 2\n", + " left = arr[:mid]\n", + " right = arr[mid:]\n", + " left = merge_sort(left)\n", + " right = merge_sort(right)\n", + " return merge(left, right)\n", + "\n", + "def merge(left, right):\n", + " result = []\n", + " while len(left) > 0 and len(right) > 0:\n", + " if left[0]\n", + "\n", + "\n", + "stop\n", + "\n", + "\n" + ] + } + ], + "source": [ + "from threading import Thread\n", + "import time\n", + "\n", + "batching_thread = Thread(target=batching_task, args=[router])\n", + "batching_thread.start()\n", + "\n", + "prompts = [\n", + " \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n", + " \"Write a function for filtering a list of integers to include only positive numbers:\\n\\ndef filter(lst):\",\n", + " \"Write a function for reversing a string:\\n\\ndef reverse_string(s):\",\n", + " \"Write a function for checking if a word if a palindrome:\\n\\ndef is_palindrome(word):\",\n", + " \"Write a function for sorting an array of integers:\\n\\ndef merge_sort(arr):\",\n", + "]\n", + "\n", + "def generate_task(prompt):\n", + " result = router.generate(prompt=prompt)\n", + " print(result)\n", + " print(\"\\n\")\n", + "\n", + "generate_threads = [\n", + " Thread(target=generate_task, args=[prompt]) for prompt in prompts\n", + "]\n", + "\n", + "# print(len(generate_threads))\n", + "\n", + "for gt in generate_threads:\n", + " gt.start()\n", + " time.sleep(0.5)\n", + "\n", + "for gt in generate_threads:\n", + " gt.join()\n", + "\n", + "\n", + "generate_task(\"stop\")\n", + "batching_thread.join()" + ] + }, + { + "cell_type": "markdown", + "id": "7d43c041-2c79-4276-9104-2f224b2f8af6", + "metadata": {}, + "source": [ + "## Example Interacting With The Service" ] }, { diff --git a/server/deepsparse/deepsparse_causal_lm.py b/server/deepsparse/deepsparse_causal_lm.py index 2afeb028..bf182395 100644 --- a/server/deepsparse/deepsparse_causal_lm.py +++ b/server/deepsparse/deepsparse_causal_lm.py @@ -205,11 +205,11 @@ class DeepSparseCausalLM: logits, past_key_values = self.model(input_ids, past_key_values) # sample token - # simple for now --- should use NextTokenChooser + # todo: simple for now --- should use NextTokenChooser generated_token_id = self.sample_token(logits) # check stopping criteria - # simple for now --- should use StoppingCriteria + # todo: simple for now --- should use StoppingCriteria assert len(input_ids.shape) == 2 assert input_ids.shape[0] == 1 diff --git a/server/deepsparse/deepsparse_queue.py b/server/deepsparse/deepsparse_queue.py index 6f0194b6..438d33f7 100644 --- a/server/deepsparse/deepsparse_queue.py +++ b/server/deepsparse/deepsparse_queue.py @@ -13,7 +13,9 @@ class GenerateRequest: self.generation = prompt self.max_generated_tokens = max_generated_tokens self.cv = Condition() + self.is_stopped = False +# todo: implement logic for maximum memory usage class DeepSparseQueue: def __init__(self): self.next_request_id: int = 0 diff --git a/server/deepsparse/deepsparse_router.py b/server/deepsparse/deepsparse_router.py index a09b374e..647d7f3f 100644 --- a/server/deepsparse/deepsparse_router.py +++ b/server/deepsparse/deepsparse_router.py @@ -16,59 +16,102 @@ class DeepSparseRouter: self.queue: DeepSparseQueue = DeepSparseQueue() self.cv: Condition = Condition() - def generate(self): - pass + def generate(self, prompt:str) -> str: + generate_request = GenerateRequest( + prompt=prompt, + max_generated_tokens=100 + ) + + with self.cv: + # print("router: acquired cv") + self.queue.append(generate_request) + self.cv.notify() + + if prompt == "stop": + return "stop" + + with generate_request.cv: + # print("generate_request: acquired cv") + if not generate_request.is_stopped: + # print("generate_request: waiting") + generate_request.cv.wait() + + # print("generate_request: done waiting") + + return generate_request.generation def prefill( self, batch: Batch, - generation_requests: Dict[int,GenerateRequest] + generate_requests: Dict[int,GenerateRequest] ) -> Optional[CachedBatch]: - + # print("prefill") generation, next_batch = self.service.Prefill( PrefillRequest(batch=batch) ) - self.filter_notify_update([generation], generation_requests) + self.filter_notify_update([generation], generate_requests) return self.filter_batch( batch=next_batch, - generation_requests=generation_requests + generate_requests=generate_requests ) - def decode(self): - pass + def decode( + self, + batches: List[CachedBatch], + generate_requests: Dict[int,GenerateRequest] + ) -> Optional[CachedBatch]: + # print("decode") + generations, next_batch = self.service.Decode( + DecodeRequest(batches=batches) + ) + + self.filter_notify_update(generations, generate_requests) + + return self.filter_batch( + batch=next_batch, + generate_requests=generate_requests + ) def filter_notify_update( self, generations: List[Generation], - generation_requests: Dict[int, GenerateRequest] + generate_requests: Dict[int, GenerateRequest] ): + # print("filter_notify_update") for generation in generations: request_id = generation.request_id # if we hit a stopping criteria if generation.generated_text is None: # remove from active requests and notify - stopped_generation_request = generation_requests.pop() - stopped_generation_request[request_id].cv.notify() + stopped_generate_request = generate_requests.pop(request_id) + with stopped_generate_request.cv: + stopped_generate_request.is_stopped = True + stopped_generate_request.cv.notify() # otherwise, update generation else: - generation_requests[request_id].generation += generation.generated_text + generate_requests[request_id].generation += generation.generated_text def filter_batch( self, - batch: CachedBatch, - generation_requests: Dict[int, GenerateRequest] + batch: Optional[CachedBatch], + generate_requests: Dict[int, GenerateRequest] ) -> Optional[CachedBatch]: + # print("filter_batch") + + # batch is already done + if batch is None: + return batch # no need to filter - if len(batch) == len(generation_requests): + if len(batch) == len(generate_requests): return batch # retain only requests that are still in active generation requests - batch.request_ids = [id for id in batch.request_ids if id in generation_requests] + batch.request_ids = [id for id in batch.request_ids if id in generate_requests] # if all requests complete, clear cache and return None if len(batch) == 0: @@ -83,18 +126,59 @@ class DeepSparseRouter: ) ) - def batching_task(self): - while True: - with self.cv: - while self.queue.is_empty(): - self.cv.wait() +def batching_task( + router: DeepSparseRouter +) -> bool: + # infinite_loop + while True: + # block while the queue is empty + # print("batching_task: about to acquire cv") + with router.cv: + while router.queue.is_empty(): + # print(f"batching_task cv: waiting") + router.cv.wait() + # print(f"batching_task: done waiting") + + # loop until all batches in the queue are processed + next_batch = router.queue.next_batch() + while next_batch is not None: + batch, generate_requests = next_batch - # loop until the queue is empty - next_batch = self.queue.next_batch() - while next_batch is not None: - cached_batch = self.prefill(*next_batch) + # hack to break out of the cycle + if batch.requests[0].prompt == "stop": + assert router.queue.is_empty() + assert len(router.service.cache) == 0 + return True + + cached_batch = router.prefill( + batch=batch, + generate_requests=generate_requests + ) + + # loop until we do not reiceve any cached batch from the service (== until + # all requests have met their stopping criteria + while cached_batch is not None: + # print(f"batch_size = {len(cached_batch)}") + batches = [cached_batch] - - - next_batch = self.queue.next_batch() - \ No newline at end of file + # try to get a new batch and run prefill on this batch + next_batch = router.queue.next_batch() + if next_batch is not None: + new_batch, new_generate_requests = next_batch + new_cached_batch = router.prefill( + batch=new_batch, + generate_requests=new_generate_requests + ) + + if new_cached_batch is not None: + batches.append(new_cached_batch) + assert len(generate_requests.keys() & new_generate_requests.keys()) == 0 + generate_requests.update(new_generate_requests) + + # run decode + cached_batch = router.decode( + batches=batches, + generate_requests=generate_requests + ) + + next_batch = router.queue.next_batch() \ No newline at end of file diff --git a/server/deepsparse/deepsparse_service.py b/server/deepsparse/deepsparse_service.py index 9458a8a6..f4eae070 100644 --- a/server/deepsparse/deepsparse_service.py +++ b/server/deepsparse/deepsparse_service.py @@ -7,7 +7,7 @@ from server.deepsparse.deepsparse_requests import ( Generation, CachedBatch ) -class BatchCache: +class Cache: def __init__(self): self.cache: Dict[int, DeepSparseCausalLMBatch] = {} @@ -37,7 +37,7 @@ class DeepSparseService: model: DeepSparseCausalLM ): self.model = model - self.cache = BatchCache() + self.cache = Cache() def ClearCache(self): self.cache.clear() @@ -46,6 +46,7 @@ class DeepSparseService: self, request: FilterBatchRequest ) -> CachedBatch: + ds_batch = self.cache.pop(request.batch_id) assert ds_batch is not None, "Batch ID {request.batch_id} not found in cache." filtered_batch = ds_batch.filter(request.request_ids) @@ -57,6 +58,7 @@ class DeepSparseService: self, request: PrefillRequest ) -> [Generation, CachedBatch]: + ds_batch = DeepSparseCausalLMBatch.from_batch( batch=request.batch, tokenizer=self.model.tokenizer @@ -88,4 +90,4 @@ class DeepSparseService: generations, next_ds_batch = self.model.generate_token(ds_batch) self.cache.set(next_ds_batch) - return generations, next_ds_batch.to_batch() if next_ds_batch else None \ No newline at end of file + return generations, (next_ds_batch.to_batch() if next_ds_batch else None) \ No newline at end of file