built fastapi wrapper around router

This commit is contained in:
rsnm2 2023-08-24 20:37:38 +00:00
parent ff3275c442
commit 09f4fc40a5
3 changed files with 256 additions and 63 deletions

View File

@ -1,4 +1,5 @@
import fastapi, uvicorn
from contextlib import asynccontextmanager
from threading import Thread
from queue import Queue
from router import DeepSparseRouter, batching_task
@ -13,21 +14,30 @@ def serve(
host="0.0.0.0",
port=5543
):
# setup router
print("\n-------------------- Building Router --------------------\n")
router = DeepSparseRouter(
model_path=model_path,
tokenizer_path=tokenizer_path
)
# start background routing task
print("\n-------------------- Starting Batching Task --------------------\n")
batching_thread = Thread(target=batching_task, args=[router])
batching_thread.start()
router = None
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
print("\n-------------------- Building Router --------------------\n")
router = DeepSparseRouter(
model_path=model_path,
tokenizer_path=tokenizer_path
)
print("\n-------------------- Launching App --------------------\n")
app = fastapi.FastAPI()
print("\n-------------------- Starting Batching Task --------------------\n")
batching_thread = Thread(target=batching_task, args=[router])
batching_thread.start()
print("\n-------------------- Launching App --------------------\n")
yield
print("\n-------------------- Joining Batching Task --------------------\n")
router.stop_batching_task()
batching_task.join()
app = fastapi.FastAPI(lifespan=lifespan)
@app.get("/generate/{prompt}")
async def generate(prompt:str):
response_stream = Queue()
@ -53,8 +63,6 @@ def serve(
port=port,
workers=1
)
batching_thread.join()
if __name__ == "__main__":
serve()

View File

