From 96f83659965fe326c863632e9056a8a9efcfad5a Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Fri, 25 Aug 2023 15:26:16 +0000 Subject: [PATCH 1/3] refactored stopping criteria; started concept of GenerationParameters to control generation --- currently enabling passing max_new_tokens; next step --- expand next token chooser --- deepsparse/main.py | 68 -------------- deepsparse/router.py | 14 +-- deepsparse/service/causal_lm.py | 72 +++++++-------- deepsparse/service/service.py | 2 +- deepsparse/utils.py | 35 +++++++- server-dev.ipynb | 155 ++++++++++---------------------- 6 files changed, 124 insertions(+), 222 deletions(-) delete mode 100644 deepsparse/main.py diff --git a/deepsparse/main.py b/deepsparse/main.py deleted file mode 100644 index 3ce9a532..00000000 --- a/deepsparse/main.py +++ /dev/null @@ -1,68 +0,0 @@ -import fastapi, uvicorn -from contextlib import asynccontextmanager -from threading import Thread -from queue import Queue -from router import DeepSparseRouter, batching_task -from utils import GenerateRequest - -TOKENIZER_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment" -MODEL_PATH = "/home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx" - -def serve( - model_path=MODEL_PATH, - tokenizer_path=TOKENIZER_PATH, - host="0.0.0.0", - port=5543 -): - - router = None - - @asynccontextmanager - async def lifespan(app: fastapi.FastAPI): - print("\n-------------------- Building Router --------------------\n") - router = DeepSparseRouter( - model_path=model_path, - tokenizer_path=tokenizer_path - ) - - print("\n-------------------- Starting Batching Task --------------------\n") - batching_thread = Thread(target=batching_task, args=[router]) - batching_thread.start() - - print("\n-------------------- Launching App --------------------\n") - yield - - print("\n-------------------- Joining Batching Task --------------------\n") - router.stop_batching_task() - batching_task.join() - - app = fastapi.FastAPI(lifespan=lifespan) - - @app.get("/generate/{prompt}") - async def generate(prompt:str): - response_stream = Queue() - router.submit_request( - GenerateRequest( - prompt=prompt, - max_generated_tokens=100, - response_stream=response_stream - ) - ) - - response_string = prompt - generation = response_stream.get() - while not generation.stopped: - response_string += generation.token - generation = response_stream.get() - - return response_string - - uvicorn.run( - app, - host=host, - port=port, - workers=1 - ) - -if __name__ == "__main__": - serve() \ No newline at end of file diff --git a/deepsparse/router.py b/deepsparse/router.py index 85e5099e..a5d0b2a7 100644 --- a/deepsparse/router.py +++ b/deepsparse/router.py @@ -2,7 +2,7 @@ from queue import Queue from typing import List, Dict, Optional, Tuple from service.service import DeepSparseService from service.causal_lm import DeepSparseCausalLM -from utils import CachedBatch, Batch, Generation, GenerateRequest, Request +from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, StoppingCriteria class DeepSparseRouter: def __init__( @@ -38,14 +38,14 @@ class DeepSparseRouter: # unblock the batching task with a dummy request if blocked self.queue.append(GenerateRequest( - prompt="dummy", - max_generated_tokens=1, + inputs="dummy", + max_new_tokens=1, response_stream=Queue() )) def prefill( - self, - batch: Batch, + self, + batch: Batch, generate_requests: Dict[int, GenerateRequest] ) -> Optional[CachedBatch]: @@ -175,8 +175,8 @@ class DeepSparseQueue: # format into request request = Request( id=self.next_request_id, - prompt=generate_request.prompt, - max_generated_tokens=generate_request.max_generated_tokens + inputs=generate_request.inputs, + max_new_tokens=generate_request.max_new_tokens ) self.next_request_id += 1 diff --git a/deepsparse/service/causal_lm.py b/deepsparse/service/causal_lm.py index 8aee72a9..b08a6bdf 100644 --- a/deepsparse/service/causal_lm.py +++ b/deepsparse/service/causal_lm.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase import numpy as np from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel -from utils import Request, Batch, CachedBatch, Generation +from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria DEEPSPARSE_SEQUENCE_LENGTH = 128 DEEPSPARSE_MULTITOKEN_LENGTH = 4 @@ -16,6 +16,7 @@ class DeepSparseCausalLMBatch: requests_idx_mapping: Dict[int,int] input_ids_list: List[np.ndarray] past_key_values_list: List[Optional[DeepSparsePastKeyValues]] + stopping_criteria_list: List[StoppingCriteria] @classmethod def from_batch( @@ -27,34 +28,40 @@ class DeepSparseCausalLMBatch: # parse batch requests_idx_mapping = {} input_ids_list = [] - - # setup tokenizer for deepsparse left padding - tokenizer.padding_side = "left" - if not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.eos_token - padding, truncation = "longest", False + stopping_criteria_list = [] # loop through items in the batch for idx, r in enumerate(batch.requests): requests_idx_mapping[r.id] = idx - # setup inputs_ids, past_key_values + # setup inputs_ids, stopping crtieria tokenized_inputs = tokenizer( - r.prompt, + r.inputs, return_tensors="np", - padding=padding, - truncation=truncation, + padding="longest", + truncation=False, return_token_type_ids=False, max_length=DEEPSPARSE_SEQUENCE_LENGTH ) input_ids_list.append(tokenized_inputs["input_ids"]) + # deepsparse able to accept up to seq len tokens + num_input_tokens = tokenized_inputs["input_ids"].shape[1] + model_max_new_tokens = DEEPSPARSE_SEQUENCE_LENGTH - num_input_tokens + stopping_criteria_list.append( + StoppingCriteria( + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=min(r.max_new_tokens, model_max_new_tokens) + ) + ) + return cls( batch_id=batch.id, requests=batch.requests, requests_idx_mapping=requests_idx_mapping, input_ids_list=input_ids_list, past_key_values_list=[None] * len(batch.requests), + stopping_criteria_list=stopping_criteria_list ) def to_cached_batch(self) -> CachedBatch: @@ -75,6 +82,7 @@ class DeepSparseCausalLMBatch: requests = [] input_ids_list = [] past_key_values_list = [] + stopping_criteria_list = [] # loop through requests, keep ones that should remain for new_idx, request_id in enumerate(request_ids): @@ -86,12 +94,14 @@ class DeepSparseCausalLMBatch: requests.append(self.requests[old_idx]) input_ids_list.append(self.input_ids_list[old_idx]) past_key_values_list.append(self.past_key_values_list[old_idx]) + stopping_criteria_list.append(self.stopping_criteria_list[old_idx]) # update batch state self.requests = requests self.requests_idx_mapping = requests_idx_mapping self.input_ids_list = input_ids_list self.past_key_values_list = past_key_values_list + self.stopping_criteria_list = stopping_criteria_list return self @@ -104,6 +114,7 @@ class DeepSparseCausalLMBatch: requests = [] input_ids_list = [] past_key_values_list = [] + stopping_criteria_list = [] start_index = 0 for i, batch in enumerate(batches): @@ -113,6 +124,7 @@ class DeepSparseCausalLMBatch: requests.extend(batch.requests) input_ids_list.extend(batch.input_ids_list) past_key_values_list.extend(batch.past_key_values_list) + stopping_criteria_list.extend(batch.stopping_criteria_list) # merge the request_id to index mapping if i == 0: @@ -128,14 +140,15 @@ class DeepSparseCausalLMBatch: requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids_list=input_ids_list, - past_key_values_list=past_key_values_list + past_key_values_list=past_key_values_list, + stopping_criteria_list=stopping_criteria_list ) class DeepSparseCausalLM: def __init__( - self, - model_path: str, - tokenizer_path: str, + self, + model_path: str, + tokenizer_path: str, ): # setup tokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) @@ -162,18 +175,6 @@ class DeepSparseCausalLM: # grab logits for the last item in the sequence # shape == (batch, seq, vocabulary_size) return np.argmax(logits[0,-1,:]) - - # TODO: switch to StoppingCriteria - def should_stop( - self, - num_tokens_processed: int, - generated_token_id: int - ): - if num_tokens_processed >= self.model.sequence_length: - return True - if generated_token_id == self.tokenizer.eos_token_id: - return True - return False def generate_token( self, @@ -188,13 +189,15 @@ class DeepSparseCausalLM: # b) sample and check stopping criteria # c) create generation # d) update batch - for i, (request, input_ids, past_key_values,) in enumerate( + for i, (request, input_ids, past_key_values,stopping_criteria,) in enumerate( zip( batch.requests, batch.input_ids_list, - batch.past_key_values_list + batch.past_key_values_list, + batch.stopping_criteria_list, ) ): + # assert input_ids is b=1 assert len(input_ids.shape) == 2 assert input_ids.shape[0] == 1 @@ -205,10 +208,8 @@ class DeepSparseCausalLM: # TODO: should use NextTokenChooser/StoppingCriteria (simple for now) generated_token_id = self.sample_token(logits) generated_token = self.tokenizer.decode(generated_token_id) - stop = self.should_stop( - num_tokens_processed=input_ids.shape[1] + 1, - generated_token_id = generated_token_id - ) + + stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id) if not stop: all_stopped = False @@ -217,11 +218,12 @@ class DeepSparseCausalLM: request_id=request.id, token=generated_token, token_id=generated_token_id, - stopped=stop + stopped=stop, + finish_reason=finish_reason )) # d) update batch - # TODO: this does not occur in place) + # TODO: this does not occur in place assert len(batch.input_ids_list[i].shape) == 2 assert batch.input_ids_list[i].shape[0] == 1 batch.input_ids_list[i] = np.append( diff --git a/deepsparse/service/service.py b/deepsparse/service/service.py index b3a9a8be..c49c8f6b 100644 --- a/deepsparse/service/service.py +++ b/deepsparse/service/service.py @@ -56,7 +56,7 @@ class DeepSparseService: assert len(generations) == 1 self.cache.set(next_ds_batch) - return generations[0], next_ds_batch.to_cached_batch() + return generations[0], (next_ds_batch.to_cached_batch() if next_ds_batch else None) def Decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: assert len(batches) != 0, "Must provide at least one batch" diff --git a/deepsparse/utils.py b/deepsparse/utils.py index 32c898b3..b5ee6b05 100644 --- a/deepsparse/utils.py +++ b/deepsparse/utils.py @@ -1,12 +1,38 @@ from dataclasses import dataclass from queue import Queue +from enum import Enum from typing import List, Optional +class FinishReason(Enum): + FINISH_REASON_LENGTH = 1 + FINISH_REASON_EOS_TOKEN = 2 + +class StoppingCriteria: + def __init__( + self, + eos_token_id: int, + max_new_tokens: int, + ): + assert max_new_tokens > 0 + self.max_new_tokens = max_new_tokens + self.eos_token_id = eos_token_id + self.current_tokens = 0 + + def __call__(self, generated_token_id:int): + self.current_tokens += 1 + if self.current_tokens >= self.max_new_tokens: + return True, FinishReason.FINISH_REASON_LENGTH + + if generated_token_id == self.eos_token_id: + return True, FinishReason.FINISH_REASON_EOS_TOKEN + + return False, None + @dataclass class Request: id: int - prompt: str - max_generated_tokens: int + inputs: str + max_new_tokens: int @dataclass class Batch: @@ -27,9 +53,10 @@ class Generation: token: Optional[str] token_id: Optional[str] stopped: bool + finish_reason: FinishReason = None @dataclass class GenerateRequest: - prompt: str - max_generated_tokens: int + inputs: str + max_new_tokens: int response_stream: Queue[Generation] \ No newline at end of file diff --git a/server-dev.ipynb b/server-dev.ipynb index 440cf7ef..d2ef7b3b 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 17, "id": "4ec49282-dafd-4d4b-af7a-a68af7fb0dc2", "metadata": {}, "outputs": [ @@ -21,136 +21,77 @@ "name": "stdout", "output_type": "stream", "text": [ - "\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Driver function to test above function\\nn = int(input(\\\"Enter the number: \\\"))\\nprint(fib(n))\\n\\n# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\\n\"" + "\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\\n if n == 0:\\n \"" ] } ], "source": [ "!curl 127.0.0.1:5543/generate \\\n", " -X POST \\\n", - " -d '{\"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"max_generated_tokens\":100}' \\\n", + " -d '{\"inputs\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\"generation_parameters\":{\"max_new_tokens\":10}}' \\\n", " -H 'Content-Type: application/json'" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 20, "id": "6ea583b1-e2d3-4f35-b87f-630d097a2628", "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "from threading import Thread\n", + "\n", + "url = \"http://127.0.0.1:5543/generate\"\n", + "# sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\"\n", + "sequence = \"def fib(n):\"\n", + "\n", + "def request_task(max_new_tokens):\n", + " obj = {\n", + " \"inputs\":sequence,\n", + " \"generation_parameters\": {\n", + " \"max_new_tokens\":max_new_tokens\n", + " }\n", + " }\n", + " with requests.post(url, json=obj) as r:\n", + " print(max_new_tokens)\n", + " print(r.text)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "9dec4413-afea-444a-90b7-c98b450d5fcc", + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "b'\\n'\n", - "b' '\n", - "b'if'\n", - "b' n'\n", - "b' =='\n", - "b' 0'\n", - "b':'\n", - "b'\\n'\n", - "b' '\n", - "b'return'\n", - "b' 0'\n", - "b'\\n'\n", - "b' '\n", - "b'el'\n", - "b'if'\n", - "b' n'\n", - "b' =='\n", - "b' 1'\n", - "b':'\n", - "b'\\n'\n", - "b' '\n", - "b'return'\n", - "b' 1'\n", - "b'\\n'\n", - "b' '\n", - "b'else'\n", - "b':'\n", - "b'\\n'\n", - "b' '\n", - "b'return'\n", - "b' fib'\n", - "b'('\n", - "b'n'\n", - "b'-'\n", - "b'1'\n", - "b')'\n", - "b' +'\n", - "b' fib'\n", - "b'('\n", - "b'n'\n", - "b'-'\n", - "b'2'\n", - "b')'\n", - "b'\\n'\n", - "b'\\n'\n", - "b'#'\n", - "b' Driver'\n", - "b' function'\n", - "b' to'\n", - "b' test'\n", - "b' above'\n", - "b' function'\n", - "b'\\n'\n", - "b'n'\n", - "b' ='\n", - "b' int'\n", - "b'('\n", - "b'input'\n", - "b'(\"'\n", - "b'Enter'\n", - "b' the'\n", - "b' number'\n", - "b':'\n", - "b' \"'\n", - "b'))'\n", - "b'\\n'\n", - "b'print'\n", - "b'('\n", - "b'f'\n", - "b'ib'\n", - "b'('\n", - "b'n'\n", - "b'))'\n", - "b'\\n'\n", - "b'\\n'\n", - "b'#'\n", - "b' This'\n", - "b' code'\n", - "b' is'\n", - "b' contributed'\n", - "b' by'\n", - "b' Nik'\n", - "b'h'\n", - "b'il'\n", - "b' Kumar'\n", - "b' Singh'\n", - "b'('\n", - "b'nick'\n", - "b'z'\n", - "b'uck'\n", - "b'_'\n", - "b'007'\n", - "b')'\n", - "b'\\n'\n" + "100\n", + "\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef\"\n", + "200\n", + "\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\"\n", + "300\n", + "\"def fib(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef fib3(n):\\n if n == 0:\\n return 0\\n elif n == 1\"\n" ] } ], "source": [ - "import requests\n", + "# max_new_tokens_lst = [50, 10, 100, 25, 15]\n", + "max_new_tokens_lst = [100, 200, 300]\n", "\n", - "url = \"http://127.0.0.1:5543/generate_stream\"\n", - "obj = {\n", - " \"prompt\":\"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\",\n", - " \"max_generated_tokens\":100\n", - "}\n", + "request_ts = [\n", + " Thread(target=request_task, args=[max_new_tokens]) for max_new_tokens in max_new_tokens_lst\n", + "]\n", "\n", - "with requests.post(url, json=obj, stream=True) as r:\n", - " for chunk in r.iter_content(16): # or, for line in r.iter_lines():\n", - " print(chunk)" + "import time\n", + "for request_t in request_ts:\n", + " request_t.start()\n", + " time.sleep(0.1)\n", + "\n", + "for request_t in request_ts:\n", + " request_t.join()" ] }, { From a875c05ccdbc9e3c6c7e47572447e21a2e5b51b2 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Fri, 25 Aug 2023 16:11:09 +0000 Subject: [PATCH 2/3] implemented temperature, repetition penalty --- server-dev.ipynb | 86 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/server-dev.ipynb b/server-dev.ipynb index d2ef7b3b..23a48e3b 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -11,6 +11,92 @@ "%autoreload 2" ] }, + { + "cell_type": "code", + "execution_count": 42, + "id": "d8cd5290-a55b-44c0-ab2c-b34299a1da7c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0.01 -0.01 0.01 -0.01 0.01 -0.01 0.01 -0.01]]\n", + "[[0 2 5 7]]\n", + "[[ 0.005 -0.01 0.005 -0.01 0.01 -0.02 0.01 -0.02 ]]\n", + "[[ 0.0025 -0.005 0.0025 -0.005 0.005 -0.01 0.005 -0.01 ]]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "class RepetitionPenaltyLogitsProcessor:\n", + " def __init__(self, penalty: float):\n", + " if not isinstance(penalty, float) or not (penalty > 0):\n", + " raise ValueError(f\"`penalty` has to be a strictly positive float, but is {penalty}\")\n", + "\n", + " self.penalty = penalty\n", + "\n", + " def __call__(self, scores: np.ndarray, input_ids: np.ndarray) -> np.ndarray:\n", + " # assert shape is [1, vocab_size]\n", + " assert len(scores.shape) == 2\n", + " assert scores.shape[0] == 1\n", + "\n", + " # assert shape is [1, seq_len]\n", + " assert len(input_ids.shape) == 2\n", + " assert input_ids.shape[0] == 1\n", + " \n", + " # TODO: update logic to handle b > 1\n", + " score = scores[:, input_ids[0]]\n", + " score = np.where(score < 0, score * self.penalty, score / self.penalty)\n", + " scores[:, input_ids[0]] = score\n", + "\n", + " return scores\n", + "\n", + "class TemperatureLogitsWarper:\n", + " def __init__(self, temperature: float):\n", + " if not isinstance(temperature, float) or not (temperature > 0):\n", + " except_msg = (\n", + " f\"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token \"\n", + " \"scores will be invalid.\"\n", + " )\n", + " if isinstance(temperature, float) and temperature == 0.0:\n", + " except_msg += \" If you're looking for greedy decoding strategies, set `do_sample=False`.\"\n", + " raise ValueError(except_msg)\n", + " \n", + " self.temperature = temperature\n", + "\n", + " def __call__(self, scores: np.ndarray) -> np.ndarray:\n", + " # assert shape is [1, vocab_size]\n", + " assert len(scores.shape) == 2\n", + " assert scores.shape[0] == 1\n", + "\n", + " return scores / self.temperature\n", + "\n", + "input_ids = np.array([[0,2,5,7]])\n", + "logits = np.array([[0.01, -0.01]*4])\n", + "\n", + "print(logits)\n", + "print(input_ids)\n", + "\n", + "processor = RepetitionPenaltyLogitsProcessor(penalty=2.0)\n", + "logits = processor(scores=logits, input_ids=input_ids)\n", + "print(logits)\n", + "\n", + "warper = TemperatureLogitsWarper(temperature=2.0)\n", + "logits = warper(scores=logits)\n", + "print(logits)" + ] + }, + { + "cell_type": "markdown", + "id": "e548af40-7b71-4f96-98b5-f33f03ef3f66", + "metadata": {}, + "source": [ + "p# **Interacting with FastAPI Server**" + ] + }, { "cell_type": "code", "execution_count": 17, From bb124e40299fd2c95973a00e5880d56a23619c67 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Mon, 28 Aug 2023 02:05:43 +0000 Subject: [PATCH 3/3] successfully implemented and integrated nexttokenchooser with support for repetition penalty, do_sample, temperature, top_k, and top_p --- deepsparse/router.py | 20 +- deepsparse/service/causal_lm.py | 86 ++-- deepsparse/utils.py | 96 +++- server-dev.ipynb | 769 +++++++++++++++++++++++++++++++- 4 files changed, 921 insertions(+), 50 deletions(-) diff --git a/deepsparse/router.py b/deepsparse/router.py index a5d0b2a7..aaa973fa 100644 --- a/deepsparse/router.py +++ b/deepsparse/router.py @@ -2,7 +2,7 @@ from queue import Queue from typing import List, Dict, Optional, Tuple from service.service import DeepSparseService from service.causal_lm import DeepSparseCausalLM -from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, StoppingCriteria +from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, GenerationParameters class DeepSparseRouter: def __init__( @@ -11,10 +11,7 @@ class DeepSparseRouter: model_path: Optional[str] = None, tokenizer_path: Optional[str] = None ): - assert ( - service is not None or - (model_path is not None and tokenizer_path is not None) - ) + assert service is not None or (model_path is not None and tokenizer_path is not None) if service is not None: self.service = service @@ -38,8 +35,8 @@ class DeepSparseRouter: # unblock the batching task with a dummy request if blocked self.queue.append(GenerateRequest( - inputs="dummy", - max_new_tokens=1, + inputs="stop", + generation_parameters=GenerationParameters(max_new_tokens=1), response_stream=Queue() )) @@ -166,9 +163,10 @@ class DeepSparseQueue: # if block = True, this blocks until something ready # if block = False, the queue has data (if not an exception is raised) - # while queue.empty() == False does not guarentee data - # the queue is only subscribed to by one thread (this one) - # since batching_task is the only function that calls next_batch + # while queue.empty() == False typically not guarentee data on next queue.get(), this + # queue is only subscribed to by one thread (this one) since batching_task is the only + # so it does in our case + generate_request = self.queue.get(block=block) generate_requests = {self.next_request_id: generate_request} @@ -176,7 +174,7 @@ class DeepSparseQueue: request = Request( id=self.next_request_id, inputs=generate_request.inputs, - max_new_tokens=generate_request.max_new_tokens + generation_parameters=generate_request.generation_parameters, ) self.next_request_id += 1 diff --git a/deepsparse/service/causal_lm.py b/deepsparse/service/causal_lm.py index b08a6bdf..23d0021f 100644 --- a/deepsparse/service/causal_lm.py +++ b/deepsparse/service/causal_lm.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase import numpy as np from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel -from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria +from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria, NextTokenChooser DEEPSPARSE_SEQUENCE_LENGTH = 128 DEEPSPARSE_MULTITOKEN_LENGTH = 4 @@ -17,6 +17,7 @@ class DeepSparseCausalLMBatch: input_ids_list: List[np.ndarray] past_key_values_list: List[Optional[DeepSparsePastKeyValues]] stopping_criteria_list: List[StoppingCriteria] + next_token_chooser_list: List[NextTokenChooser] @classmethod def from_batch( @@ -29,6 +30,7 @@ class DeepSparseCausalLMBatch: requests_idx_mapping = {} input_ids_list = [] stopping_criteria_list = [] + next_token_chooser_list = [] # loop through items in the batch for idx, r in enumerate(batch.requests): @@ -51,7 +53,19 @@ class DeepSparseCausalLMBatch: stopping_criteria_list.append( StoppingCriteria( eos_token_id=tokenizer.eos_token_id, - max_new_tokens=min(r.max_new_tokens, model_max_new_tokens) + max_new_tokens=min(r.generation_parameters.max_new_tokens, model_max_new_tokens) + ) + ) + + # get next token chooser based on input + next_token_chooser_list.append( + NextTokenChooser( + repetition_penalty=r.generation_parameters.repetition_penalty, + temperature=r.generation_parameters.temperature, + top_k=r.generation_parameters.top_k, + top_p=r.generation_parameters.top_p, + do_sample=r.generation_parameters.do_sample, + seed=r.generation_parameters.seed, ) ) @@ -61,7 +75,8 @@ class DeepSparseCausalLMBatch: requests_idx_mapping=requests_idx_mapping, input_ids_list=input_ids_list, past_key_values_list=[None] * len(batch.requests), - stopping_criteria_list=stopping_criteria_list + stopping_criteria_list=stopping_criteria_list, + next_token_chooser_list=next_token_chooser_list, ) def to_cached_batch(self) -> CachedBatch: @@ -78,11 +93,12 @@ class DeepSparseCausalLMBatch: def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]: assert(len(request_ids) > 0) - requests_idx_mapping = {} - requests = [] - input_ids_list = [] - past_key_values_list = [] - stopping_criteria_list = [] + requests_idx_mapping = {} + requests = [] + input_ids_list = [] + past_key_values_list = [] + stopping_criteria_list = [] + next_token_chooser_list = [] # loop through requests, keep ones that should remain for new_idx, request_id in enumerate(request_ids): @@ -95,6 +111,7 @@ class DeepSparseCausalLMBatch: input_ids_list.append(self.input_ids_list[old_idx]) past_key_values_list.append(self.past_key_values_list[old_idx]) stopping_criteria_list.append(self.stopping_criteria_list[old_idx]) + next_token_chooser_list.append(self.next_token_chooser_list[old_idx]) # update batch state self.requests = requests @@ -102,6 +119,7 @@ class DeepSparseCausalLMBatch: self.input_ids_list = input_ids_list self.past_key_values_list = past_key_values_list self.stopping_criteria_list = stopping_criteria_list + self.next_token_chooser_list = next_token_chooser_list return self @@ -110,11 +128,12 @@ class DeepSparseCausalLMBatch: def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch": assert len(batches) > 1, "must have more than 1 batch to concatenate" - requests_idx_mapping = {} - requests = [] - input_ids_list = [] - past_key_values_list = [] - stopping_criteria_list = [] + requests_idx_mapping = {} + requests = [] + input_ids_list = [] + past_key_values_list = [] + stopping_criteria_list = [] + next_token_chooser_list = [] start_index = 0 for i, batch in enumerate(batches): @@ -125,6 +144,7 @@ class DeepSparseCausalLMBatch: input_ids_list.extend(batch.input_ids_list) past_key_values_list.extend(batch.past_key_values_list) stopping_criteria_list.extend(batch.stopping_criteria_list) + next_token_chooser_list.extend(batch.next_token_chooser_list) # merge the request_id to index mapping if i == 0: @@ -141,7 +161,8 @@ class DeepSparseCausalLMBatch: requests_idx_mapping=requests_idx_mapping, input_ids_list=input_ids_list, past_key_values_list=past_key_values_list, - stopping_criteria_list=stopping_criteria_list + stopping_criteria_list=stopping_criteria_list, + next_token_chooser_list=next_token_chooser_list ) class DeepSparseCausalLM: @@ -164,18 +185,6 @@ class DeepSparseCausalLM: multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, ) - # TODO: switch to NextTokenChooser - def sample_token( - self, - logits: np.ndarray - ): - # assert b=1 for now - assert(logits.shape[0] == 1) - - # grab logits for the last item in the sequence - # shape == (batch, seq, vocabulary_size) - return np.argmax(logits[0,-1,:]) - def generate_token( self, batch: DeepSparseCausalLMBatch, @@ -189,14 +198,21 @@ class DeepSparseCausalLM: # b) sample and check stopping criteria # c) create generation # d) update batch - for i, (request, input_ids, past_key_values,stopping_criteria,) in enumerate( - zip( - batch.requests, - batch.input_ids_list, - batch.past_key_values_list, - batch.stopping_criteria_list, - ) - ): + + iterator = zip( + batch.requests, + batch.input_ids_list, + batch.past_key_values_list, + batch.stopping_criteria_list, + batch.next_token_chooser_list, + ) + for i, ( + request, + input_ids, + past_key_values, + stopping_criteria, + next_token_chooser + ) in enumerate(iterator): # assert input_ids is b=1 assert len(input_ids.shape) == 2 assert input_ids.shape[0] == 1 @@ -206,7 +222,7 @@ class DeepSparseCausalLM: # b) sample token and check stopping criteria # TODO: should use NextTokenChooser/StoppingCriteria (simple for now) - generated_token_id = self.sample_token(logits) + generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:]) generated_token = self.tokenizer.decode(generated_token_id) stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id) diff --git a/deepsparse/utils.py b/deepsparse/utils.py index b5ee6b05..fda83570 100644 --- a/deepsparse/utils.py +++ b/deepsparse/utils.py @@ -1,8 +1,73 @@ +import numpy as np from dataclasses import dataclass +from pydantic import BaseModel, Field from queue import Queue from enum import Enum from typing import List, Optional +from logits_process import (RepetitionPenaltyLogitsProcessor, LogitsWarpers, softmax) +# TODO: sample for b > 1 with vectorized code +class Greedy: + def __call__(self, logits: np.ndarray): + # assert b=1 for now + # shape == (batch, vocabulary_size) + assert(logits.shape[0] == 1) + assert(len(logits.shape) == 2) + + return np.argmax(logits[0,:]) + +# TODO: sample for b > 1 with vectorized code +# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a +class Sampling: + def __init__(self, seed:int=42): + self.generator = np.random.default_rng(seed=seed) + + def __call__(self, logits:np.ndarray): + # assert b=1 for now + # shape == (batch, vocabulary_size) + assert(logits.shape[0] == 1) + assert(len(logits.shape) == 2) + + probs = softmax(logits, axis=1) + return self.generator.choice(probs.shape[1], p=probs[0,:]) + +class NextTokenChooser: + def __init__( + self, + repetition_penalty: Optional[float] = 1.0, + temperature: Optional[float] = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + do_sample:bool = False, + seed: int = 42, + ): + self.repetition_processor = ( + RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) + if repetition_penalty and repetition_penalty != 1.0 else None + ) + + has_warpers = ( + (temperature is not None and temperature != 1.0) + or (top_k is not None and top_k != 0) + or (top_p is not None and top_p < 1.0) + ) + + if has_warpers: + self.warpers = LogitsWarpers(temperature=temperature, top_k=top_k, top_p=top_p) + else: + self.warpers = None + + self.choice = Sampling(seed=seed) if do_sample or has_warpers else Greedy() + + def __call__(self, input_ids: np.ndarray, scores:np.ndarray) -> int: + if self.repetition_processor is not None: + scores = self.repetition_processor(input_ids=input_ids, scores=scores) + + if self.warpers is not None: + scores = self.warpers(scores=scores) + + return self.choice(scores) + class FinishReason(Enum): FINISH_REASON_LENGTH = 1 FINISH_REASON_EOS_TOKEN = 2 @@ -28,11 +93,28 @@ class StoppingCriteria: return False, None +class GenerationParameters(BaseModel): + max_new_tokens: int = Field(default=20) + repetition_penalty: float = Field(default=1) + do_sample: bool = Field(default=False) + temperature: float = Field(default=1.0) + top_k: Optional[int] = Field(default=None) + top_p: Optional[float] = Field(default=None) + seed: int = Field(default=42) + +class GenerateRequestInputs(BaseModel): + inputs: str + generation_parameters: GenerationParameters + +class GenerateRequestOutputs(BaseModel): + response_text: str = Field(default="") + finish_reason: Optional[FinishReason] = Field(default=None) + @dataclass class Request: id: int inputs: str - max_new_tokens: int + generation_parameters: GenerationParameters @dataclass class Batch: @@ -58,5 +140,13 @@ class Generation: @dataclass class GenerateRequest: inputs: str - max_new_tokens: int - response_stream: Queue[Generation] \ No newline at end of file + generation_parameters: GenerationParameters + response_stream: Queue[Generation] + + @classmethod + def from_gr_inputs(cls, gr_inputs: GenerateRequestInputs): + return cls( + inputs=gr_inputs.inputs, + generation_parameters=gr_inputs.generation_parameters, + response_stream=Queue() + ) \ No newline at end of file diff --git a/server-dev.ipynb b/server-dev.ipynb index 23a48e3b..95d320cb 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -11,6 +11,773 @@ "%autoreload 2" ] }, + { + "cell_type": "code", + "execution_count": 89, + "id": "7513b2a1-0749-44b5-88a9-d91d5b175e8e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "45\n", + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + "def fib(n):\n", + " if n<=1:\n", + " return n\n", + " else:\n", + " return fib(n-1)+fib(n-2)\n", + " \n", + "print(fib(15))\n", + "\n", + "\n", + "50\n", + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + "def fib(n):\n", + " if n == 0:\n", + " return 0\n", + " elif n == 1:\n", + " return 1\n", + " else:\n", + " return fib(n - 1) + fib(n - 2)\n", + "\n", + "n = int(\n", + "55\n", + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + "def fib(n):\n", + " x = 1\n", + " y = 1\n", + " if n < 2:\n", + " return 1\n", + " else:\n", + " for i in range(2, n):\n", + " z = x + y\n", + " x = y\n", + " y = z\n", + "\n" + ] + } + ], + "source": [ + "import requests\n", + "from threading import Thread\n", + "import json\n", + "\n", + "url = \"http://127.0.0.1:5543/generate\"\n", + "sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\"\n", + "# sequence = \"def fib(n):\"\n", + "\n", + "def request_task(seed, max_new_tokens):\n", + " obj = {\n", + " \"inputs\": sequence,\n", + " \"generation_parameters\": {\n", + " \"max_new_tokens\":max_new_tokens,\n", + " # \"repetition_penalty\": 1.1,\n", + " # \"do_sample\": True,\n", + " # \"temperature\": 1.1,\n", + " # \"top_k\": 3,\n", + " # \"top_p\": 0.9,\n", + " \"seed\": seed,\n", + " }\n", + " }\n", + " with requests.post(url, json=obj) as r:\n", + " print(max_new_tokens)\n", + " dct = json.loads(r.text)\n", + " print(f'{sequence}{dct[\"response_text\"]}')\n", + "\n", + "max_new_tokens_lst = [55, 50, 45]\n", + "seeds = [1,2,3]\n", + "# max_new_tokens_lst = [100, 200, 300]\n", + "\n", + "request_ts = [\n", + " Thread(target=request_task, args=[seed, max_new_tokens]) for seed, max_new_tokens in zip(seeds, max_new_tokens_lst)\n", + "]\n", + "\n", + "import time\n", + "for request_t in request_ts:\n", + " request_t.start()\n", + " time.sleep(0.1)\n", + "\n", + "for request_t in request_ts:\n", + " request_t.join()" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "28964c0b-0822-48aa-bc4c-a519035fae9a", + "metadata": {}, + "outputs": [], + "source": [ + "obj = {\n", + " \"inputs\": sequence,\n", + " \"generation_parameters\": {\n", + " \"max_new_tokens\":100,\n", + " \"repetion_penalty\": 2.0,\n", + " # \"do_sample\": False,\n", + " # \"temperature\": 1.0,\n", + " # \"top_k\": 1,\n", + " # \"top_p\": 0.9,\n", + " # \"seed\": 43,\n", + " }\n", + " }\n", + "\n", + "resp = requests.post(url, json=obj)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "bdb67325-1341-4da4-808f-c1b12fdbf607", + "metadata": {}, + "outputs": [], + "source": [ + "a = \"\"\n", + "token = '\\n'\n", + "for i in range(5):\n", + " a += token" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "96aebc29-2c43-456e-ab62-3c422feb57c6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(a)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "85447c0d-7020-46af-826f-fec58c9097b8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\"\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef\"\n" + ] + } + ], + "source": [ + "print(resp.text)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ff94c668-5832-4096-a723-3fc673b90495", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/robertgshaw/deepsparse-continuous-batching/deepsparse\n" + ] + } + ], + "source": [ + "%cd deepsparse" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d314c9f2-3fa7-44c4-8609-3bb24d008bb0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-1. 1. 4. 3. 2. 1.]]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from logits_process import TopKLogitsWarper, TopPLogitsWarper, softmax\n", + "\n", + "scores = np.array([[-1.,1.,4.,3.,2.,1.]])\n", + "\n", + "sorted_indices = np.argsort(scores, axis=-1)\n", + "sorted_scores = np.take_along_axis(scores, sorted_indices, axis=-1)\n", + "\n", + "# sort, grabbing all those outside top_p\n", + "cumulative_probs = softmax(sorted_scores, axis=-1).cumsum(axis=-1)\n", + "sorted_indices_to_remove = cumulative_probs <= (1 - 0.9)\n", + "\n", + "\n", + "# note: this relies on b=1\n", + "indices_to_remove = sorted_indices[sorted_indices_to_remove]\n", + "\n", + "# set removed indices logits to -Inf (never selected)\n", + "scores[:, indices_to_remove] = -float(\"Inf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "cced6917-8017-405a-ae93-c43105fce82d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-2. 1. 5. 3. 2. 1.]\n", + " [ 5. -2. 3. -1. 2. 1.]]\n", + "[[-inf -inf 5. 3. 2. -inf]\n", + " [ 5. -inf 3. -inf 2. -inf]]\n", + "[[0. 0. 0.84379473 0.1141952 0.04201007 0. ]\n", + " [0.84379473 0. 0.1141952 0. 0.04201007 0. ]]\n" + ] + } + ], + "source": [ + "from logits_process import TopKLogitsWarper, TopPLogitsWarper, softmax\n", + "\n", + "logits = np.array([[-2.,1.,5.,3.,2.,1.], [5.,-2.,3.,-1.,2.,1.]])\n", + "print(logits)\n", + "\n", + "top_k = TopKLogitsWarper(top_k=3)\n", + "new_logits = top_k(logits)\n", + "print(new_logits)\n", + "\n", + "print(softmax(new_logits, axis=1))" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "44fb3e44-0888-49af-a5d6-2c908db324ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-1. 1. 4. 3. 2. 1.]]\n", + "[[0.00418629 0.03093274 0.62130062 0.22856372 0.0840839 0.03093274]]\n", + "[[-inf -inf 4. 3. 2. -inf]]\n", + "[[0. 0. 0.66524096 0.24472847 0.09003057 0. ]]\n" + ] + } + ], + "source": [ + "logits = np.array([[-1.,1.,4.,3.,2.,1.]])\n", + "print(logits)\n", + "\n", + "print(softmax(logits, axis=-1))\n", + "\n", + "top_p = TopPLogitsWarper(top_p=0.9)\n", + "new_logits = top_p(logits)\n", + "print(new_logits)\n", + "\n", + "print(softmax(new_logits, axis=-1))" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "f20813aa-0362-4820-a3b2-c32531658718", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-1. 1. 3. 3.5 2.5 2. 1. ]]\n", + "[[0.00468177 0.03459387 0.25561604 0.4214396 0.15503896 0.09403589\n", + " 0.03459387]]\n", + "[[-inf -inf 3. 3.5 2.5 2. -inf]]\n", + "[[-inf -inf 3. 3.5 2.5 -inf -inf]]\n", + "[[0. 0. 0.30719589 0.50648039 0.18632372 0.\n", + " 0. ]]\n" + ] + } + ], + "source": [ + "logits = np.array([[-1.,1.,3.,3.5,2.5,2.,1.]])\n", + "print(logits)\n", + "\n", + "print(softmax(logits, axis=-1))\n", + "\n", + "top_p = TopPLogitsWarper(top_p=0.9)\n", + "top_k = TopKLogitsWarper(top_k=3)\n", + "\n", + "new_logits = top_p(logits)\n", + "print(new_logits)\n", + "\n", + "new_new_logits = top_k(new_logits)\n", + "print(new_new_logits)\n", + "print(softmax(logits, axis=-1))" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "id": "6914e436-350c-4c6a-b5ca-4146f349b1f3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-1. 1. 3. 5. 2.5 2. 1. ]]\n", + "[[0.00189751 0.01402082 0.10360061 0.76551075 0.06283695 0.03811254\n", + " 0.01402082]]\n", + "[[-inf 1. 3. 5. 2.5 2. 1. ]]\n", + "[[-inf -inf 3. 5. 2.5 -inf -inf]]\n" + ] + } + ], + "source": [ + "logits = np.array([[-1.,1.,3.,5,2.5,2.,1.]])\n", + "print(logits)\n", + "print(softmax(logits, axis=-1))\n", + "\n", + "top_p = TopPLogitsWarper(top_p=0.99)\n", + "top_k = TopKLogitsWarper(top_k=3)\n", + "\n", + "new_logits = top_p(logits)\n", + "print(new_logits)\n", + "\n", + "new_new_logits = top_k(new_logits)\n", + "print(new_new_logits)\n", + "# print(softmax(new_new_logits, axis=-1))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "19ccfb53-8e40-46b4-9987-3081add2ccdb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[False, False, False, False, False]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cumulative_probs <= (1 - top_p)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "483cea72-9955-4742-b599-6f7ec7a6c6d9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0, 4, 3, 1, 2]])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sorted_indices" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "38abf685-b4ae-4dc3-ab45-acf6dfd4036f", + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'utils'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[65], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NextTokenChooser\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# picks largest logits\u001b[39;00m\n\u001b[1;32m 4\u001b[0m ntc \u001b[38;5;241m=\u001b[39m NextTokenChooser(do_sample\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" + ] + } + ], + "source": [ + "from utils import NextTokenChooser\n", + "\n", + "# picks largest logits\n", + "ntc = NextTokenChooser(do_sample=False)\n", + "logits = np.array([[-2,1,3,2]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "ntc(input_ids, logits)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "a599ccae-6457-4bb9-8240-703aaf64474f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n", + "sample / actual 0: 0.004 / 0.004\n", + "sample / actual 1: 0.081 / 0.082\n", + "sample / actual 2: 0.615 / 0.608\n", + "sample / actual 3: 0.219 / 0.224\n", + "sample / actual 4: 0.081 / 0.082\n" + ] + } + ], + "source": [ + "# samples\n", + "ntc = NextTokenChooser(do_sample=True)\n", + "logits = np.array([[-2,1,3,2,1]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "\n", + "probs = ntc.choice.softmax(logits)\n", + "print(probs)\n", + "\n", + "iters = 10000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "76f4d55d-7314-46e8-91e8-4a237d9d2ec1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n", + "sample / actual 0: 0.066 / 0.004\n", + "sample / actual 1: 0.176 / 0.082\n", + "sample / actual 2: 0.342 / 0.608\n", + "sample / actual 3: 0.244 / 0.224\n", + "sample / actual 4: 0.172 / 0.082\n" + ] + } + ], + "source": [ + "# should pick logits that are less likely\n", + "ntc = NextTokenChooser(do_sample=True, temperature=3.0)\n", + "logits = np.array([[-2,1,3,2,1]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "\n", + "probs = ntc.choice.softmax(logits)\n", + "print(probs)\n", + "\n", + "iters = 10000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "e172586d-e0f9-4429-9b73-52dccb2117ca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n", + "sample / actual 0: 0.197 / 0.004\n", + "sample / actual 1: 0.201 / 0.082\n", + "sample / actual 2: 0.205 / 0.608\n", + "sample / actual 3: 0.203 / 0.224\n", + "sample / actual 4: 0.195 / 0.082\n" + ] + } + ], + "source": [ + "# should approach uniform\n", + "ntc = NextTokenChooser(do_sample=True, temperature=100.0)\n", + "logits = np.array([[-2,1,3,2,1]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "\n", + "probs = ntc.choice.softmax(logits)\n", + "print(probs)\n", + "\n", + "iters = 10000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "606f2be7-4ae3-4898-8c52-ec71207a5ea3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n", + "sample / actual 0: 0.000 / 0.004\n", + "sample / actual 1: 0.015 / 0.082\n", + "sample / actual 2: 0.857 / 0.608\n", + "sample / actual 3: 0.111 / 0.224\n", + "sample / actual 4: 0.017 / 0.082\n" + ] + } + ], + "source": [ + "# should pick logits that are more likely\n", + "ntc = NextTokenChooser(do_sample=True, temperature=0.5)\n", + "logits = np.array([[-2,1,3,2,1]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "\n", + "probs = ntc.choice.softmax(logits)\n", + "print(probs)\n", + "\n", + "iters = 10000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "ea0c96a5-fb8f-4f21-a30e-7443d17a440c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n", + "sample / actual 0: 0.000 / 0.004\n", + "sample / actual 1: 0.000 / 0.082\n", + "sample / actual 2: 1.000 / 0.608\n", + "sample / actual 3: 0.000 / 0.224\n", + "sample / actual 4: 0.000 / 0.082\n" + ] + } + ], + "source": [ + "# should approximate greedy sampling\n", + "ntc = NextTokenChooser(do_sample=True, temperature=0.001)\n", + "logits = np.array([[-2,1,3,2,1]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "\n", + "probs = ntc.choice.softmax(logits)\n", + "print(probs)\n", + "\n", + "iters = 10000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "id": "c9dcebaa-fa27-4d32-ad95-6a566d808868", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4\n", + "4\n", + "2\n", + "4\n", + "3\n", + "1\n", + "sample / actual / original 0: 0.017 / 0.016 / 0.004\n", + "sample / actual / original 1: 0.119 / 0.122 / 0.032\n", + "sample / actual / original 2: 0.204 / 0.201 / 0.236\n", + "sample / actual / original 3: 0.329 / 0.331 / 0.087\n", + "sample / actual / original 4: 0.330 / 0.331 / 0.641\n" + ] + } + ], + "source": [ + "# should approximate greedy sampling\n", + "ntc = NextTokenChooser(do_sample=False)\n", + "\n", + "# should select index 4\n", + "logits = np.array([[-2.,1.,3.,2.,5.]])\n", + "input_ids = np.array([[3,4]])\n", + "\n", + "print(ntc(input_ids,logits))\n", + "\n", + "# should select index 4\n", + "ntc = NextTokenChooser(do_sample=False, repetition_penalty=1.0)\n", + "logits = np.array([[-2,1,3,2,5]])\n", + "input_ids = np.array([[3,4]])\n", + "print(ntc(input_ids,logits))\n", + "\n", + "# should select index 2\n", + "ntc = NextTokenChooser(do_sample=False, repetition_penalty=2.0)\n", + "logits = np.array([[-2.,1.,3.,2.,5.]])\n", + "input_ids = np.array([[3,4]])\n", + "print(ntc(input_ids,logits))\n", + "\n", + "# should select index 4\n", + "logits = np.array([[-2.,1.,3.,2.,5.]])\n", + "input_ids = np.array([[2,4]])\n", + "ntc = NextTokenChooser(do_sample=False, repetition_penalty=2.)\n", + "print(ntc(input_ids,logits))\n", + "\n", + "# should select index 3\n", + "logits = np.array([[-2.,1.,3.,2.,5.]])\n", + "input_ids = np.array([[2,4]])\n", + "ntc = NextTokenChooser(do_sample=False, repetition_penalty=3.)\n", + "print(ntc(input_ids,logits))\n", + "\n", + "# should select index 1\n", + "logits = np.array([[-2.,1.,3.,2.,5.]])\n", + "input_ids = np.array([[2,3,4]])\n", + "ntc = NextTokenChooser(do_sample=False, repetition_penalty=5.)\n", + "print(ntc(input_ids,logits))\n", + "\n", + "\n", + "# should make 2,4 less liekly\n", + "logits_og = np.array([[-1.,1.,3.,2.,4.]])\n", + "input_ids = np.array([[2,4]])\n", + "ntc = NextTokenChooser(do_sample=True, repetition_penalty=2.)\n", + "\n", + "original_ps = ntc.choice.softmax(logits_og,axis=1)\n", + "penalty = np.array([[1.,1.,2.,1.,2.]])\n", + "penalty_ps = ntc.choice.softmax(logits_og/penalty,axis=1)\n", + "\n", + "iters = 20000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " logits = logits_og.copy()\n", + " # print(logits)\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual / original {i}: {counts[i] / iters: 0.3f} / {penalty_ps[0,i]: 0.3f} / {original_ps[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "id": "afd3139e-924c-4ac0-a450-c462447a610a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-2. , 1. , 1.5, 2. , 2.5]])" + ] + }, + "execution_count": 111, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logits / penalty" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "ea9f993e-12f8-4bd0-9f8c-885bf8e56a7b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.00446236 0.08962882 0.24363641 0.66227241]]\n", + "{0: 0.0041, 1: 0.0909, 2: 0.2418, 3: 0.6632}\n" + ] + } + ], + "source": [ + "from utils import Sampling\n", + "import numpy as np\n", + "\n", + "sampling = Sampling(seed=1)\n", + "\n", + "logits = np.array([[-2,1,2,3]])\n", + "probs = sampling.softmax(logits, axis=1)\n", + "\n", + "print(probs)\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "iters = 10000\n", + "for i in range(iters):\n", + " counts[sampling(logits=logits)] +=1\n", + "\n", + "for key in counts:\n", + " counts[key] /= iters\n", + "\n", + "print(counts)" + ] + }, { "cell_type": "code", "execution_count": 42, @@ -141,7 +908,7 @@ " }\n", " with requests.post(url, json=obj) as r:\n", " print(max_new_tokens)\n", - " print(r.text)" + " print(r.text)\n" ] }, {