successfully implemented and integrated nexttokenchooser with support for repetition penalty, do_sample, temperature, top_k, and top_p

This commit is contained in:
rsnm2 2023-08-28 02:05:43 +00:00
parent a875c05ccd
commit bb124e4029
4 changed files with 921 additions and 50 deletions

View File

@ -2,7 +2,7 @@ from queue import Queue
from typing import List, Dict, Optional, Tuple
from service.service import DeepSparseService
from service.causal_lm import DeepSparseCausalLM
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, StoppingCriteria
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, GenerationParameters
class DeepSparseRouter:
def __init__(
@ -11,10 +11,7 @@ class DeepSparseRouter:
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None
):
assert (
service is not None or
(model_path is not None and tokenizer_path is not None)
)
assert service is not None or (model_path is not None and tokenizer_path is not None)
if service is not None:
self.service = service
@ -38,8 +35,8 @@ class DeepSparseRouter:
# unblock the batching task with a dummy request if blocked
self.queue.append(GenerateRequest(
inputs="dummy",
max_new_tokens=1,
inputs="stop",
generation_parameters=GenerationParameters(max_new_tokens=1),
response_stream=Queue()
))
@ -166,9 +163,10 @@ class DeepSparseQueue:
# if block = True, this blocks until something ready
# if block = False, the queue has data (if not an exception is raised)
# while queue.empty() == False does not guarentee data
# the queue is only subscribed to by one thread (this one)
# since batching_task is the only function that calls next_batch
# while queue.empty() == False typically not guarentee data on next queue.get(), this
# queue is only subscribed to by one thread (this one) since batching_task is the only
# so it does in our case
generate_request = self.queue.get(block=block)
generate_requests = {self.next_request_id: generate_request}
@ -176,7 +174,7 @@ class DeepSparseQueue:
request = Request(
id=self.next_request_id,
inputs=generate_request.inputs,
max_new_tokens=generate_request.max_new_tokens
generation_parameters=generate_request.generation_parameters,
)
self.next_request_id += 1

View File

@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
import numpy as np
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria
from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria, NextTokenChooser
DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4
@ -17,6 +17,7 @@ class DeepSparseCausalLMBatch:
input_ids_list: List[np.ndarray]
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
stopping_criteria_list: List[StoppingCriteria]
next_token_chooser_list: List[NextTokenChooser]
@classmethod
def from_batch(
@ -29,6 +30,7 @@ class DeepSparseCausalLMBatch:
requests_idx_mapping = {}
input_ids_list = []
stopping_criteria_list = []
next_token_chooser_list = []
# loop through items in the batch
for idx, r in enumerate(batch.requests):
@ -51,7 +53,19 @@ class DeepSparseCausalLMBatch:
stopping_criteria_list.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=min(r.max_new_tokens, model_max_new_tokens)
max_new_tokens=min(r.generation_parameters.max_new_tokens, model_max_new_tokens)
)
)
# get next token chooser based on input
next_token_chooser_list.append(
NextTokenChooser(
repetition_penalty=r.generation_parameters.repetition_penalty,
temperature=r.generation_parameters.temperature,
top_k=r.generation_parameters.top_k,
top_p=r.generation_parameters.top_p,
do_sample=r.generation_parameters.do_sample,
seed=r.generation_parameters.seed,
)
)
@ -61,7 +75,8 @@ class DeepSparseCausalLMBatch:
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=[None] * len(batch.requests),
stopping_criteria_list=stopping_criteria_list
stopping_criteria_list=stopping_criteria_list,
next_token_chooser_list=next_token_chooser_list,
)
def to_cached_batch(self) -> CachedBatch:
@ -83,6 +98,7 @@ class DeepSparseCausalLMBatch:
input_ids_list = []
past_key_values_list = []
stopping_criteria_list = []
next_token_chooser_list = []
# loop through requests, keep ones that should remain
for new_idx, request_id in enumerate(request_ids):
@ -95,6 +111,7 @@ class DeepSparseCausalLMBatch:
input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values_list[old_idx])
stopping_criteria_list.append(self.stopping_criteria_list[old_idx])
next_token_chooser_list.append(self.next_token_chooser_list[old_idx])
# update batch state
self.requests = requests
@ -102,6 +119,7 @@ class DeepSparseCausalLMBatch:
self.input_ids_list = input_ids_list
self.past_key_values_list = past_key_values_list
self.stopping_criteria_list = stopping_criteria_list
self.next_token_chooser_list = next_token_chooser_list
return self
@ -115,6 +133,7 @@ class DeepSparseCausalLMBatch:
input_ids_list = []
past_key_values_list = []
stopping_criteria_list = []
next_token_chooser_list = []
start_index = 0
for i, batch in enumerate(batches):
@ -125,6 +144,7 @@ class DeepSparseCausalLMBatch:
input_ids_list.extend(batch.input_ids_list)
past_key_values_list.extend(batch.past_key_values_list)
stopping_criteria_list.extend(batch.stopping_criteria_list)
next_token_chooser_list.extend(batch.next_token_chooser_list)
# merge the request_id to index mapping
if i == 0:
@ -141,7 +161,8 @@ class DeepSparseCausalLMBatch:
requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list,
past_key_values_list=past_key_values_list,
stopping_criteria_list=stopping_criteria_list
stopping_criteria_list=stopping_criteria_list,
next_token_chooser_list=next_token_chooser_list
)
class DeepSparseCausalLM:
@ -164,18 +185,6 @@ class DeepSparseCausalLM:
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
)
# TODO: switch to NextTokenChooser
def sample_token(
self,
logits: np.ndarray
):
# assert b=1 for now
assert(logits.shape[0] == 1)
# grab logits for the last item in the sequence
# shape == (batch, seq, vocabulary_size)
return np.argmax(logits[0,-1,:])
def generate_token(
self,
batch: DeepSparseCausalLMBatch,
@ -189,14 +198,21 @@ class DeepSparseCausalLM:
# b) sample and check stopping criteria
# c) create generation
# d) update batch
for i, (request, input_ids, past_key_values,stopping_criteria,) in enumerate(
zip(
iterator = zip(
batch.requests,
batch.input_ids_list,
batch.past_key_values_list,
batch.stopping_criteria_list,
batch.next_token_chooser_list,
)
):
for i, (
request,
input_ids,
past_key_values,
stopping_criteria,
next_token_chooser
) in enumerate(iterator):
# assert input_ids is b=1
assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1
@ -206,7 +222,7 @@ class DeepSparseCausalLM:
# b) sample token and check stopping criteria
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
generated_token_id = self.sample_token(logits)
generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:])
generated_token = self.tokenizer.decode(generated_token_id)
stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id)

