built fastapi wrapper around router

This commit is contained in:
rsnm2 2023-08-24 20:37:38 +00:00
parent ff3275c442
commit 09f4fc40a5
3 changed files with 256 additions and 63 deletions

View File

@ -1,4 +1,5 @@
import fastapi, uvicorn import fastapi, uvicorn
from contextlib import asynccontextmanager
from threading import Thread from threading import Thread
from queue import Queue from queue import Queue
from router import DeepSparseRouter, batching_task from router import DeepSparseRouter, batching_task
@ -13,21 +14,30 @@ def serve(
host="0.0.0.0", host="0.0.0.0",
port=5543 port=5543
): ):
# setup router
print("\n-------------------- Building Router --------------------\n")
router = DeepSparseRouter(
model_path=model_path,
tokenizer_path=tokenizer_path
)
# start background routing task router = None
print("\n-------------------- Starting Batching Task --------------------\n")
batching_thread = Thread(target=batching_task, args=[router]) @asynccontextmanager
batching_thread.start() async def lifespan(app: fastapi.FastAPI):
print("\n-------------------- Building Router --------------------\n")
router = DeepSparseRouter(
model_path=model_path,
tokenizer_path=tokenizer_path
)
print("\n-------------------- Launching App --------------------\n") print("\n-------------------- Starting Batching Task --------------------\n")
app = fastapi.FastAPI() batching_thread = Thread(target=batching_task, args=[router])
batching_thread.start()
print("\n-------------------- Launching App --------------------\n")
yield
print("\n-------------------- Joining Batching Task --------------------\n")
router.stop_batching_task()
batching_task.join()
app = fastapi.FastAPI(lifespan=lifespan)
@app.get("/generate/{prompt}") @app.get("/generate/{prompt}")
async def generate(prompt:str): async def generate(prompt:str):
response_stream = Queue() response_stream = Queue()
@ -53,8 +63,6 @@ def serve(
port=port, port=port,
workers=1 workers=1
) )
batching_thread.join()
if __name__ == "__main__": if __name__ == "__main__":
serve() serve()

View File

