mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
added HTTP interface around the router
This commit is contained in:
parent
a973cf4922
commit
ff3275c442
@ -1,42 +1,60 @@
|
|||||||
import uvicorn, fastapi
|
import fastapi, uvicorn
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
|
||||||
from router import DeepSparseRouter, batching_task
|
from router import DeepSparseRouter, batching_task
|
||||||
from utils import GenerateRequest
|
from utils import GenerateRequest
|
||||||
|
|
||||||
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
|
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
|
||||||
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
|
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
|
||||||
|
|
||||||
# setup router
|
def serve(
|
||||||
router = DeepSparseRouter(
|
|
||||||
model_path=MODEL_PATH,
|
model_path=MODEL_PATH,
|
||||||
tokenizer_path=TOKENIZER_PATH
|
tokenizer_path=TOKENIZER_PATH,
|
||||||
)
|
host="0.0.0.0",
|
||||||
|
port=5543
|
||||||
# start background routing task
|
):
|
||||||
batching_thread = Thread(target=batching_task, args=[router])
|
# setup router
|
||||||
batching_thread.start()
|
print("\n-------------------- Building Router --------------------\n")
|
||||||
|
router = DeepSparseRouter(
|
||||||
app = fastapi.FastAPI()
|
model_path=model_path,
|
||||||
|
tokenizer_path=tokenizer_path
|
||||||
@app.post("/generate")
|
|
||||||
def generate(prompt:str, max_generated_tokens:int):
|
|
||||||
response_stream = Queue()
|
|
||||||
|
|
||||||
# submit request to the router
|
|
||||||
router.submit_request(
|
|
||||||
generate_request=GenerateRequest(
|
|
||||||
prompt=prompt,
|
|
||||||
max_generated_tokens=max_generated_tokens,
|
|
||||||
response_stream=response_stream
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response_string = prompt
|
# start background routing task
|
||||||
generation = response_stream.get()
|
print("\n-------------------- Starting Batching Task --------------------\n")
|
||||||
while not generation.stopped:
|
batching_thread = Thread(target=batching_task, args=[router])
|
||||||
response_string += generation.token
|
batching_thread.start()
|
||||||
generation = response_stream.get()
|
|
||||||
|
|
||||||
return generation
|
print("\n-------------------- Launching App --------------------\n")
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
|
||||||
|
@app.get("/generate/{prompt}")
|
||||||
|
async def generate(prompt:str):
|
||||||
|
response_stream = Queue()
|
||||||
|
router.submit_request(
|
||||||
|
GenerateRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
max_generated_tokens=100,
|
||||||
|
response_stream=response_stream
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response_string = prompt
|
||||||
|
generation = response_stream.get()
|
||||||
|
while not generation.stopped:
|
||||||
|
response_string += generation.token
|
||||||
|
generation = response_stream.get()
|
||||||
|
|
||||||
|
return response_string
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
workers=1
|
||||||
|
)
|
||||||
|
|
||||||
|
batching_thread.join()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
serve()
|
@ -1,8 +1,8 @@
|
|||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import List, Dict, Optional, Tuple
|
from typing import List, Dict, Optional, Tuple
|
||||||
from server.deepsparse.service.service import DeepSparseService
|
from service.service import DeepSparseService
|
||||||
from server.deepsparse.service.causal_lm import DeepSparseCausalLM
|
from service.causal_lm import DeepSparseCausalLM
|
||||||
from server.deepsparse.utils import CachedBatch, Batch, Generation, GenerateRequest, Request
|
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
|
||||||
|
|
||||||
# TODO: implement logic for maximum size of the queue based on memory usage
|
# TODO: implement logic for maximum size of the queue based on memory usage
|
||||||
class DeepSparseQueue:
|
class DeepSparseQueue:
|
||||||
@ -50,9 +50,9 @@ class DeepSparseQueue:
|
|||||||
class DeepSparseRouter:
|
class DeepSparseRouter:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
service: Optional[DeepSparseService],
|
service: Optional[DeepSparseService] = None,
|
||||||
model_path: Optional[str],
|
model_path: Optional[str] = None,
|
||||||
tokenizer_path: Optional[str]
|
tokenizer_path: Optional[str] = None
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
service is not None or
|
service is not None or
|
||||||
|
@ -3,8 +3,8 @@ from typing import List, Dict, Optional
|
|||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from server.deepsparse.service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
|
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
|
||||||
from server.deepsparse.utils import Request, Batch, CachedBatch, Generation
|
from utils import Request, Batch, CachedBatch, Generation
|
||||||
|
|
||||||
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
||||||
DEEPSPARSE_MULTITOKEN_LENGTH = 4
|
DEEPSPARSE_MULTITOKEN_LENGTH = 4
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
from server.deepsparse.service.causal_lm import DeepSparseCausalLM, DeepSparseCausalLMBatch
|
from service.causal_lm import DeepSparseCausalLM, DeepSparseCausalLMBatch
|
||||||
from server.deepsparse.utils import Generation, CachedBatch, Batch
|
from utils import Generation, CachedBatch, Batch
|
||||||
|
|
||||||
class BatchCache:
|
class BatchCache:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -11,6 +11,24 @@
|
|||||||
"%autoreload 2"
|
"%autoreload 2"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"id": "6948b65b-8851-40f3-bf67-d7db2cf1955e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\""
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!curl http://127.0.0.1:5543/generate/def%20fib%28n%29%3A"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "a19786b8-e72c-43c1-964f-45d92fd171e9",
|
"id": "a19786b8-e72c-43c1-964f-45d92fd171e9",
|
||||||
|
Loading…
Reference in New Issue
Block a user