View File

@ -1,7 +1,72 @@
import numpy as np
from dataclasses import dataclass
from pydantic import BaseModel, Field
from queue import Queue
from enum import Enum
from typing import List, Optional
from logits_process import (RepetitionPenaltyLogitsProcessor, LogitsWarpers, softmax)
# TODO: sample for b > 1 with vectorized code
class Greedy:
def __call__(self, logits: np.ndarray):
# assert b=1 for now
# shape == (batch, vocabulary_size)
assert(logits.shape[0] == 1)
assert(len(logits.shape) == 2)
return np.argmax(logits[0,:])
# TODO: sample for b > 1 with vectorized code
# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a
class Sampling:
def __init__(self, seed:int=42):
self.generator = np.random.default_rng(seed=seed)
def __call__(self, logits:np.ndarray):
# assert b=1 for now
# shape == (batch, vocabulary_size)
assert(logits.shape[0] == 1)
assert(len(logits.shape) == 2)
probs = softmax(logits, axis=1)
return self.generator.choice(probs.shape[1], p=probs[0,:])
class NextTokenChooser:
def __init__(
self,
repetition_penalty: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
do_sample:bool = False,
seed: int = 42,
):
self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty and repetition_penalty != 1.0 else None
)
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
)
if has_warpers:
self.warpers = LogitsWarpers(temperature=temperature, top_k=top_k, top_p=top_p)
else:
self.warpers = None
self.choice = Sampling(seed=seed) if do_sample or has_warpers else Greedy()
def __call__(self, input_ids: np.ndarray, scores:np.ndarray) -> int:
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids=input_ids, scores=scores)
if self.warpers is not None:
scores = self.warpers(scores=scores)
return self.choice(scores)
class FinishReason(Enum):
FINISH_REASON_LENGTH = 1
@ -28,11 +93,28 @@ class StoppingCriteria:
return False, None
class GenerationParameters(BaseModel):
max_new_tokens: int = Field(default=20)
repetition_penalty: float = Field(default=1)
do_sample: bool = Field(default=False)
temperature: float = Field(default=1.0)
top_k: Optional[int] = Field(default=None)
top_p: Optional[float] = Field(default=None)
seed: int = Field(default=42)
class GenerateRequestInputs(BaseModel):
inputs: str
generation_parameters: GenerationParameters
class GenerateRequestOutputs(BaseModel):
response_text: str = Field(default="")
finish_reason: Optional[FinishReason] = Field(default=None)
@dataclass
class Request:
id: int
inputs: str
max_new_tokens: int
generation_parameters: GenerationParameters
@dataclass
class Batch:
@ -58,5 +140,13 @@ class Generation:
@dataclass
class GenerateRequest:
inputs: str
max_new_tokens: int
generation_parameters: GenerationParameters
response_stream: Queue[Generation]
@classmethod
def from_gr_inputs(cls, gr_inputs: GenerateRequestInputs):
return cls(
inputs=gr_inputs.inputs,
generation_parameters=gr_inputs.generation_parameters,
response_stream=Queue()
)

