refactored stopping criteria; started concept of GenerationParameters to control generation --- currently enabling passing max_new_tokens; next step --- expand next token chooser

This commit is contained in:
rsnm2 2023-08-25 15:26:16 +00:00
parent 02952c511f
commit 96f8365996
6 changed files with 124 additions and 222 deletions

View File

@ -1,68 +0,0 @@
import fastapi, uvicorn
from contextlib import asynccontextmanager
from threading import Thread
from queue import Queue
from router import DeepSparseRouter, batching_task
from utils import GenerateRequest
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
def serve(
model_path=MODEL_PATH,
tokenizer_path=TOKENIZER_PATH,
host="0.0.0.0",
port=5543
):
router = None
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
print("\n-------------------- Building Router --------------------\n")
router = DeepSparseRouter(
model_path=model_path,
tokenizer_path=tokenizer_path
)
print("\n-------------------- Starting Batching Task --------------------\n")
batching_thread = Thread(target=batching_task, args=[router])
batching_thread.start()
print("\n-------------------- Launching App --------------------\n")
yield
print("\n-------------------- Joining Batching Task --------------------\n")
router.stop_batching_task()
batching_task.join()
app = fastapi.FastAPI(lifespan=lifespan)
@app.get("/generate/{prompt}")
async def generate(prompt:str):
response_stream = Queue()
router.submit_request(
GenerateRequest(
prompt=prompt,
max_generated_tokens=100,
response_stream=response_stream
)
)
response_string = prompt
generation = response_stream.get()
while not generation.stopped:
response_string += generation.token
generation = response_stream.get()
return response_string
uvicorn.run(
app,
host=host,
port=port,
workers=1
)
if __name__ == "__main__":
serve()

View File

