diff --git a/deepsparse/router.py b/deepsparse/router.py index a5d0b2a7..aaa973fa 100644 --- a/deepsparse/router.py +++ b/deepsparse/router.py @@ -2,7 +2,7 @@ from queue import Queue from typing import List, Dict, Optional, Tuple from service.service import DeepSparseService from service.causal_lm import DeepSparseCausalLM -from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, StoppingCriteria +from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, GenerationParameters class DeepSparseRouter: def __init__( @@ -11,10 +11,7 @@ class DeepSparseRouter: model_path: Optional[str] = None, tokenizer_path: Optional[str] = None ): - assert ( - service is not None or - (model_path is not None and tokenizer_path is not None) - ) + assert service is not None or (model_path is not None and tokenizer_path is not None) if service is not None: self.service = service @@ -38,8 +35,8 @@ class DeepSparseRouter: # unblock the batching task with a dummy request if blocked self.queue.append(GenerateRequest( - inputs="dummy", - max_new_tokens=1, + inputs="stop", + generation_parameters=GenerationParameters(max_new_tokens=1), response_stream=Queue() )) @@ -166,9 +163,10 @@ class DeepSparseQueue: # if block = True, this blocks until something ready # if block = False, the queue has data (if not an exception is raised) - # while queue.empty() == False does not guarentee data - # the queue is only subscribed to by one thread (this one) - # since batching_task is the only function that calls next_batch + # while queue.empty() == False typically not guarentee data on next queue.get(), this + # queue is only subscribed to by one thread (this one) since batching_task is the only + # so it does in our case + generate_request = self.queue.get(block=block) generate_requests = {self.next_request_id: generate_request} @@ -176,7 +174,7 @@ class DeepSparseQueue: request = Request( id=self.next_request_id, inputs=generate_request.inputs, - max_new_tokens=generate_request.max_new_tokens + generation_parameters=generate_request.generation_parameters, ) self.next_request_id += 1 diff --git a/deepsparse/service/causal_lm.py b/deepsparse/service/causal_lm.py index b08a6bdf..23d0021f 100644 --- a/deepsparse/service/causal_lm.py +++ b/deepsparse/service/causal_lm.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase import numpy as np from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel -from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria +from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria, NextTokenChooser DEEPSPARSE_SEQUENCE_LENGTH = 128 DEEPSPARSE_MULTITOKEN_LENGTH = 4 @@ -17,6 +17,7 @@ class DeepSparseCausalLMBatch: input_ids_list: List[np.ndarray] past_key_values_list: List[Optional[DeepSparsePastKeyValues]] stopping_criteria_list: List[StoppingCriteria] + next_token_chooser_list: List[NextTokenChooser] @classmethod def from_batch( @@ -29,6 +30,7 @@ class DeepSparseCausalLMBatch: requests_idx_mapping = {} input_ids_list = [] stopping_criteria_list = [] + next_token_chooser_list = [] # loop through items in the batch for idx, r in enumerate(batch.requests): @@ -51,7 +53,19 @@ class DeepSparseCausalLMBatch: stopping_criteria_list.append( StoppingCriteria( eos_token_id=tokenizer.eos_token_id, - max_new_tokens=min(r.max_new_tokens, model_max_new_tokens) + max_new_tokens=min(r.generation_parameters.max_new_tokens, model_max_new_tokens) + ) + ) + + # get next token chooser based on input + next_token_chooser_list.append( + NextTokenChooser( + repetition_penalty=r.generation_parameters.repetition_penalty, + temperature=r.generation_parameters.temperature, + top_k=r.generation_parameters.top_k, + top_p=r.generation_parameters.top_p, + do_sample=r.generation_parameters.do_sample, + seed=r.generation_parameters.seed, ) ) @@ -61,7 +75,8 @@ class DeepSparseCausalLMBatch: requests_idx_mapping=requests_idx_mapping, input_ids_list=input_ids_list, past_key_values_list=[None] * len(batch.requests), - stopping_criteria_list=stopping_criteria_list + stopping_criteria_list=stopping_criteria_list, + next_token_chooser_list=next_token_chooser_list, ) def to_cached_batch(self) -> CachedBatch: @@ -78,11 +93,12 @@ class DeepSparseCausalLMBatch: def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]: assert(len(request_ids) > 0) - requests_idx_mapping = {} - requests = [] - input_ids_list = [] - past_key_values_list = [] - stopping_criteria_list = [] + requests_idx_mapping = {} + requests = [] + input_ids_list = [] + past_key_values_list = [] + stopping_criteria_list = [] + next_token_chooser_list = [] # loop through requests, keep ones that should remain for new_idx, request_id in enumerate(request_ids): @@ -95,6 +111,7 @@ class DeepSparseCausalLMBatch: input_ids_list.append(self.input_ids_list[old_idx]) past_key_values_list.append(self.past_key_values_list[old_idx]) stopping_criteria_list.append(self.stopping_criteria_list[old_idx]) + next_token_chooser_list.append(self.next_token_chooser_list[old_idx]) # update batch state self.requests = requests @@ -102,6 +119,7 @@ class DeepSparseCausalLMBatch: self.input_ids_list = input_ids_list self.past_key_values_list = past_key_values_list self.stopping_criteria_list = stopping_criteria_list + self.next_token_chooser_list = next_token_chooser_list return self @@ -110,11 +128,12 @@ class DeepSparseCausalLMBatch: def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch": assert len(batches) > 1, "must have more than 1 batch to concatenate" - requests_idx_mapping = {} - requests = [] - input_ids_list = [] - past_key_values_list = [] - stopping_criteria_list = [] + requests_idx_mapping = {} + requests = [] + input_ids_list = [] + past_key_values_list = [] + stopping_criteria_list = [] + next_token_chooser_list = [] start_index = 0 for i, batch in enumerate(batches): @@ -125,6 +144,7 @@ class DeepSparseCausalLMBatch: input_ids_list.extend(batch.input_ids_list) past_key_values_list.extend(batch.past_key_values_list) stopping_criteria_list.extend(batch.stopping_criteria_list) + next_token_chooser_list.extend(batch.next_token_chooser_list) # merge the request_id to index mapping if i == 0: @@ -141,7 +161,8 @@ class DeepSparseCausalLMBatch: requests_idx_mapping=requests_idx_mapping, input_ids_list=input_ids_list, past_key_values_list=past_key_values_list, - stopping_criteria_list=stopping_criteria_list + stopping_criteria_list=stopping_criteria_list, + next_token_chooser_list=next_token_chooser_list ) class DeepSparseCausalLM: @@ -164,18 +185,6 @@ class DeepSparseCausalLM: multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, ) - # TODO: switch to NextTokenChooser - def sample_token( - self, - logits: np.ndarray - ): - # assert b=1 for now - assert(logits.shape[0] == 1) - - # grab logits for the last item in the sequence - # shape == (batch, seq, vocabulary_size) - return np.argmax(logits[0,-1,:]) - def generate_token( self, batch: DeepSparseCausalLMBatch, @@ -189,14 +198,21 @@ class DeepSparseCausalLM: # b) sample and check stopping criteria # c) create generation # d) update batch - for i, (request, input_ids, past_key_values,stopping_criteria,) in enumerate( - zip( - batch.requests, - batch.input_ids_list, - batch.past_key_values_list, - batch.stopping_criteria_list, - ) - ): + + iterator = zip( + batch.requests, + batch.input_ids_list, + batch.past_key_values_list, + batch.stopping_criteria_list, + batch.next_token_chooser_list, + ) + for i, ( + request, + input_ids, + past_key_values, + stopping_criteria, + next_token_chooser + ) in enumerate(iterator): # assert input_ids is b=1 assert len(input_ids.shape) == 2 assert input_ids.shape[0] == 1 @@ -206,7 +222,7 @@ class DeepSparseCausalLM: # b) sample token and check stopping criteria # TODO: should use NextTokenChooser/StoppingCriteria (simple for now) - generated_token_id = self.sample_token(logits) + generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:]) generated_token = self.tokenizer.decode(generated_token_id) stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id) diff --git a/deepsparse/utils.py b/deepsparse/utils.py index b5ee6b05..fda83570 100644 --- a/deepsparse/utils.py +++ b/deepsparse/utils.py @@ -1,8 +1,73 @@ +import numpy as np from dataclasses import dataclass +from pydantic import BaseModel, Field from queue import Queue from enum import Enum from typing import List, Optional +from logits_process import (RepetitionPenaltyLogitsProcessor, LogitsWarpers, softmax) +# TODO: sample for b > 1 with vectorized code +class Greedy: + def __call__(self, logits: np.ndarray): + # assert b=1 for now + # shape == (batch, vocabulary_size) + assert(logits.shape[0] == 1) + assert(len(logits.shape) == 2) + + return np.argmax(logits[0,:]) + +# TODO: sample for b > 1 with vectorized code +# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a +class Sampling: + def __init__(self, seed:int=42): + self.generator = np.random.default_rng(seed=seed) + + def __call__(self, logits:np.ndarray): + # assert b=1 for now + # shape == (batch, vocabulary_size) + assert(logits.shape[0] == 1) + assert(len(logits.shape) == 2) + + probs = softmax(logits, axis=1) + return self.generator.choice(probs.shape[1], p=probs[0,:]) + +class NextTokenChooser: + def __init__( + self, + repetition_penalty: Optional[float] = 1.0, + temperature: Optional[float] = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + do_sample:bool = False, + seed: int = 42, + ): + self.repetition_processor = ( + RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) + if repetition_penalty and repetition_penalty != 1.0 else None + ) + + has_warpers = ( + (temperature is not None and temperature != 1.0) + or (top_k is not None and top_k != 0) + or (top_p is not None and top_p < 1.0) + ) + + if has_warpers: + self.warpers = LogitsWarpers(temperature=temperature, top_k=top_k, top_p=top_p) + else: + self.warpers = None + + self.choice = Sampling(seed=seed) if do_sample or has_warpers else Greedy() + + def __call__(self, input_ids: np.ndarray, scores:np.ndarray) -> int: + if self.repetition_processor is not None: + scores = self.repetition_processor(input_ids=input_ids, scores=scores) + + if self.warpers is not None: + scores = self.warpers(scores=scores) + + return self.choice(scores) + class FinishReason(Enum): FINISH_REASON_LENGTH = 1 FINISH_REASON_EOS_TOKEN = 2 @@ -28,11 +93,28 @@ class StoppingCriteria: return False, None +class GenerationParameters(BaseModel): + max_new_tokens: int = Field(default=20) + repetition_penalty: float = Field(default=1) + do_sample: bool = Field(default=False) + temperature: float = Field(default=1.0) + top_k: Optional[int] = Field(default=None) + top_p: Optional[float] = Field(default=None) + seed: int = Field(default=42) + +class GenerateRequestInputs(BaseModel): + inputs: str + generation_parameters: GenerationParameters + +class GenerateRequestOutputs(BaseModel): + response_text: str = Field(default="") + finish_reason: Optional[FinishReason] = Field(default=None) + @dataclass class Request: id: int inputs: str - max_new_tokens: int + generation_parameters: GenerationParameters @dataclass class Batch: @@ -58,5 +140,13 @@ class Generation: @dataclass class GenerateRequest: inputs: str - max_new_tokens: int - response_stream: Queue[Generation] \ No newline at end of file + generation_parameters: GenerationParameters + response_stream: Queue[Generation] + + @classmethod + def from_gr_inputs(cls, gr_inputs: GenerateRequestInputs): + return cls( + inputs=gr_inputs.inputs, + generation_parameters=gr_inputs.generation_parameters, + response_stream=Queue() + ) \ No newline at end of file diff --git a/server-dev.ipynb b/server-dev.ipynb index 23a48e3b..95d320cb 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -11,6 +11,773 @@ "%autoreload 2" ] }, + { + "cell_type": "code", + "execution_count": 89, + "id": "7513b2a1-0749-44b5-88a9-d91d5b175e8e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "45\n", + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + "def fib(n):\n", + " if n<=1:\n", + " return n\n", + " else:\n", + " return fib(n-1)+fib(n-2)\n", + " \n", + "print(fib(15))\n", + "\n", + "\n", + "50\n", + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + "def fib(n):\n", + " if n == 0:\n", + " return 0\n", + " elif n == 1:\n", + " return 1\n", + " else:\n", + " return fib(n - 1) + fib(n - 2)\n", + "\n", + "n = int(\n", + "55\n", + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + "def fib(n):\n", + " x = 1\n", + " y = 1\n", + " if n < 2:\n", + " return 1\n", + " else:\n", + " for i in range(2, n):\n", + " z = x + y\n", + " x = y\n", + " y = z\n", + "\n" + ] + } + ], + "source": [ + "import requests\n", + "from threading import Thread\n", + "import json\n", + "\n", + "url = \"http://127.0.0.1:5543/generate\"\n", + "sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\"\n", + "# sequence = \"def fib(n):\"\n", + "\n", + "def request_task(seed, max_new_tokens):\n", + " obj = {\n", + " \"inputs\": sequence,\n", + " \"generation_parameters\": {\n", + " \"max_new_tokens\":max_new_tokens,\n", + " # \"repetition_penalty\": 1.1,\n", + " # \"do_sample\": True,\n", + " # \"temperature\": 1.1,\n", + " # \"top_k\": 3,\n", + " # \"top_p\": 0.9,\n", + " \"seed\": seed,\n", + " }\n", + " }\n", + " with requests.post(url, json=obj) as r:\n", + " print(max_new_tokens)\n", + " dct = json.loads(r.text)\n", + " print(f'{sequence}{dct[\"response_text\"]}')\n", + "\n", + "max_new_tokens_lst = [55, 50, 45]\n", + "seeds = [1,2,3]\n", + "# max_new_tokens_lst = [100, 200, 300]\n", + "\n", + "request_ts = [\n", + " Thread(target=request_task, args=[seed, max_new_tokens]) for seed, max_new_tokens in zip(seeds, max_new_tokens_lst)\n", + "]\n", + "\n", + "import time\n", + "for request_t in request_ts:\n", + " request_t.start()\n", + " time.sleep(0.1)\n", + "\n", + "for request_t in request_ts:\n", + " request_t.join()" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "28964c0b-0822-48aa-bc4c-a519035fae9a", + "metadata": {}, + "outputs": [], + "source": [ + "obj = {\n", + " \"inputs\": sequence,\n", + " \"generation_parameters\": {\n", + " \"max_new_tokens\":100,\n", + " \"repetion_penalty\": 2.0,\n", + " # \"do_sample\": False,\n", + " # \"temperature\": 1.0,\n", + " # \"top_k\": 1,\n", + " # \"top_p\": 0.9,\n", + " # \"seed\": 43,\n", + " }\n", + " }\n", + "\n", + "resp = requests.post(url, json=obj)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "bdb67325-1341-4da4-808f-c1b12fdbf607", + "metadata": {}, + "outputs": [], + "source": [ + "a = \"\"\n", + "token = '\\n'\n", + "for i in range(5):\n", + " a += token" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "96aebc29-2c43-456e-ab62-3c422feb57c6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(a)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "85447c0d-7020-46af-826f-fec58c9097b8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\"\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef\"\n" + ] + } + ], + "source": [ + "print(resp.text)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ff94c668-5832-4096-a723-3fc673b90495", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/robertgshaw/deepsparse-continuous-batching/deepsparse\n" + ] + } + ], + "source": [ + "%cd deepsparse" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d314c9f2-3fa7-44c4-8609-3bb24d008bb0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-1. 1. 4. 3. 2. 1.]]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from logits_process import TopKLogitsWarper, TopPLogitsWarper, softmax\n", + "\n", + "scores = np.array([[-1.,1.,4.,3.,2.,1.]])\n", + "\n", + "sorted_indices = np.argsort(scores, axis=-1)\n", + "sorted_scores = np.take_along_axis(scores, sorted_indices, axis=-1)\n", + "\n", + "# sort, grabbing all those outside top_p\n", + "cumulative_probs = softmax(sorted_scores, axis=-1).cumsum(axis=-1)\n", + "sorted_indices_to_remove = cumulative_probs <= (1 - 0.9)\n", + "\n", + "\n", + "# note: this relies on b=1\n", + "indices_to_remove = sorted_indices[sorted_indices_to_remove]\n", + "\n", + "# set removed indices logits to -Inf (never selected)\n", + "scores[:, indices_to_remove] = -float(\"Inf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "cced6917-8017-405a-ae93-c43105fce82d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-2. 1. 5. 3. 2. 1.]\n", + " [ 5. -2. 3. -1. 2. 1.]]\n", + "[[-inf -inf 5. 3. 2. -inf]\n", + " [ 5. -inf 3. -inf 2. -inf]]\n", + "[[0. 0. 0.84379473 0.1141952 0.04201007 0. ]\n", + " [0.84379473 0. 0.1141952 0. 0.04201007 0. ]]\n" + ] + } + ], + "source": [ + "from logits_process import TopKLogitsWarper, TopPLogitsWarper, softmax\n", + "\n", + "logits = np.array([[-2.,1.,5.,3.,2.,1.], [5.,-2.,3.,-1.,2.,1.]])\n", + "print(logits)\n", + "\n", + "top_k = TopKLogitsWarper(top_k=3)\n", + "new_logits = top_k(logits)\n", + "print(new_logits)\n", + "\n", + "print(softmax(new_logits, axis=1))" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "44fb3e44-0888-49af-a5d6-2c908db324ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-1. 1. 4. 3. 2. 1.]]\n", + "[[0.00418629 0.03093274 0.62130062 0.22856372 0.0840839 0.03093274]]\n", + "[[-inf -inf 4. 3. 2. -inf]]\n", + "[[0. 0. 0.66524096 0.24472847 0.09003057 0. ]]\n" + ] + } + ], + "source": [ + "logits = np.array([[-1.,1.,4.,3.,2.,1.]])\n", + "print(logits)\n", + "\n", + "print(softmax(logits, axis=-1))\n", + "\n", + "top_p = TopPLogitsWarper(top_p=0.9)\n", + "new_logits = top_p(logits)\n", + "print(new_logits)\n", + "\n", + "print(softmax(new_logits, axis=-1))" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "f20813aa-0362-4820-a3b2-c32531658718", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-1. 1. 3. 3.5 2.5 2. 1. ]]\n", + "[[0.00468177 0.03459387 0.25561604 0.4214396 0.15503896 0.09403589\n", + " 0.03459387]]\n", + "[[-inf -inf 3. 3.5 2.5 2. -inf]]\n", + "[[-inf -inf 3. 3.5 2.5 -inf -inf]]\n", + "[[0. 0. 0.30719589 0.50648039 0.18632372 0.\n", + " 0. ]]\n" + ] + } + ], + "source": [ + "logits = np.array([[-1.,1.,3.,3.5,2.5,2.,1.]])\n", + "print(logits)\n", + "\n", + "print(softmax(logits, axis=-1))\n", + "\n", + "top_p = TopPLogitsWarper(top_p=0.9)\n", + "top_k = TopKLogitsWarper(top_k=3)\n", + "\n", + "new_logits = top_p(logits)\n", + "print(new_logits)\n", + "\n", + "new_new_logits = top_k(new_logits)\n", + "print(new_new_logits)\n", + "print(softmax(logits, axis=-1))" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "id": "6914e436-350c-4c6a-b5ca-4146f349b1f3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-1. 1. 3. 5. 2.5 2. 1. ]]\n", + "[[0.00189751 0.01402082 0.10360061 0.76551075 0.06283695 0.03811254\n", + " 0.01402082]]\n", + "[[-inf 1. 3. 5. 2.5 2. 1. ]]\n", + "[[-inf -inf 3. 5. 2.5 -inf -inf]]\n" + ] + } + ], + "source": [ + "logits = np.array([[-1.,1.,3.,5,2.5,2.,1.]])\n", + "print(logits)\n", + "print(softmax(logits, axis=-1))\n", + "\n", + "top_p = TopPLogitsWarper(top_p=0.99)\n", + "top_k = TopKLogitsWarper(top_k=3)\n", + "\n", + "new_logits = top_p(logits)\n", + "print(new_logits)\n", + "\n", + "new_new_logits = top_k(new_logits)\n", + "print(new_new_logits)\n", + "# print(softmax(new_new_logits, axis=-1))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "19ccfb53-8e40-46b4-9987-3081add2ccdb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[False, False, False, False, False]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cumulative_probs <= (1 - top_p)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "483cea72-9955-4742-b599-6f7ec7a6c6d9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0, 4, 3, 1, 2]])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sorted_indices" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "38abf685-b4ae-4dc3-ab45-acf6dfd4036f", + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'utils'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[65], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NextTokenChooser\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# picks largest logits\u001b[39;00m\n\u001b[1;32m 4\u001b[0m ntc \u001b[38;5;241m=\u001b[39m NextTokenChooser(do_sample\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" + ] + } + ], + "source": [ + "from utils import NextTokenChooser\n", + "\n", + "# picks largest logits\n", + "ntc = NextTokenChooser(do_sample=False)\n", + "logits = np.array([[-2,1,3,2]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "ntc(input_ids, logits)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "a599ccae-6457-4bb9-8240-703aaf64474f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n", + "sample / actual 0: 0.004 / 0.004\n", + "sample / actual 1: 0.081 / 0.082\n", + "sample / actual 2: 0.615 / 0.608\n", + "sample / actual 3: 0.219 / 0.224\n", + "sample / actual 4: 0.081 / 0.082\n" + ] + } + ], + "source": [ + "# samples\n", + "ntc = NextTokenChooser(do_sample=True)\n", + "logits = np.array([[-2,1,3,2,1]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "\n", + "probs = ntc.choice.softmax(logits)\n", + "print(probs)\n", + "\n", + "iters = 10000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "76f4d55d-7314-46e8-91e8-4a237d9d2ec1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n", + "sample / actual 0: 0.066 / 0.004\n", + "sample / actual 1: 0.176 / 0.082\n", + "sample / actual 2: 0.342 / 0.608\n", + "sample / actual 3: 0.244 / 0.224\n", + "sample / actual 4: 0.172 / 0.082\n" + ] + } + ], + "source": [ + "# should pick logits that are less likely\n", + "ntc = NextTokenChooser(do_sample=True, temperature=3.0)\n", + "logits = np.array([[-2,1,3,2,1]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "\n", + "probs = ntc.choice.softmax(logits)\n", + "print(probs)\n", + "\n", + "iters = 10000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "e172586d-e0f9-4429-9b73-52dccb2117ca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n", + "sample / actual 0: 0.197 / 0.004\n", + "sample / actual 1: 0.201 / 0.082\n", + "sample / actual 2: 0.205 / 0.608\n", + "sample / actual 3: 0.203 / 0.224\n", + "sample / actual 4: 0.195 / 0.082\n" + ] + } + ], + "source": [ + "# should approach uniform\n", + "ntc = NextTokenChooser(do_sample=True, temperature=100.0)\n", + "logits = np.array([[-2,1,3,2,1]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "\n", + "probs = ntc.choice.softmax(logits)\n", + "print(probs)\n", + "\n", + "iters = 10000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "606f2be7-4ae3-4898-8c52-ec71207a5ea3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n", + "sample / actual 0: 0.000 / 0.004\n", + "sample / actual 1: 0.015 / 0.082\n", + "sample / actual 2: 0.857 / 0.608\n", + "sample / actual 3: 0.111 / 0.224\n", + "sample / actual 4: 0.017 / 0.082\n" + ] + } + ], + "source": [ + "# should pick logits that are more likely\n", + "ntc = NextTokenChooser(do_sample=True, temperature=0.5)\n", + "logits = np.array([[-2,1,3,2,1]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "\n", + "probs = ntc.choice.softmax(logits)\n", + "print(probs)\n", + "\n", + "iters = 10000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "ea0c96a5-fb8f-4f21-a30e-7443d17a440c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n", + "sample / actual 0: 0.000 / 0.004\n", + "sample / actual 1: 0.000 / 0.082\n", + "sample / actual 2: 1.000 / 0.608\n", + "sample / actual 3: 0.000 / 0.224\n", + "sample / actual 4: 0.000 / 0.082\n" + ] + } + ], + "source": [ + "# should approximate greedy sampling\n", + "ntc = NextTokenChooser(do_sample=True, temperature=0.001)\n", + "logits = np.array([[-2,1,3,2,1]])\n", + "input_ids = np.array([[1,2,3,4]])\n", + "\n", + "probs = ntc.choice.softmax(logits)\n", + "print(probs)\n", + "\n", + "iters = 10000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "id": "c9dcebaa-fa27-4d32-ad95-6a566d808868", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4\n", + "4\n", + "2\n", + "4\n", + "3\n", + "1\n", + "sample / actual / original 0: 0.017 / 0.016 / 0.004\n", + "sample / actual / original 1: 0.119 / 0.122 / 0.032\n", + "sample / actual / original 2: 0.204 / 0.201 / 0.236\n", + "sample / actual / original 3: 0.329 / 0.331 / 0.087\n", + "sample / actual / original 4: 0.330 / 0.331 / 0.641\n" + ] + } + ], + "source": [ + "# should approximate greedy sampling\n", + "ntc = NextTokenChooser(do_sample=False)\n", + "\n", + "# should select index 4\n", + "logits = np.array([[-2.,1.,3.,2.,5.]])\n", + "input_ids = np.array([[3,4]])\n", + "\n", + "print(ntc(input_ids,logits))\n", + "\n", + "# should select index 4\n", + "ntc = NextTokenChooser(do_sample=False, repetition_penalty=1.0)\n", + "logits = np.array([[-2,1,3,2,5]])\n", + "input_ids = np.array([[3,4]])\n", + "print(ntc(input_ids,logits))\n", + "\n", + "# should select index 2\n", + "ntc = NextTokenChooser(do_sample=False, repetition_penalty=2.0)\n", + "logits = np.array([[-2.,1.,3.,2.,5.]])\n", + "input_ids = np.array([[3,4]])\n", + "print(ntc(input_ids,logits))\n", + "\n", + "# should select index 4\n", + "logits = np.array([[-2.,1.,3.,2.,5.]])\n", + "input_ids = np.array([[2,4]])\n", + "ntc = NextTokenChooser(do_sample=False, repetition_penalty=2.)\n", + "print(ntc(input_ids,logits))\n", + "\n", + "# should select index 3\n", + "logits = np.array([[-2.,1.,3.,2.,5.]])\n", + "input_ids = np.array([[2,4]])\n", + "ntc = NextTokenChooser(do_sample=False, repetition_penalty=3.)\n", + "print(ntc(input_ids,logits))\n", + "\n", + "# should select index 1\n", + "logits = np.array([[-2.,1.,3.,2.,5.]])\n", + "input_ids = np.array([[2,3,4]])\n", + "ntc = NextTokenChooser(do_sample=False, repetition_penalty=5.)\n", + "print(ntc(input_ids,logits))\n", + "\n", + "\n", + "# should make 2,4 less liekly\n", + "logits_og = np.array([[-1.,1.,3.,2.,4.]])\n", + "input_ids = np.array([[2,4]])\n", + "ntc = NextTokenChooser(do_sample=True, repetition_penalty=2.)\n", + "\n", + "original_ps = ntc.choice.softmax(logits_og,axis=1)\n", + "penalty = np.array([[1.,1.,2.,1.,2.]])\n", + "penalty_ps = ntc.choice.softmax(logits_og/penalty,axis=1)\n", + "\n", + "iters = 20000\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "\n", + "for _ in range(iters):\n", + " logits = logits_og.copy()\n", + " # print(logits)\n", + " pred = ntc(input_ids, logits)\n", + " counts[pred] += 1\n", + "\n", + "for i in range(logits.shape[1]):\n", + " print(f\"sample / actual / original {i}: {counts[i] / iters: 0.3f} / {penalty_ps[0,i]: 0.3f} / {original_ps[0,i]: 0.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "id": "afd3139e-924c-4ac0-a450-c462447a610a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-2. , 1. , 1.5, 2. , 2.5]])" + ] + }, + "execution_count": 111, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logits / penalty" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "ea9f993e-12f8-4bd0-9f8c-885bf8e56a7b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.00446236 0.08962882 0.24363641 0.66227241]]\n", + "{0: 0.0041, 1: 0.0909, 2: 0.2418, 3: 0.6632}\n" + ] + } + ], + "source": [ + "from utils import Sampling\n", + "import numpy as np\n", + "\n", + "sampling = Sampling(seed=1)\n", + "\n", + "logits = np.array([[-2,1,2,3]])\n", + "probs = sampling.softmax(logits, axis=1)\n", + "\n", + "print(probs)\n", + "counts = {a: 0 for a in range(logits.shape[1])}\n", + "iters = 10000\n", + "for i in range(iters):\n", + " counts[sampling(logits=logits)] +=1\n", + "\n", + "for key in counts:\n", + " counts[key] /= iters\n", + "\n", + "print(counts)" + ] + }, { "cell_type": "code", "execution_count": 42, @@ -141,7 +908,7 @@ " }\n", " with requests.post(url, json=obj) as r:\n", " print(max_new_tokens)\n", - " print(r.text)" + " print(r.text)\n" ] }, {