mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
added HTTP interface around the router
This commit is contained in:
parent
a973cf4922
commit
ff3275c442
@ -1,42 +1,60 @@
|
||||
import uvicorn, fastapi
|
||||
import fastapi, uvicorn
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
|
||||
from router import DeepSparseRouter, batching_task
|
||||
from utils import GenerateRequest
|
||||
|
||||
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
|
||||
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
|
||||
|
||||
# setup router
|
||||
router = DeepSparseRouter(
|
||||
def serve(
|
||||
model_path=MODEL_PATH,
|
||||
tokenizer_path=TOKENIZER_PATH
|
||||
)
|
||||
tokenizer_path=TOKENIZER_PATH,
|
||||
host="0.0.0.0",
|
||||
port=5543
|
||||
):
|
||||
# setup router
|
||||
print("\n-------------------- Building Router --------------------\n")
|
||||
router = DeepSparseRouter(
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path
|
||||
)
|
||||
|
||||
# start background routing task
|
||||
batching_thread = Thread(target=batching_task, args=[router])
|
||||
batching_thread.start()
|
||||
# start background routing task
|
||||
print("\n-------------------- Starting Batching Task --------------------\n")
|
||||
batching_thread = Thread(target=batching_task, args=[router])
|
||||
batching_thread.start()
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
print("\n-------------------- Launching App --------------------\n")
|
||||
app = fastapi.FastAPI()
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt:str, max_generated_tokens:int):
|
||||
response_stream = Queue()
|
||||
|
||||
# submit request to the router
|
||||
router.submit_request(
|
||||
generate_request=GenerateRequest(
|
||||
prompt=prompt,
|
||||
max_generated_tokens=max_generated_tokens,
|
||||
response_stream=response_stream
|
||||
@app.get("/generate/{prompt}")
|
||||
async def generate(prompt:str):
|
||||
response_stream = Queue()
|
||||
router.submit_request(
|
||||
GenerateRequest(
|
||||
prompt=prompt,
|
||||
max_generated_tokens=100,
|
||||
response_stream=response_stream
|
||||
)
|
||||
)
|
||||
|
||||
response_string = prompt
|
||||
generation = response_stream.get()
|
||||
while not generation.stopped:
|
||||
response_string += generation.token
|
||||
generation = response_stream.get()
|
||||
|
||||
return response_string
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
workers=1
|
||||
)
|
||||
|
||||
response_string = prompt
|
||||
generation = response_stream.get()
|
||||
while not generation.stopped:
|
||||
response_string += generation.token
|
||||
generation = response_stream.get()
|
||||
batching_thread.join()
|
||||
|
||||
return generation
|
||||
if __name__ == "__main__":
|
||||
serve()
|
@ -1,8 +1,8 @@
|
||||
from queue import Queue
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from server.deepsparse.service.service import DeepSparseService
|
||||
from server.deepsparse.service.causal_lm import DeepSparseCausalLM
|
||||
from server.deepsparse.utils import CachedBatch, Batch, Generation, GenerateRequest, Request
|
||||
from service.service import DeepSparseService
|
||||
from service.causal_lm import DeepSparseCausalLM
|
||||
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
|
||||
|
||||
# TODO: implement logic for maximum size of the queue based on memory usage
|
||||
class DeepSparseQueue:
|
||||
@ -50,9 +50,9 @@ class DeepSparseQueue:
|
||||
class DeepSparseRouter:
|
||||
def __init__(
|
||||
self,
|
||||
service: Optional[DeepSparseService],
|
||||
model_path: Optional[str],
|
||||
tokenizer_path: Optional[str]
|
||||
service: Optional[DeepSparseService] = None,
|
||||
model_path: Optional[str] = None,
|
||||
tokenizer_path: Optional[str] = None
|
||||
):
|
||||
assert (
|
||||
service is not None or
|
||||
|
@ -3,8 +3,8 @@ from typing import List, Dict, Optional
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
import numpy as np
|
||||
|
||||
from server.deepsparse.service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
|
||||
from server.deepsparse.utils import Request, Batch, CachedBatch, Generation
|
||||
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
|
||||
from utils import Request, Batch, CachedBatch, Generation
|
||||
|
||||
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
||||
DEEPSPARSE_MULTITOKEN_LENGTH = 4
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from server.deepsparse.service.causal_lm import DeepSparseCausalLM, DeepSparseCausalLMBatch
|
||||
from server.deepsparse.utils import Generation, CachedBatch, Batch
|
||||
from service.causal_lm import DeepSparseCausalLM, DeepSparseCausalLMBatch
|
||||
from utils import Generation, CachedBatch, Batch
|
||||
|
||||
class BatchCache:
|
||||
def __init__(self):
|
||||
|
@ -11,6 +11,24 @@
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "6948b65b-8851-40f3-bf67-d7db2cf1955e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\""
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!curl http://127.0.0.1:5543/generate/def%20fib%28n%29%3A"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a19786b8-e72c-43c1-964f-45d92fd171e9",
|
||||
|
Loading…
Reference in New Issue
Block a user