View File

@ -11,6 +11,773 @@
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 89,
"id": "7513b2a1-0749-44b5-88a9-d91d5b175e8e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"45\n",
"Finish the following function for computing a fibonacci sequence: \n",
"\n",
"def fib(n):\n",
" if n<=1:\n",
" return n\n",
" else:\n",
" return fib(n-1)+fib(n-2)\n",
" \n",
"print(fib(15))\n",
"\n",
"\n",
"50\n",
"Finish the following function for computing a fibonacci sequence: \n",
"\n",
"def fib(n):\n",
" if n == 0:\n",
" return 0\n",
" elif n == 1:\n",
" return 1\n",
" else:\n",
" return fib(n - 1) + fib(n - 2)\n",
"\n",
"n = int(\n",
"55\n",
"Finish the following function for computing a fibonacci sequence: \n",
"\n",
"def fib(n):\n",
" x = 1\n",
" y = 1\n",
" if n < 2:\n",
" return 1\n",
" else:\n",
" for i in range(2, n):\n",
" z = x + y\n",
" x = y\n",
" y = z\n",
"\n"
]
}
],
"source": [
"import requests\n",
"from threading import Thread\n",
"import json\n",
"\n",
"url = \"http://127.0.0.1:5543/generate\"\n",
"sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\"\n",
"# sequence = \"def fib(n):\"\n",
"\n",
"def request_task(seed, max_new_tokens):\n",
" obj = {\n",
" \"inputs\": sequence,\n",
" \"generation_parameters\": {\n",
" \"max_new_tokens\":max_new_tokens,\n",
" # \"repetition_penalty\": 1.1,\n",
" # \"do_sample\": True,\n",
" # \"temperature\": 1.1,\n",
" # \"top_k\": 3,\n",
" # \"top_p\": 0.9,\n",
" \"seed\": seed,\n",
" }\n",
" }\n",
" with requests.post(url, json=obj) as r:\n",
" print(max_new_tokens)\n",
" dct = json.loads(r.text)\n",
" print(f'{sequence}{dct[\"response_text\"]}')\n",
"\n",
"max_new_tokens_lst = [55, 50, 45]\n",
"seeds = [1,2,3]\n",
"# max_new_tokens_lst = [100, 200, 300]\n",
"\n",
"request_ts = [\n",
" Thread(target=request_task, args=[seed, max_new_tokens]) for seed, max_new_tokens in zip(seeds, max_new_tokens_lst)\n",
"]\n",
"\n",
"import time\n",
"for request_t in request_ts:\n",
" request_t.start()\n",
" time.sleep(0.1)\n",
"\n",
"for request_t in request_ts:\n",
" request_t.join()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "28964c0b-0822-48aa-bc4c-a519035fae9a",
"metadata": {},
"outputs": [],
"source": [
"obj = {\n",
" \"inputs\": sequence,\n",
" \"generation_parameters\": {\n",
" \"max_new_tokens\":100,\n",
" \"repetion_penalty\": 2.0,\n",
" # \"do_sample\": False,\n",
" # \"temperature\": 1.0,\n",
" # \"top_k\": 1,\n",
" # \"top_p\": 0.9,\n",
" # \"seed\": 43,\n",
" }\n",
" }\n",
"\n",
"resp = requests.post(url, json=obj)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "bdb67325-1341-4da4-808f-c1b12fdbf607",
"metadata": {},
"outputs": [],
"source": [
"a = \"\"\n",
"token = '\\n'\n",
"for i in range(5):\n",
" a += token"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "96aebc29-2c43-456e-ab62-3c422feb57c6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
}
],
"source": [
"print(a)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "85447c0d-7020-46af-826f-fec58c9097b8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\"\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef\"\n"
]
}
],
"source": [
"print(resp.text)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "ff94c668-5832-4096-a723-3fc673b90495",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/robertgshaw/deepsparse-continuous-batching/deepsparse\n"
]
}
],
"source": [
"%cd deepsparse"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d314c9f2-3fa7-44c4-8609-3bb24d008bb0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1. 1. 4. 3. 2. 1.]]\n"
]
}
],
"source": [
"import numpy as np\n",
"from logits_process import TopKLogitsWarper, TopPLogitsWarper, softmax\n",
"\n",
"scores = np.array([[-1.,1.,4.,3.,2.,1.]])\n",
"\n",
"sorted_indices = np.argsort(scores, axis=-1)\n",
"sorted_scores = np.take_along_axis(scores, sorted_indices, axis=-1)\n",
"\n",
"# sort, grabbing all those outside top_p\n",
"cumulative_probs = softmax(sorted_scores, axis=-1).cumsum(axis=-1)\n",
"sorted_indices_to_remove = cumulative_probs <= (1 - 0.9)\n",
"\n",
"\n",
"# note: this relies on b=1\n",
"indices_to_remove = sorted_indices[sorted_indices_to_remove]\n",
"\n",
"# set removed indices logits to -Inf (never selected)\n",
"scores[:, indices_to_remove] = -float(\"Inf\")"
]
},
{
"cell_type": "code",
"execution_count": 93,
"id": "cced6917-8017-405a-ae93-c43105fce82d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-2. 1. 5. 3. 2. 1.]\n",
" [ 5. -2. 3. -1. 2. 1.]]\n",
"[[-inf -inf 5. 3. 2. -inf]\n",
" [ 5. -inf 3. -inf 2. -inf]]\n",
"[[0. 0. 0.84379473 0.1141952 0.04201007 0. ]\n",
" [0.84379473 0. 0.1141952 0. 0.04201007 0. ]]\n"
]
}
],
"source": [
"from logits_process import TopKLogitsWarper, TopPLogitsWarper, softmax\n",
"\n",
"logits = np.array([[-2.,1.,5.,3.,2.,1.], [5.,-2.,3.,-1.,2.,1.]])\n",
"print(logits)\n",
"\n",
"top_k = TopKLogitsWarper(top_k=3)\n",
"new_logits = top_k(logits)\n",
"print(new_logits)\n",
"\n",
"print(softmax(new_logits, axis=1))"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "44fb3e44-0888-49af-a5d6-2c908db324ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1. 1. 4. 3. 2. 1.]]\n",
"[[0.00418629 0.03093274 0.62130062 0.22856372 0.0840839 0.03093274]]\n",
"[[-inf -inf 4. 3. 2. -inf]]\n",
"[[0. 0. 0.66524096 0.24472847 0.09003057 0. ]]\n"
]
}
],
"source": [
"logits = np.array([[-1.,1.,4.,3.,2.,1.]])\n",
"print(logits)\n",
"\n",
"print(softmax(logits, axis=-1))\n",
"\n",
"top_p = TopPLogitsWarper(top_p=0.9)\n",
"new_logits = top_p(logits)\n",
"print(new_logits)\n",
"\n",
"print(softmax(new_logits, axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": 102,
"id": "f20813aa-0362-4820-a3b2-c32531658718",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1. 1. 3. 3.5 2.5 2. 1. ]]\n",
"[[0.00468177 0.03459387 0.25561604 0.4214396 0.15503896 0.09403589\n",
" 0.03459387]]\n",
"[[-inf -inf 3. 3.5 2.5 2. -inf]]\n",
"[[-inf -inf 3. 3.5 2.5 -inf -inf]]\n",
"[[0. 0. 0.30719589 0.50648039 0.18632372 0.\n",
" 0. ]]\n"
]
}
],
"source": [
"logits = np.array([[-1.,1.,3.,3.5,2.5,2.,1.]])\n",
"print(logits)\n",
"\n",
"print(softmax(logits, axis=-1))\n",
"\n",
"top_p = TopPLogitsWarper(top_p=0.9)\n",
"top_k = TopKLogitsWarper(top_k=3)\n",
"\n",
"new_logits = top_p(logits)\n",
"print(new_logits)\n",
"\n",
"new_new_logits = top_k(new_logits)\n",
"print(new_new_logits)\n",
"print(softmax(logits, axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "6914e436-350c-4c6a-b5ca-4146f349b1f3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1. 1. 3. 5. 2.5 2. 1. ]]\n",
"[[0.00189751 0.01402082 0.10360061 0.76551075 0.06283695 0.03811254\n",
" 0.01402082]]\n",
"[[-inf 1. 3. 5. 2.5 2. 1. ]]\n",
"[[-inf -inf 3. 5. 2.5 -inf -inf]]\n"
]
}
],
"source": [
"logits = np.array([[-1.,1.,3.,5,2.5,2.,1.]])\n",
"print(logits)\n",
"print(softmax(logits, axis=-1))\n",
"\n",
"top_p = TopPLogitsWarper(top_p=0.99)\n",
"top_k = TopKLogitsWarper(top_k=3)\n",
"\n",
"new_logits = top_p(logits)\n",
"print(new_logits)\n",
"\n",
"new_new_logits = top_k(new_logits)\n",
"print(new_new_logits)\n",
"# print(softmax(new_new_logits, axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "19ccfb53-8e40-46b4-9987-3081add2ccdb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[False, False, False, False, False]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cumulative_probs <= (1 - top_p)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "483cea72-9955-4742-b599-6f7ec7a6c6d9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0, 4, 3, 1, 2]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sorted_indices"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "38abf685-b4ae-4dc3-ab45-acf6dfd4036f",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'utils'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[65], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NextTokenChooser\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# picks largest logits\u001b[39;00m\n\u001b[1;32m 4\u001b[0m ntc \u001b[38;5;241m=\u001b[39m NextTokenChooser(do_sample\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'"
]
}
],
"source": [
"from utils import NextTokenChooser\n",
"\n",
"# picks largest logits\n",
"ntc = NextTokenChooser(do_sample=False)\n",
"logits = np.array([[-2,1,3,2]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"ntc(input_ids, logits)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "a599ccae-6457-4bb9-8240-703aaf64474f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
"sample / actual 0: 0.004 / 0.004\n",
"sample / actual 1: 0.081 / 0.082\n",
"sample / actual 2: 0.615 / 0.608\n",
"sample / actual 3: 0.219 / 0.224\n",
"sample / actual 4: 0.081 / 0.082\n"
]
}
],
"source": [
"# samples\n",
"ntc = NextTokenChooser(do_sample=True)\n",
"logits = np.array([[-2,1,3,2,1]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"\n",
"probs = ntc.choice.softmax(logits)\n",
"print(probs)\n",
"\n",
"iters = 10000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "76f4d55d-7314-46e8-91e8-4a237d9d2ec1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
"sample / actual 0: 0.066 / 0.004\n",
"sample / actual 1: 0.176 / 0.082\n",
"sample / actual 2: 0.342 / 0.608\n",
"sample / actual 3: 0.244 / 0.224\n",
"sample / actual 4: 0.172 / 0.082\n"
]
}
],
"source": [
"# should pick logits that are less likely\n",
"ntc = NextTokenChooser(do_sample=True, temperature=3.0)\n",
"logits = np.array([[-2,1,3,2,1]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"\n",
"probs = ntc.choice.softmax(logits)\n",
"print(probs)\n",
"\n",
"iters = 10000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "e172586d-e0f9-4429-9b73-52dccb2117ca",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
"sample / actual 0: 0.197 / 0.004\n",
"sample / actual 1: 0.201 / 0.082\n",
"sample / actual 2: 0.205 / 0.608\n",
"sample / actual 3: 0.203 / 0.224\n",
"sample / actual 4: 0.195 / 0.082\n"
]
}
],
"source": [
"# should approach uniform\n",
"ntc = NextTokenChooser(do_sample=True, temperature=100.0)\n",
"logits = np.array([[-2,1,3,2,1]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"\n",
"probs = ntc.choice.softmax(logits)\n",
"print(probs)\n",
"\n",
"iters = 10000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "606f2be7-4ae3-4898-8c52-ec71207a5ea3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
"sample / actual 0: 0.000 / 0.004\n",
"sample / actual 1: 0.015 / 0.082\n",
"sample / actual 2: 0.857 / 0.608\n",
"sample / actual 3: 0.111 / 0.224\n",
"sample / actual 4: 0.017 / 0.082\n"
]
}
],
"source": [
"# should pick logits that are more likely\n",
"ntc = NextTokenChooser(do_sample=True, temperature=0.5)\n",
"logits = np.array([[-2,1,3,2,1]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"\n",
"probs = ntc.choice.softmax(logits)\n",
"print(probs)\n",
"\n",
"iters = 10000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "ea0c96a5-fb8f-4f21-a30e-7443d17a440c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
"sample / actual 0: 0.000 / 0.004\n",
"sample / actual 1: 0.000 / 0.082\n",
"sample / actual 2: 1.000 / 0.608\n",
"sample / actual 3: 0.000 / 0.224\n",
"sample / actual 4: 0.000 / 0.082\n"
]
}
],
"source": [
"# should approximate greedy sampling\n",
"ntc = NextTokenChooser(do_sample=True, temperature=0.001)\n",
"logits = np.array([[-2,1,3,2,1]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"\n",
"probs = ntc.choice.softmax(logits)\n",
"print(probs)\n",
"\n",
"iters = 10000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 134,
"id": "c9dcebaa-fa27-4d32-ad95-6a566d808868",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4\n",
"4\n",
"2\n",
"4\n",
"3\n",
"1\n",
"sample / actual / original 0: 0.017 / 0.016 / 0.004\n",
"sample / actual / original 1: 0.119 / 0.122 / 0.032\n",
"sample / actual / original 2: 0.204 / 0.201 / 0.236\n",
"sample / actual / original 3: 0.329 / 0.331 / 0.087\n",
"sample / actual / original 4: 0.330 / 0.331 / 0.641\n"
]
}
],
"source": [
"# should approximate greedy sampling\n",
"ntc = NextTokenChooser(do_sample=False)\n",
"\n",
"# should select index 4\n",
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
"input_ids = np.array([[3,4]])\n",
"\n",
"print(ntc(input_ids,logits))\n",
"\n",
"# should select index 4\n",
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=1.0)\n",
"logits = np.array([[-2,1,3,2,5]])\n",
"input_ids = np.array([[3,4]])\n",
"print(ntc(input_ids,logits))\n",
"\n",
"# should select index 2\n",
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=2.0)\n",
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
"input_ids = np.array([[3,4]])\n",
"print(ntc(input_ids,logits))\n",
"\n",
"# should select index 4\n",
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
"input_ids = np.array([[2,4]])\n",
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=2.)\n",
"print(ntc(input_ids,logits))\n",
"\n",
"# should select index 3\n",
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
"input_ids = np.array([[2,4]])\n",
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=3.)\n",
"print(ntc(input_ids,logits))\n",
"\n",
"# should select index 1\n",
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
"input_ids = np.array([[2,3,4]])\n",
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=5.)\n",
"print(ntc(input_ids,logits))\n",
"\n",
"\n",
"# should make 2,4 less liekly\n",
"logits_og = np.array([[-1.,1.,3.,2.,4.]])\n",
"input_ids = np.array([[2,4]])\n",
"ntc = NextTokenChooser(do_sample=True, repetition_penalty=2.)\n",
"\n",
"original_ps = ntc.choice.softmax(logits_og,axis=1)\n",
"penalty = np.array([[1.,1.,2.,1.,2.]])\n",
"penalty_ps = ntc.choice.softmax(logits_og/penalty,axis=1)\n",
"\n",
"iters = 20000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" logits = logits_og.copy()\n",
" # print(logits)\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual / original {i}: {counts[i] / iters: 0.3f} / {penalty_ps[0,i]: 0.3f} / {original_ps[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 111,
"id": "afd3139e-924c-4ac0-a450-c462447a610a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-2. , 1. , 1.5, 2. , 2.5]])"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logits / penalty"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "ea9f993e-12f8-4bd0-9f8c-885bf8e56a7b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.00446236 0.08962882 0.24363641 0.66227241]]\n",
"{0: 0.0041, 1: 0.0909, 2: 0.2418, 3: 0.6632}\n"
]
}
],
"source": [
"from utils import Sampling\n",
"import numpy as np\n",
"\n",
"sampling = Sampling(seed=1)\n",
"\n",
"logits = np.array([[-2,1,2,3]])\n",
"probs = sampling.softmax(logits, axis=1)\n",
"\n",
"print(probs)\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"iters = 10000\n",
"for i in range(iters):\n",
" counts[sampling(logits=logits)] +=1\n",
"\n",
"for key in counts:\n",
" counts[key] /= iters\n",
"\n",
"print(counts)"
]
},
{
"cell_type": "code",
"execution_count": 42,
@ -141,7 +908,7 @@
" }\n",
" with requests.post(url, json=obj) as r:\n",
" print(max_new_tokens)\n",
" print(r.text)"
" print(r.text)\n"
]
},
{