@ -4,49 +4,6 @@ from service.service import DeepSparseService
from service.causal_lm import DeepSparseCausalLM from service.causal_lm import DeepSparseCausalLM
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
# TODO: implement logic for maximum size of the queue based on memory usage
class DeepSparseQueue:
def __init__(self):
self.next_request_id: int = 0
self.next_batch_id: int = 0
self.queue: Queue[GenerateRequest] = Queue()
def append(self, generate_request: GenerateRequest):
self.queue.put(generate_request)
# TODO: enable multiple prefill requests in a batch
def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
# if not blocking, return none if empty
if not block and self.queue.empty():
return None
# if block = True, this blocks until something ready
# if block = False, the queue has data (if not an exception is raised)
# while queue.empty() == False does not guarentee data
# the queue is only subscribed to by one thread (this one)
# since batching_task is the only function that calls next_batch
generate_request = self.queue.get(block=block)
generate_requests = {self.next_request_id: generate_request}
# format into request
request = Request(
id=self.next_request_id,
prompt=generate_request.prompt,
max_generated_tokens=generate_request.max_generated_tokens
)
self.next_request_id += 1
# format into batch
batch = Batch(
id = self.next_batch_id,
requests=[request]
)
self.next_batch_id += 1
# return batch, generate_requests
return (batch, generate_requests)
class DeepSparseRouter: class DeepSparseRouter:
def __init__( def __init__(
self, self,
@ -70,10 +27,22 @@ class DeepSparseRouter:
) )
self.queue: DeepSparseQueue = DeepSparseQueue() self.queue: DeepSparseQueue = DeepSparseQueue()
self.batching_task_should_stop:bool = False
def submit_request(self, generate_request: GenerateRequest): def submit_request(self, generate_request: GenerateRequest):
self.queue.append(generate_request) self.queue.append(generate_request)
def stop_batching_task(self):
# tell batching task to stop
self.batching_task_should_stop = True
# unblock the batching task with a dummy request if blocked
self.queue.append(GenerateRequest(
prompt="dummy",
max_generated_tokens=1,
response_stream=Queue()
))
def prefill( def prefill(
self, self,
batch: Batch, batch: Batch,
@ -136,16 +105,14 @@ class DeepSparseRouter:
# TODO: update to do more sophisticated logic as to when to do a prefill # TODO: update to do more sophisticated logic as to when to do a prefill
def batching_task(router: DeepSparseRouter): def batching_task(router: DeepSparseRouter):
while True: # while not signaled to stop
while not router.batching_task_should_stop:
# loop until no requests to process (note: this blocks if queue is empty) # loop until no requests to process (note: this blocks if queue is empty)
next_batch = router.queue.next_batch(block=True) next_batch = router.queue.next_batch(block=True)
while next_batch is not None: while next_batch is not None:
batch, generate_requests = next_batch batch, generate_requests = next_batch
# HACK for development --- breaks out of the cycle
if batch.requests[0].prompt == "stop":
return
# run prefill # run prefill
cached_batch = router.prefill( cached_batch = router.prefill(
batch=batch, batch=batch,
@ -178,4 +145,47 @@ def batching_task(router: DeepSparseRouter):
generate_requests=generate_requests generate_requests=generate_requests
) )
next_batch = router.queue.next_batch(block=False) next_batch = router.queue.next_batch(block=False)
# TODO: implement logic for maximum size of the queue based on memory usage
class DeepSparseQueue:
def __init__(self):
self.next_request_id: int = 0
self.next_batch_id: int = 0
self.queue: Queue[GenerateRequest] = Queue()
def append(self, generate_request: GenerateRequest):
self.queue.put(generate_request)
# TODO: enable multiple prefill requests in a batch
def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
# if not blocking, return none if empty
if not block and self.queue.empty():
return None
# if block = True, this blocks until something ready
# if block = False, the queue has data (if not an exception is raised)
# while queue.empty() == False does not guarentee data
# the queue is only subscribed to by one thread (this one)
# since batching_task is the only function that calls next_batch
generate_request = self.queue.get(block=block)
generate_requests = {self.next_request_id: generate_request}
# format into request
request = Request(
id=self.next_request_id,
prompt=generate_request.prompt,
max_generated_tokens=generate_request.max_generated_tokens
)
self.next_request_id += 1
# format into batch
batch = Batch(
id = self.next_batch_id,
requests=[request]
)
self.next_batch_id += 1
# return batch, generate_requests
return (batch, generate_requests)

View File

@ -11,6 +11,181 @@
"%autoreload 2" "%autoreload 2"
] ]
}, },
{
"cell_type": "code",
"execution_count": 10,
"id": "4ec49282-dafd-4d4b-af7a-a68af7fb0dc2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Driver function to test above function\\nn = int(input(\\\"Enter the number: \\\"))\\nprint(fib(n))\\n\\n# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\\n\""
]
}
],
"source": [
"!curl 127.0.0.1:5543/generate \\\n",
" -X POST \\\n",
" -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n",
" -H 'Content-Type: application/json'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6ea583b1-e2d3-4f35-b87f-630d097a2628",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b'\\n'\n",
"b' '\n",
"b'if'\n",
"b' n'\n",
"b' =='\n",
"b' 0'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' 0'\n",
"b'\\n'\n",
"b' '\n",
"b'el'\n",
"b'if'\n",
"b' n'\n",
"b' =='\n",
"b' 1'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' 1'\n",
"b'\\n'\n",
"b' '\n",
"b'else'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' fib'\n",
"b'('\n",
"b'n'\n",
"b'-'\n",
"b'1'\n",
"b')'\n",
"b' +'\n",
"b' fib'\n",
"b'('\n",
"b'n'\n",
"b'-'\n",
"b'2'\n",
"b')'\n",
"b'\\n'\n",
"b'\\n'\n",
"b'#'\n",
"b' Driver'\n",
"b' function'\n",
"b' to'\n",
"b' test'\n",
"b' above'\n",
"b' function'\n",
"b'\\n'\n",
"b'n'\n",
"b' ='\n",
"b' int'\n",
"b'('\n",
"b'input'\n",
"b'(\"'\n",
"b'Enter'\n",
"b' the'\n",
"b' number'\n",
"b':'\n",
"b' \"'\n",
"b'))'\n",
"b'\\n'\n",
"b'print'\n",
"b'('\n",
"b'f'\n",
"b'ib'\n",
"b'('\n",
"b'n'\n",
"b'))'\n",
"b'\\n'\n",
"b'\\n'\n",
"b'#'\n",
"b' This'\n",
"b' code'\n",
"b' is'\n",
"b' contributed'\n",
"b' by'\n",
"b' Nik'\n",
"b'h'\n",
"b'il'\n",
"b' Kumar'\n",
"b' Singh'\n",
"b'('\n",
"b'nick'\n",
"b'z'\n",
"b'uck'\n",
"b'_'\n",
"b'007'\n",
"b')'\n",
"b'\\n'\n"
]
}
],
"source": [
"import requests\n",
"\n",
"url = \"http://127.0.0.1:5543/generate_stream\"\n",
"obj = {\n",
" \"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n",
" \"max_generated_tokens\":100\n",
"}\n",
"\n",
"with requests.post(url, json=obj, stream=True) as r:\n",
" for chunk in r.iter_content(16): # or, for line in r.iter_lines():\n",
" print(chunk)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dfb0e1de-17e2-4ff2-aecb-22b6f607ef9b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" if n == 0:\n",
" return 0\n",
" elif n == 1:\n",
" return 1\n",
" else:\n",
" return fib(n-1) + fib(n-2)\n",
"\n",
"# Driver function to test above function\n",
"n = int(input(\"Enter the number: \"))\n",
"print(fib(n))\n",
"\n",
"# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n"
]
}
],
"source": [
"!curl 127.0.0.1:5543/generate_stream \\\n",
" -X POST \\\n",
" -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n",
" -H 'Content-Type: application/json'"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 13,