mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
built fastapi wrapper around router
This commit is contained in:
parent
ff3275c442
commit
09f4fc40a5
@ -1,4 +1,5 @@
|
||||
import fastapi, uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
from router import DeepSparseRouter, batching_task
|
||||
@ -13,21 +14,30 @@ def serve(
|
||||
host="0.0.0.0",
|
||||
port=5543
|
||||
):
|
||||
# setup router
|
||||
print("\n-------------------- Building Router --------------------\n")
|
||||
router = DeepSparseRouter(
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path
|
||||
)
|
||||
|
||||
# start background routing task
|
||||
print("\n-------------------- Starting Batching Task --------------------\n")
|
||||
batching_thread = Thread(target=batching_task, args=[router])
|
||||
batching_thread.start()
|
||||
router = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: fastapi.FastAPI):
|
||||
print("\n-------------------- Building Router --------------------\n")
|
||||
router = DeepSparseRouter(
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path
|
||||
)
|
||||
|
||||
print("\n-------------------- Launching App --------------------\n")
|
||||
app = fastapi.FastAPI()
|
||||
print("\n-------------------- Starting Batching Task --------------------\n")
|
||||
batching_thread = Thread(target=batching_task, args=[router])
|
||||
batching_thread.start()
|
||||
|
||||
print("\n-------------------- Launching App --------------------\n")
|
||||
yield
|
||||
|
||||
print("\n-------------------- Joining Batching Task --------------------\n")
|
||||
router.stop_batching_task()
|
||||
batching_task.join()
|
||||
|
||||
app = fastapi.FastAPI(lifespan=lifespan)
|
||||
|
||||
@app.get("/generate/{prompt}")
|
||||
async def generate(prompt:str):
|
||||
response_stream = Queue()
|
||||
@ -53,8 +63,6 @@ def serve(
|
||||
port=port,
|
||||
workers=1
|
||||
)
|
||||
|
||||
batching_thread.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
@ -4,49 +4,6 @@ from service.service import DeepSparseService
|
||||
from service.causal_lm import DeepSparseCausalLM
|
||||
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
|
||||
|
||||
# TODO: implement logic for maximum size of the queue based on memory usage
|
||||
class DeepSparseQueue:
|
||||
def __init__(self):
|
||||
self.next_request_id: int = 0
|
||||
self.next_batch_id: int = 0
|
||||
self.queue: Queue[GenerateRequest] = Queue()
|
||||
|
||||
def append(self, generate_request: GenerateRequest):
|
||||
self.queue.put(generate_request)
|
||||
|
||||
# TODO: enable multiple prefill requests in a batch
|
||||
def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
|
||||
|
||||
# if not blocking, return none if empty
|
||||
if not block and self.queue.empty():
|
||||
return None
|
||||
|
||||
# if block = True, this blocks until something ready
|
||||
# if block = False, the queue has data (if not an exception is raised)
|
||||
# while queue.empty() == False does not guarentee data
|
||||
# the queue is only subscribed to by one thread (this one)
|
||||
# since batching_task is the only function that calls next_batch
|
||||
generate_request = self.queue.get(block=block)
|
||||
generate_requests = {self.next_request_id: generate_request}
|
||||
|
||||
# format into request
|
||||
request = Request(
|
||||
id=self.next_request_id,
|
||||
prompt=generate_request.prompt,
|
||||
max_generated_tokens=generate_request.max_generated_tokens
|
||||
)
|
||||
self.next_request_id += 1
|
||||
|
||||
# format into batch
|
||||
batch = Batch(
|
||||
id = self.next_batch_id,
|
||||
requests=[request]
|
||||
)
|
||||
self.next_batch_id += 1
|
||||
|
||||
# return batch, generate_requests
|
||||
return (batch, generate_requests)
|
||||
|
||||
class DeepSparseRouter:
|
||||
def __init__(
|
||||
self,
|
||||
@ -70,10 +27,22 @@ class DeepSparseRouter:
|
||||
)
|
||||
|
||||
self.queue: DeepSparseQueue = DeepSparseQueue()
|
||||
self.batching_task_should_stop:bool = False
|
||||
|
||||
def submit_request(self, generate_request: GenerateRequest):
|
||||
self.queue.append(generate_request)
|
||||
|
||||
def stop_batching_task(self):
|
||||
# tell batching task to stop
|
||||
self.batching_task_should_stop = True
|
||||
|
||||
# unblock the batching task with a dummy request if blocked
|
||||
self.queue.append(GenerateRequest(
|
||||
prompt="dummy",
|
||||
max_generated_tokens=1,
|
||||
response_stream=Queue()
|
||||
))
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
batch: Batch,
|
||||
@ -136,16 +105,14 @@ class DeepSparseRouter:
|
||||
|
||||
# TODO: update to do more sophisticated logic as to when to do a prefill
|
||||
def batching_task(router: DeepSparseRouter):
|
||||
while True:
|
||||
# while not signaled to stop
|
||||
while not router.batching_task_should_stop:
|
||||
|
||||
# loop until no requests to process (note: this blocks if queue is empty)
|
||||
next_batch = router.queue.next_batch(block=True)
|
||||
while next_batch is not None:
|
||||
batch, generate_requests = next_batch
|
||||
|
||||
# HACK for development --- breaks out of the cycle
|
||||
if batch.requests[0].prompt == "stop":
|
||||
return
|
||||
|
||||
# run prefill
|
||||
cached_batch = router.prefill(
|
||||
batch=batch,
|
||||
@ -178,4 +145,47 @@ def batching_task(router: DeepSparseRouter):
|
||||
generate_requests=generate_requests
|
||||
)
|
||||
|
||||
next_batch = router.queue.next_batch(block=False)
|
||||
next_batch = router.queue.next_batch(block=False)
|
||||
|
||||
# TODO: implement logic for maximum size of the queue based on memory usage
|
||||
class DeepSparseQueue:
|
||||
def __init__(self):
|
||||
self.next_request_id: int = 0
|
||||
self.next_batch_id: int = 0
|
||||
self.queue: Queue[GenerateRequest] = Queue()
|
||||
|
||||
def append(self, generate_request: GenerateRequest):
|
||||
self.queue.put(generate_request)
|
||||
|
||||
# TODO: enable multiple prefill requests in a batch
|
||||
def next_batch(self, block=False) -> Optional[Tuple[Batch, Dict[int, GenerateRequest]]]:
|
||||
|
||||
# if not blocking, return none if empty
|
||||
if not block and self.queue.empty():
|
||||
return None
|
||||
|
||||
# if block = True, this blocks until something ready
|
||||
# if block = False, the queue has data (if not an exception is raised)
|
||||
# while queue.empty() == False does not guarentee data
|
||||
# the queue is only subscribed to by one thread (this one)
|
||||
# since batching_task is the only function that calls next_batch
|
||||
generate_request = self.queue.get(block=block)
|
||||
generate_requests = {self.next_request_id: generate_request}
|
||||
|
||||
# format into request
|
||||
request = Request(
|
||||
id=self.next_request_id,
|
||||
prompt=generate_request.prompt,
|
||||
max_generated_tokens=generate_request.max_generated_tokens
|
||||
)
|
||||
self.next_request_id += 1
|
||||
|
||||
# format into batch
|
||||
batch = Batch(
|
||||
id = self.next_batch_id,
|
||||
requests=[request]
|
||||
)
|
||||
self.next_batch_id += 1
|
||||
|
||||
# return batch, generate_requests
|
||||
return (batch, generate_requests)
|
175
server-dev.ipynb
175
server-dev.ipynb
@ -11,6 +11,181 @@
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "4ec49282-dafd-4d4b-af7a-a68af7fb0dc2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Driver function to test above function\\nn = int(input(\\\"Enter the number: \\\"))\\nprint(fib(n))\\n\\n# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\\n\""
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!curl 127.0.0.1:5543/generate \\\n",
|
||||
" -X POST \\\n",
|
||||
" -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n",
|
||||
" -H 'Content-Type: application/json'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "6ea583b1-e2d3-4f35-b87f-630d097a2628",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"b'\\n'\n",
|
||||
"b' '\n",
|
||||
"b'if'\n",
|
||||
"b' n'\n",
|
||||
"b' =='\n",
|
||||
"b' 0'\n",
|
||||
"b':'\n",
|
||||
"b'\\n'\n",
|
||||
"b' '\n",
|
||||
"b'return'\n",
|
||||
"b' 0'\n",
|
||||
"b'\\n'\n",
|
||||
"b' '\n",
|
||||
"b'el'\n",
|
||||
"b'if'\n",
|
||||
"b' n'\n",
|
||||
"b' =='\n",
|
||||
"b' 1'\n",
|
||||
"b':'\n",
|
||||
"b'\\n'\n",
|
||||
"b' '\n",
|
||||
"b'return'\n",
|
||||
"b' 1'\n",
|
||||
"b'\\n'\n",
|
||||
"b' '\n",
|
||||
"b'else'\n",
|
||||
"b':'\n",
|
||||
"b'\\n'\n",
|
||||
"b' '\n",
|
||||
"b'return'\n",
|
||||
"b' fib'\n",
|
||||
"b'('\n",
|
||||
"b'n'\n",
|
||||
"b'-'\n",
|
||||
"b'1'\n",
|
||||
"b')'\n",
|
||||
"b' +'\n",
|
||||
"b' fib'\n",
|
||||
"b'('\n",
|
||||
"b'n'\n",
|
||||
"b'-'\n",
|
||||
"b'2'\n",
|
||||
"b')'\n",
|
||||
"b'\\n'\n",
|
||||
"b'\\n'\n",
|
||||
"b'#'\n",
|
||||
"b' Driver'\n",
|
||||
"b' function'\n",
|
||||
"b' to'\n",
|
||||
"b' test'\n",
|
||||
"b' above'\n",
|
||||
"b' function'\n",
|
||||
"b'\\n'\n",
|
||||
"b'n'\n",
|
||||
"b' ='\n",
|
||||
"b' int'\n",
|
||||
"b'('\n",
|
||||
"b'input'\n",
|
||||
"b'(\"'\n",
|
||||
"b'Enter'\n",
|
||||
"b' the'\n",
|
||||
"b' number'\n",
|
||||
"b':'\n",
|
||||
"b' \"'\n",
|
||||
"b'))'\n",
|
||||
"b'\\n'\n",
|
||||
"b'print'\n",
|
||||
"b'('\n",
|
||||
"b'f'\n",
|
||||
"b'ib'\n",
|
||||
"b'('\n",
|
||||
"b'n'\n",
|
||||
"b'))'\n",
|
||||
"b'\\n'\n",
|
||||
"b'\\n'\n",
|
||||
"b'#'\n",
|
||||
"b' This'\n",
|
||||
"b' code'\n",
|
||||
"b' is'\n",
|
||||
"b' contributed'\n",
|
||||
"b' by'\n",
|
||||
"b' Nik'\n",
|
||||
"b'h'\n",
|
||||
"b'il'\n",
|
||||
"b' Kumar'\n",
|
||||
"b' Singh'\n",
|
||||
"b'('\n",
|
||||
"b'nick'\n",
|
||||
"b'z'\n",
|
||||
"b'uck'\n",
|
||||
"b'_'\n",
|
||||
"b'007'\n",
|
||||
"b')'\n",
|
||||
"b'\\n'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"url = \"http://127.0.0.1:5543/generate_stream\"\n",
|
||||
"obj = {\n",
|
||||
" \"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n",
|
||||
" \"max_generated_tokens\":100\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"with requests.post(url, json=obj, stream=True) as r:\n",
|
||||
" for chunk in r.iter_content(16): # or, for line in r.iter_lines():\n",
|
||||
" print(chunk)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "dfb0e1de-17e2-4ff2-aecb-22b6f607ef9b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
" if n == 0:\n",
|
||||
" return 0\n",
|
||||
" elif n == 1:\n",
|
||||
" return 1\n",
|
||||
" else:\n",
|
||||
" return fib(n-1) + fib(n-2)\n",
|
||||
"\n",
|
||||
"# Driver function to test above function\n",
|
||||
"n = int(input(\"Enter the number: \"))\n",
|
||||
"print(fib(n))\n",
|
||||
"\n",
|
||||
"# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!curl 127.0.0.1:5543/generate_stream \\\n",
|
||||
" -X POST \\\n",
|
||||
" -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n",
|
||||
" -H 'Content-Type: application/json'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
|
Loading…
Reference in New Issue
Block a user