@ -2,7 +2,7 @@ from queue import Queue
from typing import List, Dict, Optional, Tuple
from service.service import DeepSparseService
from service.causal_lm import DeepSparseCausalLM
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, StoppingCriteria
class DeepSparseRouter:
def __init__(
@ -38,14 +38,14 @@ class DeepSparseRouter:
# unblock the batching task with a dummy request if blocked
self.queue.append(GenerateRequest(
prompt="dummy",
max_generated_tokens=1,
inputs="dummy",
max_new_tokens=1,
response_stream=Queue()
))
def prefill(
self,
batch: Batch,
self,
batch: Batch,
generate_requests: Dict[int, GenerateRequest]
) -> Optional[CachedBatch]:
@ -175,8 +175,8 @@ class DeepSparseQueue:
# format into request
request = Request(
id=self.next_request_id,
prompt=generate_request.prompt,
max_generated_tokens=generate_request.max_generated_tokens
inputs=generate_request.inputs,
max_new_tokens=generate_request.max_new_tokens
)
self.next_request_id += 1

View File

@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
import numpy as np
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
from utils import Request, Batch, CachedBatch, Generation
from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria
DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4
@ -16,6 +16,7 @@ class DeepSparseCausalLMBatch:
requests_idx_mapping: Dict[int,int]
input_ids_list: List[np.ndarray]
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
stopping_criteria_list: List[StoppingCriteria]
@classmethod
def from_batch(
@ -27,34 +28,40 @@ class DeepSparseCausalLMBatch:
# parse batch
requests_idx_mapping = {}
input_ids_list = []
# setup tokenizer for deepsparse left padding
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
padding, truncation = "longest", False
stopping_criteria_list = []
# loop through items in the batch
for idx, r in enumerate(batch.requests):
requests_idx_mapping[r.id] = idx
# setup inputs_ids, past_key_values
# setup inputs_ids, stopping crtieria
tokenized_inputs = tokenizer(
r.prompt,
r.inputs,
return_tensors="np",
padding=padding,
truncation=truncation,
padding="longest",
truncation=False,
return_token_type_ids=False,
max_length=DEEPSPARSE_SEQUENCE_LENGTH
)
input_ids_list.append(tokenized_inputs["input_ids"])
# deepsparse able to accept up to seq len tokens
num_input_tokens = tokenized_inputs["input_ids"].shape[1]
model_max_new_tokens = DEEPSPARSE_SEQUENCE_LENGTH - num_input_tokens
stopping_criteria_list.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=min(r.max_new_tokens, model_max_new_tokens)
)
)
return cls(
batch_id=batch.id,
requests=batch.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=[None] * len(batch.requests),
stopping_criteria_list=stopping_criteria_list
)
def to_cached_batch(self) -> CachedBatch:
@ -75,6 +82,7 @@ class DeepSparseCausalLMBatch:
requests = []
input_ids_list = []
past_key_values_list = []
stopping_criteria_list = []
# loop through requests, keep ones that should remain
for new_idx, request_id in enumerate(request_ids):
@ -86,12 +94,14 @@ class DeepSparseCausalLMBatch:
requests.append(self.requests[old_idx])
input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values_list[old_idx])
stopping_criteria_list.append(self.stopping_criteria_list[old_idx])
# update batch state
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids_list = input_ids_list
self.past_key_values_list = past_key_values_list
self.stopping_criteria_list = stopping_criteria_list
return self
@ -104,6 +114,7 @@ class DeepSparseCausalLMBatch:
requests = []
input_ids_list = []
past_key_values_list = []
stopping_criteria_list = []
start_index = 0
for i, batch in enumerate(batches):
@ -113,6 +124,7 @@ class DeepSparseCausalLMBatch:
requests.extend(batch.requests)
input_ids_list.extend(batch.input_ids_list)
past_key_values_list.extend(batch.past_key_values_list)
stopping_criteria_list.extend(batch.stopping_criteria_list)
# merge the request_id to index mapping
if i == 0:
@ -128,14 +140,15 @@ class DeepSparseCausalLMBatch:
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=past_key_values_list
past_key_values_list=past_key_values_list,
stopping_criteria_list=stopping_criteria_list
)
class DeepSparseCausalLM:
def __init__(
self,
model_path: str,
tokenizer_path: str,
self,
model_path: str,
tokenizer_path: str,
):
# setup tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
@ -162,18 +175,6 @@ class DeepSparseCausalLM:
# grab logits for the last item in the sequence
# shape == (batch, seq, vocabulary_size)
return np.argmax(logits[0,-1,:])
# TODO: switch to StoppingCriteria
def should_stop(
self,
num_tokens_processed: int,
generated_token_id: int
):
if num_tokens_processed >= self.model.sequence_length:
return True
if generated_token_id == self.tokenizer.eos_token_id:
return True
return False
def generate_token(
self,
@ -188,13 +189,15 @@ class DeepSparseCausalLM:
# b) sample and check stopping criteria
# c) create generation
# d) update batch
for i, (request, input_ids, past_key_values,) in enumerate(
for i, (request, input_ids, past_key_values,stopping_criteria,) in enumerate(
zip(
batch.requests,
batch.input_ids_list,
batch.past_key_values_list
batch.past_key_values_list,
batch.stopping_criteria_list,
)
):
# assert input_ids is b=1
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
@ -205,10 +208,8 @@ class DeepSparseCausalLM:
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
generated_token_id = self.sample_token(logits)
generated_token = self.tokenizer.decode(generated_token_id)
stop = self.should_stop(
num_tokens_processed=input_ids.shape[1] + 1,
generated_token_id = generated_token_id
)
stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id)
if not stop:
all_stopped = False
@ -217,11 +218,12 @@ class DeepSparseCausalLM:
request_id=request.id,
token=generated_token,
token_id=generated_token_id,
stopped=stop
stopped=stop,
finish_reason=finish_reason
))
# d) update batch
# TODO: this does not occur in place)
# TODO: this does not occur in place
assert len(batch.input_ids_list[i].shape) == 2
assert batch.input_ids_list[i].shape[0] == 1
batch.input_ids_list[i] = np.append(

View File

@ -56,7 +56,7 @@ class DeepSparseService:
assert len(generations) == 1
self.cache.set(next_ds_batch)
return generations[0], next_ds_batch.to_cached_batch()
return generations[0], (next_ds_batch.to_cached_batch() if next_ds_batch else None)
def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
assert len(batches) != 0, "Must provide at least one batch"

View File

@ -1,12 +1,38 @@
from dataclasses import dataclass
from queue import Queue
from enum import Enum
from typing import List, Optional
class FinishReason(Enum):
FINISH_REASON_LENGTH = 1
FINISH_REASON_EOS_TOKEN = 2
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
max_new_tokens: int,
):
assert max_new_tokens > 0
self.max_new_tokens = max_new_tokens
self.eos_token_id = eos_token_id
self.current_tokens = 0
def __call__(self, generated_token_id:int):
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if generated_token_id == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
return False, None
@dataclass
class Request:
id: int
prompt: str
max_generated_tokens: int
inputs: str
max_new_tokens: int
@dataclass
class Batch:
@ -27,9 +53,10 @@ class Generation:
token: Optional[str]
token_id: Optional[str]
stopped: bool
finish_reason: FinishReason = None
@dataclass
class GenerateRequest:
prompt: str
max_generated_tokens: int
inputs: str
max_new_tokens: int
response_stream: Queue[Generation]

