refactored stopping criteria; started concept of GenerationParameters to control generation --- currently enabling passing max_new_tokens; next step --- expand next token chooser

This commit is contained in:
rsnm2 2023-08-25 15:26:16 +00:00
parent 02952c511f
commit 96f8365996
6 changed files with 124 additions and 222 deletions

View File

@ -1,68 +0,0 @@
import fastapi, uvicorn
from contextlib import asynccontextmanager
from threading import Thread
from queue import Queue
from router import DeepSparseRouter, batching_task
from utils import GenerateRequest
TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment"
MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx"
def serve(
model_path=MODEL_PATH,
tokenizer_path=TOKENIZER_PATH,
host="0.0.0.0",
port=5543
):
router = None
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
print("\n-------------------- Building Router --------------------\n")
router = DeepSparseRouter(
model_path=model_path,
tokenizer_path=tokenizer_path
)
print("\n-------------------- Starting Batching Task --------------------\n")
batching_thread = Thread(target=batching_task, args=[router])
batching_thread.start()
print("\n-------------------- Launching App --------------------\n")
yield
print("\n-------------------- Joining Batching Task --------------------\n")
router.stop_batching_task()
batching_task.join()
app = fastapi.FastAPI(lifespan=lifespan)
@app.get("/generate/{prompt}")
async def generate(prompt:str):
response_stream = Queue()
router.submit_request(
GenerateRequest(
prompt=prompt,
max_generated_tokens=100,
response_stream=response_stream
)
)
response_string = prompt
generation = response_stream.get()
while not generation.stopped:
response_string += generation.token
generation = response_stream.get()
return response_string
uvicorn.run(
app,
host=host,
port=port,
workers=1
)
if __name__ == "__main__":
serve()

View File

