added HTTP interface around the router

This commit is contained in:
rsnm2 2023-08-24 19:10:51 +00:00
parent a973cf4922
commit ff3275c442
5 changed files with 72 additions and 36 deletions

View File

@ -1,42 +1,60 @@
import uvicorn, fastapi
import fastapi, uvicorn
from threading import Thread
from queue import Queue
from router import DeepSparseRouter, batching_task
from utils import GenerateRequest
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
# setup router
router = DeepSparseRouter(
def serve(
model_path=MODEL_PATH,
tokenizer_path=TOKENIZER_PATH
)
# start background routing task
batching_thread = Thread(target=batching_task, args=[router])
batching_thread.start()
app = fastapi.FastAPI()
@app.post("/generate")
def generate(prompt:str, max_generated_tokens:int):
response_stream = Queue()
# submit request to the router
router.submit_request(
generate_request=GenerateRequest(
prompt=prompt,
max_generated_tokens=max_generated_tokens,
response_stream=response_stream
)
tokenizer_path=TOKENIZER_PATH,
host="0.0.0.0",
port=5543
):
# setup router
print("\n-------------------- Building Router --------------------\n")
router = DeepSparseRouter(
model_path=model_path,
tokenizer_path=tokenizer_path
)
response_string = prompt
generation = response_stream.get()
while not generation.stopped:
response_string += generation.token
generation = response_stream.get()
# start background routing task
print("\n-------------------- Starting Batching Task --------------------\n")
batching_thread = Thread(target=batching_task, args=[router])
batching_thread.start()
return generation
print("\n-------------------- Launching App --------------------\n")
app = fastapi.FastAPI()
@app.get("/generate/{prompt}")
async def generate(prompt:str):
response_stream = Queue()
router.submit_request(
GenerateRequest(
prompt=prompt,
max_generated_tokens=100,
response_stream=response_stream
)
)
response_string = prompt
generation = response_stream.get()
while not generation.stopped:
response_string += generation.token
generation = response_stream.get()
return response_string
uvicorn.run(
app,
host=host,
port=port,
workers=1
)
batching_thread.join()
if __name__ == "__main__":
serve()

View File

@ -1,8 +1,8 @@
from queue import Queue
from typing import List, Dict, Optional, Tuple
from server.deepsparse.service.service import DeepSparseService
from server.deepsparse.service.causal_lm import DeepSparseCausalLM
from server.deepsparse.utils import CachedBatch, Batch, Generation, GenerateRequest, Request
from service.service import DeepSparseService
from service.causal_lm import DeepSparseCausalLM
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
# TODO: implement logic for maximum size of the queue based on memory usage
class DeepSparseQueue:
@ -50,9 +50,9 @@ class DeepSparseQueue:
class DeepSparseRouter:
def __init__(
self,
service: Optional[DeepSparseService],
model_path: Optional[str],
tokenizer_path: Optional[str]
service: Optional[DeepSparseService] = None,
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None
):
assert (
service is not None or

View File

@ -3,8 +3,8 @@ from typing import List, Dict, Optional
from transformers import AutoTokenizer, PreTrainedTokenizerBase
import numpy as np
from server.deepsparse.service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
from server.deepsparse.utils import Request, Batch, CachedBatch, Generation
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
from utils import Request, Batch, CachedBatch, Generation
DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4

View File

@ -1,6 +1,6 @@
from typing import Dict, List, Tuple
from server.deepsparse.service.causal_lm import DeepSparseCausalLM, DeepSparseCausalLMBatch
from server.deepsparse.utils import Generation, CachedBatch, Batch
from service.causal_lm import DeepSparseCausalLM, DeepSparseCausalLMBatch
from utils import Generation, CachedBatch, Batch
class BatchCache:
def __init__(self):

View File

@ -11,6 +11,24 @@
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "6948b65b-8851-40f3-bf67-d7db2cf1955e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\""
]
}
],
"source": [
"!curl http://127.0.0.1:5543/generate/def%20fib%28n%29%3A"
]
},
{
"cell_type": "markdown",
"id": "a19786b8-e72c-43c1-964f-45d92fd171e9",