From 09f4fc40a5606276d952ca8f130fce02235c312d Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Thu, 24 Aug 2023 20:37:38 +0000 Subject: [PATCH] built fastapi wrapper around router --- deepsparse/main.py | 36 +++++---- deepsparse/router.py | 108 ++++++++++++++------------ server-dev.ipynb | 175 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 256 insertions(+), 63 deletions(-) diff --git a/deepsparse/main.py b/deepsparse/main.py index 28bc6588..3ce9a532 100644 --- a/deepsparse/main.py +++ b/deepsparse/main.py @@ -1,4 +1,5 @@ import fastapi, uvicorn +from contextlib import asynccontextmanager from threading import Thread from queue import Queue from router import DeepSparseRouter, batching_task @@ -13,21 +14,30 @@ def serve( host="0.0.0.0", port=5543 ): - # setup router - print("\n-------------------- Building Router --------------------\n") - router = DeepSparseRouter( - model_path=model_path, - tokenizer_path=tokenizer_path - ) - # start background routing task - print("\n-------------------- Starting Batching Task --------------------\n") - batching_thread = Thread(target=batching_task, args=[router]) - batching_thread.start() + router = None + + @asynccontextmanager + async def lifespan(app: fastapi.FastAPI): + print("\n-------------------- Building Router --------------------\n") + router = DeepSparseRouter( + model_path=model_path, + tokenizer_path=tokenizer_path + ) - print("\n-------------------- Launching App --------------------\n") - app = fastapi.FastAPI() + print("\n-------------------- Starting Batching Task --------------------\n") + batching_thread = Thread(target=batching_task, args=[router]) + batching_thread.start() + print("\n-------------------- Launching App --------------------\n") + yield + + print("\n-------------------- Joining Batching Task --------------------\n") + router.stop_batching_task() + batching_task.join() + + app = fastapi.FastAPI(lifespan=lifespan) + @app.get("/generate/{prompt}") async def generate(prompt:str): response_stream = Queue() @@ -53,8 +63,6 @@ def serve( port=port, workers=1 ) - - batching_thread.join() if __name__ == "__main__": serve() \ No newline at end of file diff --git a/deepsparse/router.py b/deepsparse/router.py index d91107c1..85e5099e 100644 --- a/deepsparse/router.py +++ b/deepsparse/router.py @@ -4,49 +4,6 @@ from service.service import DeepSparseService from service.causal_lm import DeepSparseCausalLM from utils import CachedBatch, Batch, Generation, GenerateRequest, Request -# TODO: implement logic for maximum size of the queue based on memory usage -class DeepSparseQueue: - def __init__(self): - self.next_request_id: int = 0 - self.next_batch_id: int = 0 - self.queue: Queue[GenerateRequest] = Queue() - - def append(self, generate_request: GenerateRequest): - self.queue.put(generate_request) - - # TODO: enable multiple prefill requests in a batch - def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]: - - # if not blocking, return none if empty - if not block and self.queue.empty(): - return None - - # if block = True, this blocks until something ready - # if block = False, the queue has data (if not an exception is raised) - # while queue.empty() == False does not guarentee data - # the queue is only subscribed to by one thread (this one) - # since batching_task is the only function that calls next_batch - generate_request = self.queue.get(block=block) - generate_requests = {self.next_request_id: generate_request} - - # format into request - request = Request( - id=self.next_request_id, - prompt=generate_request.prompt, - max_generated_tokens=generate_request.max_generated_tokens - ) - self.next_request_id += 1 - - # format into batch - batch = Batch( - id = self.next_batch_id, - requests=[request] - ) - self.next_batch_id += 1 - - # return batch, generate_requests - return (batch, generate_requests) - class DeepSparseRouter: def __init__( self, @@ -70,10 +27,22 @@ class DeepSparseRouter: ) self.queue: DeepSparseQueue = DeepSparseQueue() + self.batching_task_should_stop:bool = False def submit_request(self, generate_request: GenerateRequest): self.queue.append(generate_request) + def stop_batching_task(self): + # tell batching task to stop + self.batching_task_should_stop = True + + # unblock the batching task with a dummy request if blocked + self.queue.append(GenerateRequest( + prompt="dummy", + max_generated_tokens=1, + response_stream=Queue() + )) + def prefill( self, batch: Batch, @@ -136,16 +105,14 @@ class DeepSparseRouter: # TODO: update to do more sophisticated logic as to when to do a prefill def batching_task(router: DeepSparseRouter): - while True: + # while not signaled to stop + while not router.batching_task_should_stop: + # loop until no requests to process (note: this blocks if queue is empty) next_batch = router.queue.next_batch(block=True) while next_batch is not None: batch, generate_requests = next_batch - # HACK for development --- breaks out of the cycle - if batch.requests[0].prompt == "stop": - return - # run prefill cached_batch = router.prefill( batch=batch, @@ -178,4 +145,47 @@ def batching_task(router: DeepSparseRouter): generate_requests=generate_requests ) - next_batch = router.queue.next_batch(block=False) \ No newline at end of file + next_batch = router.queue.next_batch(block=False) + +# TODO: implement logic for maximum size of the queue based on memory usage +class DeepSparseQueue: + def __init__(self): + self.next_request_id: int = 0 + self.next_batch_id: int = 0 + self.queue: Queue[GenerateRequest] = Queue() + + def append(self, generate_request: GenerateRequest): + self.queue.put(generate_request) + + # TODO: enable multiple prefill requests in a batch + def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]: + + # if not blocking, return none if empty + if not block and self.queue.empty(): + return None + + # if block = True, this blocks until something ready + # if block = False, the queue has data (if not an exception is raised) + # while queue.empty() == False does not guarentee data + # the queue is only subscribed to by one thread (this one) + # since batching_task is the only function that calls next_batch + generate_request = self.queue.get(block=block) + generate_requests = {self.next_request_id: generate_request} + + # format into request + request = Request( + id=self.next_request_id, + prompt=generate_request.prompt, + max_generated_tokens=generate_request.max_generated_tokens + ) + self.next_request_id += 1 + + # format into batch + batch = Batch( + id = self.next_batch_id, + requests=[request] + ) + self.next_batch_id += 1 + + # return batch, generate_requests + return (batch, generate_requests) \ No newline at end of file diff --git a/server-dev.ipynb b/server-dev.ipynb index 902b467a..440cf7ef 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -11,6 +11,181 @@ "%autoreload 2" ] }, + { + "cell_type": "code", + "execution_count": 10, + "id": "4ec49282-dafd-4d4b-af7a-a68af7fb0dc2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Driver function to test above function\\nn = int(input(\\\"Enter the number: \\\"))\\nprint(fib(n))\\n\\n# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\\n\"" + ] + } + ], + "source": [ + "!curl 127.0.0.1:5543/generate \\\n", + " -X POST \\\n", + " -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n", + " -H 'Content-Type: application/json'" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6ea583b1-e2d3-4f35-b87f-630d097a2628", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'\\n'\n", + "b' '\n", + "b'if'\n", + "b' n'\n", + "b' =='\n", + "b' 0'\n", + "b':'\n", + "b'\\n'\n", + "b' '\n", + "b'return'\n", + "b' 0'\n", + "b'\\n'\n", + "b' '\n", + "b'el'\n", + "b'if'\n", + "b' n'\n", + "b' =='\n", + "b' 1'\n", + "b':'\n", + "b'\\n'\n", + "b' '\n", + "b'return'\n", + "b' 1'\n", + "b'\\n'\n", + "b' '\n", + "b'else'\n", + "b':'\n", + "b'\\n'\n", + "b' '\n", + "b'return'\n", + "b' fib'\n", + "b'('\n", + "b'n'\n", + "b'-'\n", + "b'1'\n", + "b')'\n", + "b' +'\n", + "b' fib'\n", + "b'('\n", + "b'n'\n", + "b'-'\n", + "b'2'\n", + "b')'\n", + "b'\\n'\n", + "b'\\n'\n", + "b'#'\n", + "b' Driver'\n", + "b' function'\n", + "b' to'\n", + "b' test'\n", + "b' above'\n", + "b' function'\n", + "b'\\n'\n", + "b'n'\n", + "b' ='\n", + "b' int'\n", + "b'('\n", + "b'input'\n", + "b'(\"'\n", + "b'Enter'\n", + "b' the'\n", + "b' number'\n", + "b':'\n", + "b' \"'\n", + "b'))'\n", + "b'\\n'\n", + "b'print'\n", + "b'('\n", + "b'f'\n", + "b'ib'\n", + "b'('\n", + "b'n'\n", + "b'))'\n", + "b'\\n'\n", + "b'\\n'\n", + "b'#'\n", + "b' This'\n", + "b' code'\n", + "b' is'\n", + "b' contributed'\n", + "b' by'\n", + "b' Nik'\n", + "b'h'\n", + "b'il'\n", + "b' Kumar'\n", + "b' Singh'\n", + "b'('\n", + "b'nick'\n", + "b'z'\n", + "b'uck'\n", + "b'_'\n", + "b'007'\n", + "b')'\n", + "b'\\n'\n" + ] + } + ], + "source": [ + "import requests\n", + "\n", + "url = \"http://127.0.0.1:5543/generate_stream\"\n", + "obj = {\n", + " \"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n", + " \"max_generated_tokens\":100\n", + "}\n", + "\n", + "with requests.post(url, json=obj, stream=True) as r:\n", + " for chunk in r.iter_content(16): # or, for line in r.iter_lines():\n", + " print(chunk)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dfb0e1de-17e2-4ff2-aecb-22b6f607ef9b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " if n == 0:\n", + " return 0\n", + " elif n == 1:\n", + " return 1\n", + " else:\n", + " return fib(n-1) + fib(n-2)\n", + "\n", + "# Driver function to test above function\n", + "n = int(input(\"Enter the number: \"))\n", + "print(fib(n))\n", + "\n", + "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n" + ] + } + ], + "source": [ + "!curl 127.0.0.1:5543/generate_stream \\\n", + " -X POST \\\n", + " -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n", + " -H 'Content-Type: application/json'" + ] + }, { "cell_type": "code", "execution_count": 13,