@ -4,49 +4,6 @@ from service.service import DeepSparseService
from service.causal_lm import DeepSparseCausalLM
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
# TODO: implement logic for maximum size of the queue based on memory usage
class DeepSparseQueue:
def __init__(self):
self.next_request_id: int = 0
self.next_batch_id: int = 0
self.queue: Queue[GenerateRequest] = Queue()
def append(self, generate_request: GenerateRequest):
self.queue.put(generate_request)
# TODO: enable multiple prefill requests in a batch
def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
# if not blocking, return none if empty
if not block and self.queue.empty():
return None
# if block = True, this blocks until something ready
# if block = False, the queue has data (if not an exception is raised)
# while queue.empty() == False does not guarentee data
# the queue is only subscribed to by one thread (this one)
# since batching_task is the only function that calls next_batch
generate_request = self.queue.get(block=block)
generate_requests = {self.next_request_id: generate_request}
# format into request
request = Request(
id=self.next_request_id,
prompt=generate_request.prompt,
max_generated_tokens=generate_request.max_generated_tokens
)
self.next_request_id += 1
# format into batch
batch = Batch(
id = self.next_batch_id,
requests=[request]
)
self.next_batch_id += 1
# return batch, generate_requests
return (batch, generate_requests)
class DeepSparseRouter:
def __init__(
self,
@ -70,10 +27,22 @@ class DeepSparseRouter:
)
self.queue: DeepSparseQueue = DeepSparseQueue()
self.batching_task_should_stop:bool = False
def submit_request(self, generate_request: GenerateRequest):
self.queue.append(generate_request)
def stop_batching_task(self):
# tell batching task to stop
self.batching_task_should_stop = True
# unblock the batching task with a dummy request if blocked
self.queue.append(GenerateRequest(
prompt="dummy",
max_generated_tokens=1,
response_stream=Queue()
))
def prefill(
self,
batch: Batch,
@ -136,16 +105,14 @@ class DeepSparseRouter:
# TODO: update to do more sophisticated logic as to when to do a prefill
def batching_task(router: DeepSparseRouter):
while True:
# while not signaled to stop
while not router.batching_task_should_stop:
# loop until no requests to process (note: this blocks if queue is empty)
next_batch = router.queue.next_batch(block=True)
while next_batch is not None:
batch, generate_requests = next_batch
# HACK for development --- breaks out of the cycle
if batch.requests[0].prompt == "stop":
return
# run prefill
cached_batch = router.prefill(
batch=batch,
@ -178,4 +145,47 @@ def batching_task(router: DeepSparseRouter):
generate_requests=generate_requests
)
next_batch = router.queue.next_batch(block=False)
next_batch = router.queue.next_batch(block=False)
# TODO: implement logic for maximum size of the queue based on memory usage
class DeepSparseQueue:
def __init__(self):
self.next_request_id: int = 0
self.next_batch_id: int = 0
self.queue: Queue[GenerateRequest] = Queue()
def append(self, generate_request: GenerateRequest):
self.queue.put(generate_request)
# TODO: enable multiple prefill requests in a batch
def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
# if not blocking, return none if empty
if not block and self.queue.empty():
return None
# if block = True, this blocks until something ready
# if block = False, the queue has data (if not an exception is raised)
# while queue.empty() == False does not guarentee data
# the queue is only subscribed to by one thread (this one)
# since batching_task is the only function that calls next_batch
generate_request = self.queue.get(block=block)
generate_requests = {self.next_request_id: generate_request}
# format into request
request = Request(
id=self.next_request_id,
prompt=generate_request.prompt,
max_generated_tokens=generate_request.max_generated_tokens
)
self.next_request_id += 1
# format into batch
batch = Batch(
id = self.next_batch_id,
requests=[request]
)
self.next_batch_id += 1
# return batch, generate_requests
return (batch, generate_requests)

View File

@ -11,6 +11,181 @@
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4ec49282-dafd-4d4b-af7a-a68af7fb0dc2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Driver function to test above function\\nn = int(input(\\\"Enter the number: \\\"))\\nprint(fib(n))\\n\\n# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\\n\""
]
}
],
"source": [
"!curl 127.0.0.1:5543/generate \\\n",
" -X POST \\\n",
" -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n",
" -H 'Content-Type: application/json'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6ea583b1-e2d3-4f35-b87f-630d097a2628",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b'\\n'\n",
"b' '\n",
"b'if'\n",
"b' n'\n",
"b' =='\n",
"b' 0'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' 0'\n",
"b'\\n'\n",
"b' '\n",
"b'el'\n",
"b'if'\n",
"b' n'\n",
"b' =='\n",
"b' 1'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' 1'\n",
"b'\\n'\n",
"b' '\n",
"b'else'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' fib'\n",
"b'('\n",
"b'n'\n",
"b'-'\n",
"b'1'\n",
"b')'\n",
"b' +'\n",
"b' fib'\n",
"b'('\n",
"b'n'\n",
"b'-'\n",
"b'2'\n",
"b')'\n",
"b'\\n'\n",
"b'\\n'\n",
"b'#'\n",
"b' Driver'\n",
"b' function'\n",
"b' to'\n",
"b' test'\n",
"b' above'\n",
"b' function'\n",
"b'\\n'\n",
"b'n'\n",
"b' ='\n",
"b' int'\n",
"b'('\n",
"b'input'\n",
"b'(\"'\n",
"b'Enter'\n",
"b' the'\n",
"b' number'\n",
"b':'\n",
"b' \"'\n",
"b'))'\n",
"b'\\n'\n",
"b'print'\n",
"b'('\n",
"b'f'\n",
"b'ib'\n",
"b'('\n",
"b'n'\n",
"b'))'\n",
"b'\\n'\n",
"b'\\n'\n",
"b'#'\n",
"b' This'\n",
"b' code'\n",
"b' is'\n",
"b' contributed'\n",
"b' by'\n",
"b' Nik'\n",
"b'h'\n",
"b'il'\n",
"b' Kumar'\n",
"b' Singh'\n",
"b'('\n",
"b'nick'\n",
"b'z'\n",
"b'uck'\n",
"b'_'\n",
"b'007'\n",
"b')'\n",
"b'\\n'\n"
]
}
],
"source": [
"import requests\n",
"\n",
"url = \"http://127.0.0.1:5543/generate_stream\"\n",
"obj = {\n",
" \"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n",
" \"max_generated_tokens\":100\n",
"}\n",
"\n",
"with requests.post(url, json=obj, stream=True) as r:\n",
" for chunk in r.iter_content(16): # or, for line in r.iter_lines():\n",
" print(chunk)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dfb0e1de-17e2-4ff2-aecb-22b6f607ef9b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" if n == 0:\n",
" return 0\n",
" elif n == 1:\n",
" return 1\n",
" else:\n",
" return fib(n-1) + fib(n-2)\n",
"\n",
"# Driver function to test above function\n",
"n = int(input(\"Enter the number: \"))\n",
"print(fib(n))\n",
"\n",
"# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n"
]
}
],
"source": [
"!curl 127.0.0.1:5543/generate_stream \\\n",
" -X POST \\\n",
" -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n",
" -H 'Content-Type: application/json'"
]
},
{
"cell_type": "code",
"execution_count": 13,