View File

@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 17,
"id": "4ec49282-dafd-4d4b-af7a-a68af7fb0dc2",
"metadata": {},
"outputs": [
@ -21,136 +21,77 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Driver function to test above function\\nn = int(input(\\\"Enter the number: \\\"))\\nprint(fib(n))\\n\\n# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\\n\""
"\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n \""
]
}
],
"source": [
"!curl 127.0.0.1:5543/generate \\\n",
" -X POST \\\n",
" -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n",
" -d '{\"inputs\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"generation_parameters\":{\"max_new_tokens\":10}}' \\\n",
" -H 'Content-Type: application/json'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 20,
"id": "6ea583b1-e2d3-4f35-b87f-630d097a2628",
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from threading import Thread\n",
"\n",
"url = \"http://127.0.0.1:5543/generate\"\n",
"# sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\"\n",
"sequence = \"def fib(n):\"\n",
"\n",
"def request_task(max_new_tokens):\n",
" obj = {\n",
" \"inputs\":sequence,\n",
" \"generation_parameters\": {\n",
" \"max_new_tokens\":max_new_tokens\n",
" }\n",
" }\n",
" with requests.post(url, json=obj) as r:\n",
" print(max_new_tokens)\n",
" print(r.text)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "9dec4413-afea-444a-90b7-c98b450d5fcc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b'\\n'\n",
"b' '\n",
"b'if'\n",
"b' n'\n",
"b' =='\n",
"b' 0'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' 0'\n",
"b'\\n'\n",
"b' '\n",
"b'el'\n",
"b'if'\n",
"b' n'\n",
"b' =='\n",
"b' 1'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' 1'\n",
"b'\\n'\n",
"b' '\n",
"b'else'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' fib'\n",
"b'('\n",
"b'n'\n",
"b'-'\n",
"b'1'\n",
"b')'\n",
"b' +'\n",
"b' fib'\n",
"b'('\n",
"b'n'\n",
"b'-'\n",
"b'2'\n",
"b')'\n",
"b'\\n'\n",
"b'\\n'\n",
"b'#'\n",
"b' Driver'\n",
"b' function'\n",
"b' to'\n",
"b' test'\n",
"b' above'\n",
"b' function'\n",
"b'\\n'\n",
"b'n'\n",
"b' ='\n",
"b' int'\n",
"b'('\n",
"b'input'\n",
"b'(\"'\n",
"b'Enter'\n",
"b' the'\n",
"b' number'\n",
"b':'\n",
"b' \"'\n",
"b'))'\n",
"b'\\n'\n",
"b'print'\n",
"b'('\n",
"b'f'\n",
"b'ib'\n",
"b'('\n",
"b'n'\n",
"b'))'\n",
"b'\\n'\n",
"b'\\n'\n",
"b'#'\n",
"b' This'\n",
"b' code'\n",
"b' is'\n",
"b' contributed'\n",
"b' by'\n",
"b' Nik'\n",
"b'h'\n",
"b'il'\n",
"b' Kumar'\n",
"b' Singh'\n",
"b'('\n",
"b'nick'\n",
"b'z'\n",
"b'uck'\n",
"b'_'\n",
"b'007'\n",
"b')'\n",
"b'\\n'\n"
"100\n",
"\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef\"\n",
"200\n",
"\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\"\n",
"300\n",
"\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\"\n"
]
}
],
"source": [
"import requests\n",
"# max_new_tokens_lst = [50, 10, 100, 25, 15]\n",
"max_new_tokens_lst = [100, 200, 300]\n",
"\n",
"url = \"http://127.0.0.1:5543/generate_stream\"\n",
"obj = {\n",
" \"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n",
" \"max_generated_tokens\":100\n",
"}\n",
"request_ts = [\n",
" Thread(target=request_task, args=[max_new_tokens]) for max_new_tokens in max_new_tokens_lst\n",
"]\n",
"\n",
"with requests.post(url, json=obj, stream=True) as r:\n",
" for chunk in r.iter_content(16): # or, for line in r.iter_lines():\n",
" print(chunk)"
"import time\n",
"for request_t in request_ts:\n",
" request_t.start()\n",
" time.sleep(0.1)\n",
"\n",
"for request_t in request_ts:\n",
" request_t.join()"
]
},
{