@ -2,7 +2,7 @@ from queue import Queue
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Optional, Tuple
from service.service import DeepSparseService from service.service import DeepSparseService
from service.causal_lm import DeepSparseCausalLM from service.causal_lm import DeepSparseCausalLM
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, StoppingCriteria
class DeepSparseRouter: class DeepSparseRouter:
def __init__( def __init__(
@ -38,14 +38,14 @@ class DeepSparseRouter:
# unblock the batching task with a dummy request if blocked # unblock the batching task with a dummy request if blocked
self.queue.append(GenerateRequest( self.queue.append(GenerateRequest(
prompt="dummy", inputs="dummy",
max_generated_tokens=1, max_new_tokens=1,
response_stream=Queue() response_stream=Queue()
)) ))
def prefill( def prefill(
self, self,
batch: Batch, batch: Batch,
generate_requests: Dict[int, GenerateRequest] generate_requests: Dict[int, GenerateRequest]
) -> Optional[CachedBatch]: ) -> Optional[CachedBatch]:
@ -175,8 +175,8 @@ class DeepSparseQueue:
# format into request # format into request
request = Request( request = Request(
id=self.next_request_id, id=self.next_request_id,
prompt=generate_request.prompt, inputs=generate_request.inputs,
max_generated_tokens=generate_request.max_generated_tokens max_new_tokens=generate_request.max_new_tokens
) )
self.next_request_id += 1 self.next_request_id += 1

View File

@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
import numpy as np import numpy as np
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
from utils import Request, Batch, CachedBatch, Generation from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria
DEEPSPARSE_SEQUENCE_LENGTH = 128 DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4 DEEPSPARSE_MULTITOKEN_LENGTH = 4
@ -16,6 +16,7 @@ class DeepSparseCausalLMBatch:
requests_idx_mapping: Dict[int,int] requests_idx_mapping: Dict[int,int]
input_ids_list: List[np.ndarray] input_ids_list: List[np.ndarray]
past_key_values_list: List[Optional[DeepSparsePastKeyValues]] past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
stopping_criteria_list: List[StoppingCriteria]
@classmethod @classmethod
def from_batch( def from_batch(
@ -27,34 +28,40 @@ class DeepSparseCausalLMBatch:
# parse batch # parse batch
requests_idx_mapping = {} requests_idx_mapping = {}
input_ids_list = [] input_ids_list = []
stopping_criteria_list = []
# setup tokenizer for deepsparse left padding
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
padding, truncation = "longest", False
# loop through items in the batch # loop through items in the batch
for idx, r in enumerate(batch.requests): for idx, r in enumerate(batch.requests):
requests_idx_mapping[r.id] = idx requests_idx_mapping[r.id] = idx
# setup inputs_ids, past_key_values # setup inputs_ids, stopping crtieria
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
r.prompt, r.inputs,
return_tensors="np", return_tensors="np",
padding=padding, padding="longest",
truncation=truncation, truncation=False,
return_token_type_ids=False, return_token_type_ids=False,
max_length=DEEPSPARSE_SEQUENCE_LENGTH max_length=DEEPSPARSE_SEQUENCE_LENGTH
) )
input_ids_list.append(tokenized_inputs["input_ids"]) input_ids_list.append(tokenized_inputs["input_ids"])
# deepsparse able to accept up to seq len tokens
num_input_tokens = tokenized_inputs["input_ids"].shape[1]
model_max_new_tokens = DEEPSPARSE_SEQUENCE_LENGTH - num_input_tokens
stopping_criteria_list.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=min(r.max_new_tokens, model_max_new_tokens)
)
)
return cls( return cls(
batch_id=batch.id, batch_id=batch.id,
requests=batch.requests, requests=batch.requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list, input_ids_list=input_ids_list,
past_key_values_list=[None] * len(batch.requests), past_key_values_list=[None] * len(batch.requests),
stopping_criteria_list=stopping_criteria_list
) )
def to_cached_batch(self) -> CachedBatch: def to_cached_batch(self) -> CachedBatch:
@ -75,6 +82,7 @@ class DeepSparseCausalLMBatch:
requests = [] requests = []
input_ids_list = [] input_ids_list = []
past_key_values_list = [] past_key_values_list = []
stopping_criteria_list = []
# loop through requests, keep ones that should remain # loop through requests, keep ones that should remain
for new_idx, request_id in enumerate(request_ids): for new_idx, request_id in enumerate(request_ids):
@ -86,12 +94,14 @@ class DeepSparseCausalLMBatch:
requests.append(self.requests[old_idx]) requests.append(self.requests[old_idx])
input_ids_list.append(self.input_ids_list[old_idx]) input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values_list[old_idx]) past_key_values_list.append(self.past_key_values_list[old_idx])
stopping_criteria_list.append(self.stopping_criteria_list[old_idx])
# update batch state # update batch state
self.requests = requests self.requests = requests
self.requests_idx_mapping = requests_idx_mapping self.requests_idx_mapping = requests_idx_mapping
self.input_ids_list = input_ids_list self.input_ids_list = input_ids_list
self.past_key_values_list = past_key_values_list self.past_key_values_list = past_key_values_list
self.stopping_criteria_list = stopping_criteria_list
return self return self
@ -104,6 +114,7 @@ class DeepSparseCausalLMBatch:
requests = [] requests = []
input_ids_list = [] input_ids_list = []
past_key_values_list = [] past_key_values_list = []
stopping_criteria_list = []
start_index = 0 start_index = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
@ -113,6 +124,7 @@ class DeepSparseCausalLMBatch:
requests.extend(batch.requests) requests.extend(batch.requests)
input_ids_list.extend(batch.input_ids_list) input_ids_list.extend(batch.input_ids_list)
past_key_values_list.extend(batch.past_key_values_list) past_key_values_list.extend(batch.past_key_values_list)
stopping_criteria_list.extend(batch.stopping_criteria_list)
# merge the request_id to index mapping # merge the request_id to index mapping
if i == 0: if i == 0:
@ -128,14 +140,15 @@ class DeepSparseCausalLMBatch:
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list, input_ids_list=input_ids_list,
past_key_values_list=past_key_values_list past_key_values_list=past_key_values_list,
stopping_criteria_list=stopping_criteria_list
) )
class DeepSparseCausalLM: class DeepSparseCausalLM:
def __init__( def __init__(
self, self,
model_path: str, model_path: str,
tokenizer_path: str, tokenizer_path: str,
): ):
# setup tokenizer # setup tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
@ -162,18 +175,6 @@ class DeepSparseCausalLM:
# grab logits for the last item in the sequence # grab logits for the last item in the sequence
# shape == (batch, seq, vocabulary_size) # shape == (batch, seq, vocabulary_size)
return np.argmax(logits[0,-1,:]) return np.argmax(logits[0,-1,:])
# TODO: switch to StoppingCriteria
def should_stop(
self,
num_tokens_processed: int,
generated_token_id: int
):
if num_tokens_processed >= self.model.sequence_length:
return True
if generated_token_id == self.tokenizer.eos_token_id:
return True
return False
def generate_token( def generate_token(
self, self,
@ -188,13 +189,15 @@ class DeepSparseCausalLM:
# b) sample and check stopping criteria # b) sample and check stopping criteria
# c) create generation # c) create generation
# d) update batch # d) update batch
for i, (request, input_ids, past_key_values,) in enumerate( for i, (request, input_ids, past_key_values,stopping_criteria,) in enumerate(
zip( zip(
batch.requests, batch.requests,
batch.input_ids_list, batch.input_ids_list,
batch.past_key_values_list batch.past_key_values_list,
batch.stopping_criteria_list,
) )
): ):
# assert input_ids is b=1
assert len(input_ids.shape) == 2 assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1 assert input_ids.shape[0] == 1
@ -205,10 +208,8 @@ class DeepSparseCausalLM:
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now) # TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
generated_token_id = self.sample_token(logits) generated_token_id = self.sample_token(logits)
generated_token = self.tokenizer.decode(generated_token_id) generated_token = self.tokenizer.decode(generated_token_id)
stop = self.should_stop(
num_tokens_processed=input_ids.shape[1] + 1, stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id)
generated_token_id = generated_token_id
)
if not stop: if not stop:
all_stopped = False all_stopped = False
@ -217,11 +218,12 @@ class DeepSparseCausalLM:
request_id=request.id, request_id=request.id,
token=generated_token, token=generated_token,
token_id=generated_token_id, token_id=generated_token_id,
stopped=stop stopped=stop,
finish_reason=finish_reason
)) ))
# d) update batch # d) update batch
# TODO: this does not occur in place) # TODO: this does not occur in place
assert len(batch.input_ids_list[i].shape) == 2 assert len(batch.input_ids_list[i].shape) == 2
assert batch.input_ids_list[i].shape[0] == 1 assert batch.input_ids_list[i].shape[0] == 1
batch.input_ids_list[i] = np.append( batch.input_ids_list[i] = np.append(

View File

@ -56,7 +56,7 @@ class DeepSparseService:
assert len(generations) == 1 assert len(generations) == 1
self.cache.set(next_ds_batch) self.cache.set(next_ds_batch)
return generations[0], next_ds_batch.to_cached_batch() return generations[0], (next_ds_batch.to_cached_batch() if next_ds_batch else None)
def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
assert len(batches) != 0, "Must provide at least one batch" assert len(batches) != 0, "Must provide at least one batch"

View File

@ -1,12 +1,38 @@
from dataclasses import dataclass from dataclasses import dataclass
from queue import Queue from queue import Queue
from enum import Enum
from typing import List, Optional from typing import List, Optional
class FinishReason(Enum):
FINISH_REASON_LENGTH = 1
FINISH_REASON_EOS_TOKEN = 2
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
max_new_tokens: int,
):
assert max_new_tokens > 0
self.max_new_tokens = max_new_tokens
self.eos_token_id = eos_token_id
self.current_tokens = 0
def __call__(self, generated_token_id:int):
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if generated_token_id == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
return False, None
@dataclass @dataclass
class Request: class Request:
id: int id: int
prompt: str inputs: str
max_generated_tokens: int max_new_tokens: int
@dataclass @dataclass
class Batch: class Batch:
@ -27,9 +53,10 @@ class Generation:
token: Optional[str] token: Optional[str]
token_id: Optional[str] token_id: Optional[str]
stopped: bool stopped: bool
finish_reason: FinishReason = None
@dataclass @dataclass
class GenerateRequest: class GenerateRequest:
prompt: str inputs: str
max_generated_tokens: int max_new_tokens: int
response_stream: Queue[Generation] response_stream: Queue[Generation]

View File

@ -13,7 +13,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 17,
"id": "4ec49282-dafd-4d4b-af7a-a68af7fb0dc2", "id": "4ec49282-dafd-4d4b-af7a-a68af7fb0dc2",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -21,136 +21,77 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Driver function to test above function\\nn = int(input(\\\"Enter the number: \\\"))\\nprint(fib(n))\\n\\n# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\\n\"" "\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n \""
] ]
} }
], ],
"source": [ "source": [
"!curl 127.0.0.1:5543/generate \\\n", "!curl 127.0.0.1:5543/generate \\\n",
" -X POST \\\n", " -X POST \\\n",
" -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n", " -d '{\"inputs\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"generation_parameters\":{\"max_new_tokens\":10}}' \\\n",
" -H 'Content-Type: application/json'" " -H 'Content-Type: application/json'"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 20,
"id": "6ea583b1-e2d3-4f35-b87f-630d097a2628", "id": "6ea583b1-e2d3-4f35-b87f-630d097a2628",
"metadata": {}, "metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from threading import Thread\n",
"\n",
"url = \"http://127.0.0.1:5543/generate\"\n",
"# sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\"\n",
"sequence = \"def fib(n):\"\n",
"\n",
"def request_task(max_new_tokens):\n",
" obj = {\n",
" \"inputs\":sequence,\n",
" \"generation_parameters\": {\n",
" \"max_new_tokens\":max_new_tokens\n",
" }\n",
" }\n",
" with requests.post(url, json=obj) as r:\n",
" print(max_new_tokens)\n",
" print(r.text)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "9dec4413-afea-444a-90b7-c98b450d5fcc",
"metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"b'\\n'\n", "100\n",
"b' '\n", "\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef\"\n",
"b'if'\n", "200\n",
"b' n'\n", "\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\"\n",
"b' =='\n", "300\n",
"b' 0'\n", "\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\"\n"
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' 0'\n",
"b'\\n'\n",
"b' '\n",
"b'el'\n",
"b'if'\n",
"b' n'\n",
"b' =='\n",
"b' 1'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' 1'\n",
"b'\\n'\n",
"b' '\n",
"b'else'\n",
"b':'\n",
"b'\\n'\n",
"b' '\n",
"b'return'\n",
"b' fib'\n",
"b'('\n",
"b'n'\n",
"b'-'\n",
"b'1'\n",
"b')'\n",
"b' +'\n",
"b' fib'\n",
"b'('\n",
"b'n'\n",
"b'-'\n",
"b'2'\n",
"b')'\n",
"b'\\n'\n",
"b'\\n'\n",
"b'#'\n",
"b' Driver'\n",
"b' function'\n",
"b' to'\n",
"b' test'\n",
"b' above'\n",
"b' function'\n",
"b'\\n'\n",
"b'n'\n",
"b' ='\n",
"b' int'\n",
"b'('\n",
"b'input'\n",
"b'(\"'\n",
"b'Enter'\n",
"b' the'\n",
"b' number'\n",
"b':'\n",
"b' \"'\n",
"b'))'\n",
"b'\\n'\n",
"b'print'\n",
"b'('\n",
"b'f'\n",
"b'ib'\n",
"b'('\n",
"b'n'\n",
"b'))'\n",
"b'\\n'\n",
"b'\\n'\n",
"b'#'\n",
"b' This'\n",
"b' code'\n",
"b' is'\n",
"b' contributed'\n",
"b' by'\n",
"b' Nik'\n",
"b'h'\n",
"b'il'\n",
"b' Kumar'\n",
"b' Singh'\n",
"b'('\n",
"b'nick'\n",
"b'z'\n",
"b'uck'\n",
"b'_'\n",
"b'007'\n",
"b')'\n",
"b'\\n'\n"
] ]
} }
], ],
"source": [ "source": [
"import requests\n", "# max_new_tokens_lst = [50, 10, 100, 25, 15]\n",
"max_new_tokens_lst = [100, 200, 300]\n",
"\n", "\n",
"url = \"http://127.0.0.1:5543/generate_stream\"\n", "request_ts = [\n",
"obj = {\n", " Thread(target=request_task, args=[max_new_tokens]) for max_new_tokens in max_new_tokens_lst\n",
" \"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n", "]\n",
" \"max_generated_tokens\":100\n",
"}\n",
"\n", "\n",
"with requests.post(url, json=obj, stream=True) as r:\n", "import time\n",
" for chunk in r.iter_content(16): # or, for line in r.iter_lines():\n", "for request_t in request_ts:\n",
" print(chunk)" " request_t.start()\n",
" time.sleep(0.1)\n",
"\n",
"for request_t in request_ts:\n",
" request_t.join()"
] ]
}, },
{ {