From 96f83659965fe326c863632e9056a8a9efcfad5a Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Fri, 25 Aug 2023 15:26:16 +0000 Subject: [PATCH] refactored stopping criteria; started concept of GenerationParameters to control generation --- currently enabling passing max_new_tokens; next step --- expand next token chooser --- deepsparse/main.py | 68 -------------- deepsparse/router.py | 14 +-- deepsparse/service/causal_lm.py | 72 +++++++-------- deepsparse/service/service.py | 2 +- deepsparse/utils.py | 35 +++++++- server-dev.ipynb | 155 ++++++++++---------------------- 6 files changed, 124 insertions(+), 222 deletions(-) delete mode 100644 deepsparse/main.py diff --git a/deepsparse/main.py b/deepsparse/main.py deleted file mode 100644 index 3ce9a532..00000000 --- a/deepsparse/main.py +++ /dev/null @@ -1,68 +0,0 @@ -import fastapi, uvicorn -from contextlib import asynccontextmanager -from threading import Thread -from queue import Queue -from router import DeepSparseRouter, batching_task -from utils import GenerateRequest - -TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment" -MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx" - -def serve( - model_path=MODEL_PATH, - tokenizer_path=TOKENIZER_PATH, - host="0.0.0.0", - port=5543 -): - - router = None - - @asynccontextmanager - async def lifespan(app: fastapi.FastAPI): - print("\n-------------------- Building Router --------------------\n") - router = DeepSparseRouter( - model_path=model_path, - tokenizer_path=tokenizer_path - ) - - print("\n-------------------- Starting Batching Task --------------------\n") - batching_thread = Thread(target=batching_task, args=[router]) - batching_thread.start() - - print("\n-------------------- Launching App --------------------\n") - yield - - print("\n-------------------- Joining Batching Task --------------------\n") - router.stop_batching_task() - batching_task.join() - - app = fastapi.FastAPI(lifespan=lifespan) - - @app.get("/generate/{prompt}") - async def generate(prompt:str): - response_stream = Queue() - router.submit_request( - GenerateRequest( - prompt=prompt, - max_generated_tokens=100, - response_stream=response_stream - ) - ) - - response_string = prompt - generation = response_stream.get() - while not generation.stopped: - response_string += generation.token - generation = response_stream.get() - - return response_string - - uvicorn.run( - app, - host=host, - port=port, - workers=1 - ) - -if __name__ == "__main__": - serve() \ No newline at end of file diff --git a/deepsparse/router.py b/deepsparse/router.py index 85e5099e..a5d0b2a7 100644 --- a/deepsparse/router.py +++ b/deepsparse/router.py @@ -2,7 +2,7 @@ from queue import Queue from typing import List, Dict, Optional, Tuple from service.service import DeepSparseService from service.causal_lm import DeepSparseCausalLM -from utils import CachedBatch, Batch, Generation, GenerateRequest, Request +from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, StoppingCriteria class DeepSparseRouter: def __init__( @@ -38,14 +38,14 @@ class DeepSparseRouter: # unblock the batching task with a dummy request if blocked self.queue.append(GenerateRequest( - prompt="dummy", - max_generated_tokens=1, + inputs="dummy", + max_new_tokens=1, response_stream=Queue() )) def prefill( - self, - batch: Batch, + self, + batch: Batch, generate_requests: Dict[int, GenerateRequest] ) -> Optional[CachedBatch]: @@ -175,8 +175,8 @@ class DeepSparseQueue: # format into request request = Request( id=self.next_request_id, - prompt=generate_request.prompt, - max_generated_tokens=generate_request.max_generated_tokens + inputs=generate_request.inputs, + max_new_tokens=generate_request.max_new_tokens ) self.next_request_id += 1 diff --git a/deepsparse/service/causal_lm.py b/deepsparse/service/causal_lm.py index 8aee72a9..b08a6bdf 100644 --- a/deepsparse/service/causal_lm.py +++ b/deepsparse/service/causal_lm.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase import numpy as np from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel -from utils import Request, Batch, CachedBatch, Generation +from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria DEEPSPARSE_SEQUENCE_LENGTH = 128 DEEPSPARSE_MULTITOKEN_LENGTH = 4 @@ -16,6 +16,7 @@ class DeepSparseCausalLMBatch: requests_idx_mapping: Dict[int,int] input_ids_list: List[np.ndarray] past_key_values_list: List[Optional[DeepSparsePastKeyValues]] + stopping_criteria_list: List[StoppingCriteria] @classmethod def from_batch( @@ -27,34 +28,40 @@ class DeepSparseCausalLMBatch: # parse batch requests_idx_mapping = {} input_ids_list = [] - - # setup tokenizer for deepsparse left padding - tokenizer.padding_side = "left" - if not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.eos_token - padding, truncation = "longest", False + stopping_criteria_list = [] # loop through items in the batch for idx, r in enumerate(batch.requests): requests_idx_mapping[r.id] = idx - # setup inputs_ids, past_key_values + # setup inputs_ids, stopping crtieria tokenized_inputs = tokenizer( - r.prompt, + r.inputs, return_tensors="np", - padding=padding, - truncation=truncation, + padding="longest", + truncation=False, return_token_type_ids=False, max_length=DEEPSPARSE_SEQUENCE_LENGTH ) input_ids_list.append(tokenized_inputs["input_ids"]) + # deepsparse able to accept up to seq len tokens + num_input_tokens = tokenized_inputs["input_ids"].shape[1] + model_max_new_tokens = DEEPSPARSE_SEQUENCE_LENGTH - num_input_tokens + stopping_criteria_list.append( + StoppingCriteria( + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=min(r.max_new_tokens, model_max_new_tokens) + ) + ) + return cls( batch_id=batch.id, requests=batch.requests, requests_idx_mapping=requests_idx_mapping, input_ids_list=input_ids_list, past_key_values_list=[None] * len(batch.requests), + stopping_criteria_list=stopping_criteria_list ) def to_cached_batch(self) -> CachedBatch: @@ -75,6 +82,7 @@ class DeepSparseCausalLMBatch: requests = [] input_ids_list = [] past_key_values_list = [] + stopping_criteria_list = [] # loop through requests, keep ones that should remain for new_idx, request_id in enumerate(request_ids): @@ -86,12 +94,14 @@ class DeepSparseCausalLMBatch: requests.append(self.requests[old_idx]) input_ids_list.append(self.input_ids_list[old_idx]) past_key_values_list.append(self.past_key_values_list[old_idx]) + stopping_criteria_list.append(self.stopping_criteria_list[old_idx]) # update batch state self.requests = requests self.requests_idx_mapping = requests_idx_mapping self.input_ids_list = input_ids_list self.past_key_values_list = past_key_values_list + self.stopping_criteria_list = stopping_criteria_list return self @@ -104,6 +114,7 @@ class DeepSparseCausalLMBatch: requests = [] input_ids_list = [] past_key_values_list = [] + stopping_criteria_list = [] start_index = 0 for i, batch in enumerate(batches): @@ -113,6 +124,7 @@ class DeepSparseCausalLMBatch: requests.extend(batch.requests) input_ids_list.extend(batch.input_ids_list) past_key_values_list.extend(batch.past_key_values_list) + stopping_criteria_list.extend(batch.stopping_criteria_list) # merge the request_id to index mapping if i == 0: @@ -128,14 +140,15 @@ class DeepSparseCausalLMBatch: requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids_list=input_ids_list, - past_key_values_list=past_key_values_list + past_key_values_list=past_key_values_list, + stopping_criteria_list=stopping_criteria_list ) class DeepSparseCausalLM: def __init__( - self, - model_path: str, - tokenizer_path: str, + self, + model_path: str, + tokenizer_path: str, ): # setup tokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) @@ -162,18 +175,6 @@ class DeepSparseCausalLM: # grab logits for the last item in the sequence # shape == (batch, seq, vocabulary_size) return np.argmax(logits[0,-1,:]) - - # TODO: switch to StoppingCriteria - def should_stop( - self, - num_tokens_processed: int, - generated_token_id: int - ): - if num_tokens_processed >= self.model.sequence_length: - return True - if generated_token_id == self.tokenizer.eos_token_id: - return True - return False def generate_token( self, @@ -188,13 +189,15 @@ class DeepSparseCausalLM: # b) sample and check stopping criteria # c) create generation # d) update batch - for i, (request, input_ids, past_key_values,) in enumerate( + for i, (request, input_ids, past_key_values,stopping_criteria,) in enumerate( zip( batch.requests, batch.input_ids_list, - batch.past_key_values_list + batch.past_key_values_list, + batch.stopping_criteria_list, ) ): + # assert input_ids is b=1 assert len(input_ids.shape) == 2 assert input_ids.shape[0] == 1 @@ -205,10 +208,8 @@ class DeepSparseCausalLM: # TODO: should use NextTokenChooser/StoppingCriteria (simple for now) generated_token_id = self.sample_token(logits) generated_token = self.tokenizer.decode(generated_token_id) - stop = self.should_stop( - num_tokens_processed=input_ids.shape[1] + 1, - generated_token_id = generated_token_id - ) + + stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id) if not stop: all_stopped = False @@ -217,11 +218,12 @@ class DeepSparseCausalLM: request_id=request.id, token=generated_token, token_id=generated_token_id, - stopped=stop + stopped=stop, + finish_reason=finish_reason )) # d) update batch - # TODO: this does not occur in place) + # TODO: this does not occur in place assert len(batch.input_ids_list[i].shape) == 2 assert batch.input_ids_list[i].shape[0] == 1 batch.input_ids_list[i] = np.append( diff --git a/deepsparse/service/service.py b/deepsparse/service/service.py index b3a9a8be..c49c8f6b 100644 --- a/deepsparse/service/service.py +++ b/deepsparse/service/service.py @@ -56,7 +56,7 @@ class DeepSparseService: assert len(generations) == 1 self.cache.set(next_ds_batch) - return generations[0], next_ds_batch.to_cached_batch() + return generations[0], (next_ds_batch.to_cached_batch() if next_ds_batch else None) def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: assert len(batches) != 0, "Must provide at least one batch" diff --git a/deepsparse/utils.py b/deepsparse/utils.py index 32c898b3..b5ee6b05 100644 --- a/deepsparse/utils.py +++ b/deepsparse/utils.py @@ -1,12 +1,38 @@ from dataclasses import dataclass from queue import Queue +from enum import Enum from typing import List, Optional +class FinishReason(Enum): + FINISH_REASON_LENGTH = 1 + FINISH_REASON_EOS_TOKEN = 2 + +class StoppingCriteria: + def __init__( + self, + eos_token_id: int, + max_new_tokens: int, + ): + assert max_new_tokens > 0 + self.max_new_tokens = max_new_tokens + self.eos_token_id = eos_token_id + self.current_tokens = 0 + + def __call__(self, generated_token_id:int): + self.current_tokens += 1 + if self.current_tokens >= self.max_new_tokens: + return True, FinishReason.FINISH_REASON_LENGTH + + if generated_token_id == self.eos_token_id: + return True, FinishReason.FINISH_REASON_EOS_TOKEN + + return False, None + @dataclass class Request: id: int - prompt: str - max_generated_tokens: int + inputs: str + max_new_tokens: int @dataclass class Batch: @@ -27,9 +53,10 @@ class Generation: token: Optional[str] token_id: Optional[str] stopped: bool + finish_reason: FinishReason = None @dataclass class GenerateRequest: - prompt: str - max_generated_tokens: int + inputs: str + max_new_tokens: int response_stream: Queue[Generation] \ No newline at end of file diff --git a/server-dev.ipynb b/server-dev.ipynb index 440cf7ef..d2ef7b3b 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 17, "id": "4ec49282-dafd-4d4b-af7a-a68af7fb0dc2", "metadata": {}, "outputs": [ @@ -21,136 +21,77 @@ "name": "stdout", "output_type": "stream", "text": [ - "\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Driver function to test above function\\nn = int(input(\\\"Enter the number: \\\"))\\nprint(fib(n))\\n\\n# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\\n\"" + "\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n \"" ] } ], "source": [ "!curl 127.0.0.1:5543/generate \\\n", " -X POST \\\n", - " -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n", + " -d '{\"inputs\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"generation_parameters\":{\"max_new_tokens\":10}}' \\\n", " -H 'Content-Type: application/json'" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 20, "id": "6ea583b1-e2d3-4f35-b87f-630d097a2628", "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "from threading import Thread\n", + "\n", + "url = \"http://127.0.0.1:5543/generate\"\n", + "# sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\"\n", + "sequence = \"def fib(n):\"\n", + "\n", + "def request_task(max_new_tokens):\n", + " obj = {\n", + " \"inputs\":sequence,\n", + " \"generation_parameters\": {\n", + " \"max_new_tokens\":max_new_tokens\n", + " }\n", + " }\n", + " with requests.post(url, json=obj) as r:\n", + " print(max_new_tokens)\n", + " print(r.text)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "9dec4413-afea-444a-90b7-c98b450d5fcc", + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "b'\\n'\n", - "b' '\n", - "b'if'\n", - "b' n'\n", - "b' =='\n", - "b' 0'\n", - "b':'\n", - "b'\\n'\n", - "b' '\n", - "b'return'\n", - "b' 0'\n", - "b'\\n'\n", - "b' '\n", - "b'el'\n", - "b'if'\n", - "b' n'\n", - "b' =='\n", - "b' 1'\n", - "b':'\n", - "b'\\n'\n", - "b' '\n", - "b'return'\n", - "b' 1'\n", - "b'\\n'\n", - "b' '\n", - "b'else'\n", - "b':'\n", - "b'\\n'\n", - "b' '\n", - "b'return'\n", - "b' fib'\n", - "b'('\n", - "b'n'\n", - "b'-'\n", - "b'1'\n", - "b')'\n", - "b' +'\n", - "b' fib'\n", - "b'('\n", - "b'n'\n", - "b'-'\n", - "b'2'\n", - "b')'\n", - "b'\\n'\n", - "b'\\n'\n", - "b'#'\n", - "b' Driver'\n", - "b' function'\n", - "b' to'\n", - "b' test'\n", - "b' above'\n", - "b' function'\n", - "b'\\n'\n", - "b'n'\n", - "b' ='\n", - "b' int'\n", - "b'('\n", - "b'input'\n", - "b'(\"'\n", - "b'Enter'\n", - "b' the'\n", - "b' number'\n", - "b':'\n", - "b' \"'\n", - "b'))'\n", - "b'\\n'\n", - "b'print'\n", - "b'('\n", - "b'f'\n", - "b'ib'\n", - "b'('\n", - "b'n'\n", - "b'))'\n", - "b'\\n'\n", - "b'\\n'\n", - "b'#'\n", - "b' This'\n", - "b' code'\n", - "b' is'\n", - "b' contributed'\n", - "b' by'\n", - "b' Nik'\n", - "b'h'\n", - "b'il'\n", - "b' Kumar'\n", - "b' Singh'\n", - "b'('\n", - "b'nick'\n", - "b'z'\n", - "b'uck'\n", - "b'_'\n", - "b'007'\n", - "b')'\n", - "b'\\n'\n" + "100\n", + "\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef\"\n", + "200\n", + "\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\"\n", + "300\n", + "\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\"\n" ] } ], "source": [ - "import requests\n", + "# max_new_tokens_lst = [50, 10, 100, 25, 15]\n", + "max_new_tokens_lst = [100, 200, 300]\n", "\n", - "url = \"http://127.0.0.1:5543/generate_stream\"\n", - "obj = {\n", - " \"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n", - " \"max_generated_tokens\":100\n", - "}\n", + "request_ts = [\n", + " Thread(target=request_task, args=[max_new_tokens]) for max_new_tokens in max_new_tokens_lst\n", + "]\n", "\n", - "with requests.post(url, json=obj, stream=True) as r:\n", - " for chunk in r.iter_content(16): # or, for line in r.iter_lines():\n", - " print(chunk)" + "import time\n", + "for request_t in request_ts:\n", + " request_t.start()\n", + " time.sleep(0.1)\n", + "\n", + "for request_t in request_ts:\n", + " request_t.join()" ] }, {