successfully implemented and integrated nexttokenchooser with support for repetition penalty, do_sample, temperature, top_k, and top_p

This commit is contained in:
rsnm2 2023-08-28 02:05:43 +00:00
parent a875c05ccd
commit bb124e4029
4 changed files with 921 additions and 50 deletions

View File

@ -2,7 +2,7 @@ from queue import Queue
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Optional, Tuple
from service.service import DeepSparseService from service.service import DeepSparseService
from service.causal_lm import DeepSparseCausalLM from service.causal_lm import DeepSparseCausalLM
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, StoppingCriteria from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, GenerationParameters
class DeepSparseRouter: class DeepSparseRouter:
def __init__( def __init__(
@ -11,10 +11,7 @@ class DeepSparseRouter:
model_path: Optional[str] = None, model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
): ):
assert ( assert service is not None or (model_path is not None and tokenizer_path is not None)
service is not None or
(model_path is not None and tokenizer_path is not None)
)
if service is not None: if service is not None:
self.service = service self.service = service
@ -38,8 +35,8 @@ class DeepSparseRouter:
# unblock the batching task with a dummy request if blocked # unblock the batching task with a dummy request if blocked
self.queue.append(GenerateRequest( self.queue.append(GenerateRequest(
inputs="dummy", inputs="stop",
max_new_tokens=1, generation_parameters=GenerationParameters(max_new_tokens=1),
response_stream=Queue() response_stream=Queue()
)) ))
@ -166,9 +163,10 @@ class DeepSparseQueue:
# if block = True, this blocks until something ready # if block = True, this blocks until something ready
# if block = False, the queue has data (if not an exception is raised) # if block = False, the queue has data (if not an exception is raised)
# while queue.empty() == False does not guarentee data # while queue.empty() == False typically not guarentee data on next queue.get(), this
# the queue is only subscribed to by one thread (this one) # queue is only subscribed to by one thread (this one) since batching_task is the only
# since batching_task is the only function that calls next_batch # so it does in our case
generate_request = self.queue.get(block=block) generate_request = self.queue.get(block=block)
generate_requests = {self.next_request_id: generate_request} generate_requests = {self.next_request_id: generate_request}
@ -176,7 +174,7 @@ class DeepSparseQueue:
request = Request( request = Request(
id=self.next_request_id, id=self.next_request_id,
inputs=generate_request.inputs, inputs=generate_request.inputs,
max_new_tokens=generate_request.max_new_tokens generation_parameters=generate_request.generation_parameters,
) )
self.next_request_id += 1 self.next_request_id += 1

View File

@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
import numpy as np import numpy as np
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria, NextTokenChooser
DEEPSPARSE_SEQUENCE_LENGTH = 128 DEEPSPARSE_SEQUENCE_LENGTH = 128
DEEPSPARSE_MULTITOKEN_LENGTH = 4 DEEPSPARSE_MULTITOKEN_LENGTH = 4
@ -17,6 +17,7 @@ class DeepSparseCausalLMBatch:
input_ids_list: List[np.ndarray] input_ids_list: List[np.ndarray]
past_key_values_list: List[Optional[DeepSparsePastKeyValues]] past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
stopping_criteria_list: List[StoppingCriteria] stopping_criteria_list: List[StoppingCriteria]
next_token_chooser_list: List[NextTokenChooser]
@classmethod @classmethod
def from_batch( def from_batch(
@ -29,6 +30,7 @@ class DeepSparseCausalLMBatch:
requests_idx_mapping = {} requests_idx_mapping = {}
input_ids_list = [] input_ids_list = []
stopping_criteria_list = [] stopping_criteria_list = []
next_token_chooser_list = []
# loop through items in the batch # loop through items in the batch
for idx, r in enumerate(batch.requests): for idx, r in enumerate(batch.requests):
@ -51,7 +53,19 @@ class DeepSparseCausalLMBatch:
stopping_criteria_list.append( stopping_criteria_list.append(
StoppingCriteria( StoppingCriteria(
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
max_new_tokens=min(r.max_new_tokens, model_max_new_tokens) max_new_tokens=min(r.generation_parameters.max_new_tokens, model_max_new_tokens)
)
)
# get next token chooser based on input
next_token_chooser_list.append(
NextTokenChooser(
repetition_penalty=r.generation_parameters.repetition_penalty,
temperature=r.generation_parameters.temperature,
top_k=r.generation_parameters.top_k,
top_p=r.generation_parameters.top_p,
do_sample=r.generation_parameters.do_sample,
seed=r.generation_parameters.seed,
) )
) )
@ -61,7 +75,8 @@ class DeepSparseCausalLMBatch:
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list, input_ids_list=input_ids_list,
past_key_values_list=[None] * len(batch.requests), past_key_values_list=[None] * len(batch.requests),
stopping_criteria_list=stopping_criteria_list stopping_criteria_list=stopping_criteria_list,
next_token_chooser_list=next_token_chooser_list,
) )
def to_cached_batch(self) -> CachedBatch: def to_cached_batch(self) -> CachedBatch:
@ -78,11 +93,12 @@ class DeepSparseCausalLMBatch:
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]: def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
assert(len(request_ids) > 0) assert(len(request_ids) > 0)
requests_idx_mapping = {} requests_idx_mapping = {}
requests = [] requests = []
input_ids_list = [] input_ids_list = []
past_key_values_list = [] past_key_values_list = []
stopping_criteria_list = [] stopping_criteria_list = []
next_token_chooser_list = []
# loop through requests, keep ones that should remain # loop through requests, keep ones that should remain
for new_idx, request_id in enumerate(request_ids): for new_idx, request_id in enumerate(request_ids):
@ -95,6 +111,7 @@ class DeepSparseCausalLMBatch:
input_ids_list.append(self.input_ids_list[old_idx]) input_ids_list.append(self.input_ids_list[old_idx])
past_key_values_list.append(self.past_key_values_list[old_idx]) past_key_values_list.append(self.past_key_values_list[old_idx])
stopping_criteria_list.append(self.stopping_criteria_list[old_idx]) stopping_criteria_list.append(self.stopping_criteria_list[old_idx])
next_token_chooser_list.append(self.next_token_chooser_list[old_idx])
# update batch state # update batch state
self.requests = requests self.requests = requests
@ -102,6 +119,7 @@ class DeepSparseCausalLMBatch:
self.input_ids_list = input_ids_list self.input_ids_list = input_ids_list
self.past_key_values_list = past_key_values_list self.past_key_values_list = past_key_values_list
self.stopping_criteria_list = stopping_criteria_list self.stopping_criteria_list = stopping_criteria_list
self.next_token_chooser_list = next_token_chooser_list
return self return self
@ -110,11 +128,12 @@ class DeepSparseCausalLMBatch:
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch": def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
assert len(batches) > 1, "must have more than 1 batch to concatenate" assert len(batches) > 1, "must have more than 1 batch to concatenate"
requests_idx_mapping = {} requests_idx_mapping = {}
requests = [] requests = []
input_ids_list = [] input_ids_list = []
past_key_values_list = [] past_key_values_list = []
stopping_criteria_list = [] stopping_criteria_list = []
next_token_chooser_list = []
start_index = 0 start_index = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
@ -125,6 +144,7 @@ class DeepSparseCausalLMBatch:
input_ids_list.extend(batch.input_ids_list) input_ids_list.extend(batch.input_ids_list)
past_key_values_list.extend(batch.past_key_values_list) past_key_values_list.extend(batch.past_key_values_list)
stopping_criteria_list.extend(batch.stopping_criteria_list) stopping_criteria_list.extend(batch.stopping_criteria_list)
next_token_chooser_list.extend(batch.next_token_chooser_list)
# merge the request_id to index mapping # merge the request_id to index mapping
if i == 0: if i == 0:
@ -141,7 +161,8 @@ class DeepSparseCausalLMBatch:
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids_list=input_ids_list, input_ids_list=input_ids_list,
past_key_values_list=past_key_values_list, past_key_values_list=past_key_values_list,
stopping_criteria_list=stopping_criteria_list stopping_criteria_list=stopping_criteria_list,
next_token_chooser_list=next_token_chooser_list
) )
class DeepSparseCausalLM: class DeepSparseCausalLM:
@ -164,18 +185,6 @@ class DeepSparseCausalLM:
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
) )
# TODO: switch to NextTokenChooser
def sample_token(
self,
logits: np.ndarray
):
# assert b=1 for now
assert(logits.shape[0] == 1)
# grab logits for the last item in the sequence
# shape == (batch, seq, vocabulary_size)
return np.argmax(logits[0,-1,:])
def generate_token( def generate_token(
self, self,
batch: DeepSparseCausalLMBatch, batch: DeepSparseCausalLMBatch,
@ -189,14 +198,21 @@ class DeepSparseCausalLM:
# b) sample and check stopping criteria # b) sample and check stopping criteria
# c) create generation # c) create generation
# d) update batch # d) update batch
for i, (request, input_ids, past_key_values,stopping_criteria,) in enumerate(
zip( iterator = zip(
batch.requests, batch.requests,
batch.input_ids_list, batch.input_ids_list,
batch.past_key_values_list, batch.past_key_values_list,
batch.stopping_criteria_list, batch.stopping_criteria_list,
) batch.next_token_chooser_list,
): )
for i, (
request,
input_ids,
past_key_values,
stopping_criteria,
next_token_chooser
) in enumerate(iterator):
# assert input_ids is b=1 # assert input_ids is b=1
assert len(input_ids.shape) == 2 assert len(input_ids.shape) == 2
assert input_ids.shape[0] == 1 assert input_ids.shape[0] == 1
@ -206,7 +222,7 @@ class DeepSparseCausalLM:
# b) sample token and check stopping criteria # b) sample token and check stopping criteria
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now) # TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
generated_token_id = self.sample_token(logits) generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:])
generated_token = self.tokenizer.decode(generated_token_id) generated_token = self.tokenizer.decode(generated_token_id)
stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id) stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id)

View File

@ -1,7 +1,72 @@
import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from pydantic import BaseModel, Field
from queue import Queue from queue import Queue
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
from logits_process import (RepetitionPenaltyLogitsProcessor, LogitsWarpers, softmax)
# TODO: sample for b > 1 with vectorized code
class Greedy:
def __call__(self, logits: np.ndarray):
# assert b=1 for now
# shape == (batch, vocabulary_size)
assert(logits.shape[0] == 1)
assert(len(logits.shape) == 2)
return np.argmax(logits[0,:])
# TODO: sample for b > 1 with vectorized code
# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a
class Sampling:
def __init__(self, seed:int=42):
self.generator = np.random.default_rng(seed=seed)
def __call__(self, logits:np.ndarray):
# assert b=1 for now
# shape == (batch, vocabulary_size)
assert(logits.shape[0] == 1)
assert(len(logits.shape) == 2)
probs = softmax(logits, axis=1)
return self.generator.choice(probs.shape[1], p=probs[0,:])
class NextTokenChooser:
def __init__(
self,
repetition_penalty: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
do_sample:bool = False,
seed: int = 42,
):
self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty and repetition_penalty != 1.0 else None
)
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
)
if has_warpers:
self.warpers = LogitsWarpers(temperature=temperature, top_k=top_k, top_p=top_p)
else:
self.warpers = None
self.choice = Sampling(seed=seed) if do_sample or has_warpers else Greedy()
def __call__(self, input_ids: np.ndarray, scores:np.ndarray) -> int:
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids=input_ids, scores=scores)
if self.warpers is not None:
scores = self.warpers(scores=scores)
return self.choice(scores)
class FinishReason(Enum): class FinishReason(Enum):
FINISH_REASON_LENGTH = 1 FINISH_REASON_LENGTH = 1
@ -28,11 +93,28 @@ class StoppingCriteria:
return False, None return False, None
class GenerationParameters(BaseModel):
max_new_tokens: int = Field(default=20)
repetition_penalty: float = Field(default=1)
do_sample: bool = Field(default=False)
temperature: float = Field(default=1.0)
top_k: Optional[int] = Field(default=None)
top_p: Optional[float] = Field(default=None)
seed: int = Field(default=42)
class GenerateRequestInputs(BaseModel):
inputs: str
generation_parameters: GenerationParameters
class GenerateRequestOutputs(BaseModel):
response_text: str = Field(default="")
finish_reason: Optional[FinishReason] = Field(default=None)
@dataclass @dataclass
class Request: class Request:
id: int id: int
inputs: str inputs: str
max_new_tokens: int generation_parameters: GenerationParameters
@dataclass @dataclass
class Batch: class Batch:
@ -58,5 +140,13 @@ class Generation:
@dataclass @dataclass
class GenerateRequest: class GenerateRequest:
inputs: str inputs: str
max_new_tokens: int generation_parameters: GenerationParameters
response_stream: Queue[Generation] response_stream: Queue[Generation]
@classmethod
def from_gr_inputs(cls, gr_inputs: GenerateRequestInputs):
return cls(
inputs=gr_inputs.inputs,
generation_parameters=gr_inputs.generation_parameters,
response_stream=Queue()
)

View File

@ -11,6 +11,773 @@
"%autoreload 2" "%autoreload 2"
] ]
}, },
{
"cell_type": "code",
"execution_count": 89,
"id": "7513b2a1-0749-44b5-88a9-d91d5b175e8e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"45\n",
"Finish the following function for computing a fibonacci sequence: \n",
"\n",
"def fib(n):\n",
" if n<=1:\n",
" return n\n",
" else:\n",
" return fib(n-1)+fib(n-2)\n",
" \n",
"print(fib(15))\n",
"\n",
"\n",
"50\n",
"Finish the following function for computing a fibonacci sequence: \n",
"\n",
"def fib(n):\n",
" if n == 0:\n",
" return 0\n",
" elif n == 1:\n",
" return 1\n",
" else:\n",
" return fib(n - 1) + fib(n - 2)\n",
"\n",
"n = int(\n",
"55\n",
"Finish the following function for computing a fibonacci sequence: \n",
"\n",
"def fib(n):\n",
" x = 1\n",
" y = 1\n",
" if n < 2:\n",
" return 1\n",
" else:\n",
" for i in range(2, n):\n",
" z = x + y\n",
" x = y\n",
" y = z\n",
"\n"
]
}
],
"source": [
"import requests\n",
"from threading import Thread\n",
"import json\n",
"\n",
"url = \"http://127.0.0.1:5543/generate\"\n",
"sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\"\n",
"# sequence = \"def fib(n):\"\n",
"\n",
"def request_task(seed, max_new_tokens):\n",
" obj = {\n",
" \"inputs\": sequence,\n",
" \"generation_parameters\": {\n",
" \"max_new_tokens\":max_new_tokens,\n",
" # \"repetition_penalty\": 1.1,\n",
" # \"do_sample\": True,\n",
" # \"temperature\": 1.1,\n",
" # \"top_k\": 3,\n",
" # \"top_p\": 0.9,\n",
" \"seed\": seed,\n",
" }\n",
" }\n",
" with requests.post(url, json=obj) as r:\n",
" print(max_new_tokens)\n",
" dct = json.loads(r.text)\n",
" print(f'{sequence}{dct[\"response_text\"]}')\n",
"\n",
"max_new_tokens_lst = [55, 50, 45]\n",
"seeds = [1,2,3]\n",
"# max_new_tokens_lst = [100, 200, 300]\n",
"\n",
"request_ts = [\n",
" Thread(target=request_task, args=[seed, max_new_tokens]) for seed, max_new_tokens in zip(seeds, max_new_tokens_lst)\n",
"]\n",
"\n",
"import time\n",
"for request_t in request_ts:\n",
" request_t.start()\n",
" time.sleep(0.1)\n",
"\n",
"for request_t in request_ts:\n",
" request_t.join()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "28964c0b-0822-48aa-bc4c-a519035fae9a",
"metadata": {},
"outputs": [],
"source": [
"obj = {\n",
" \"inputs\": sequence,\n",
" \"generation_parameters\": {\n",
" \"max_new_tokens\":100,\n",
" \"repetion_penalty\": 2.0,\n",
" # \"do_sample\": False,\n",
" # \"temperature\": 1.0,\n",
" # \"top_k\": 1,\n",
" # \"top_p\": 0.9,\n",
" # \"seed\": 43,\n",
" }\n",
" }\n",
"\n",
"resp = requests.post(url, json=obj)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "bdb67325-1341-4da4-808f-c1b12fdbf607",
"metadata": {},
"outputs": [],
"source": [
"a = \"\"\n",
"token = '\\n'\n",
"for i in range(5):\n",
" a += token"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "96aebc29-2c43-456e-ab62-3c422feb57c6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
}
],
"source": [
"print(a)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "85447c0d-7020-46af-826f-fec58c9097b8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\"\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef\"\n"
]
}
],
"source": [
"print(resp.text)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "ff94c668-5832-4096-a723-3fc673b90495",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/robertgshaw/deepsparse-continuous-batching/deepsparse\n"
]
}
],
"source": [
"%cd deepsparse"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d314c9f2-3fa7-44c4-8609-3bb24d008bb0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1. 1. 4. 3. 2. 1.]]\n"
]
}
],
"source": [
"import numpy as np\n",
"from logits_process import TopKLogitsWarper, TopPLogitsWarper, softmax\n",
"\n",
"scores = np.array([[-1.,1.,4.,3.,2.,1.]])\n",
"\n",
"sorted_indices = np.argsort(scores, axis=-1)\n",
"sorted_scores = np.take_along_axis(scores, sorted_indices, axis=-1)\n",
"\n",
"# sort, grabbing all those outside top_p\n",
"cumulative_probs = softmax(sorted_scores, axis=-1).cumsum(axis=-1)\n",
"sorted_indices_to_remove = cumulative_probs <= (1 - 0.9)\n",
"\n",
"\n",
"# note: this relies on b=1\n",
"indices_to_remove = sorted_indices[sorted_indices_to_remove]\n",
"\n",
"# set removed indices logits to -Inf (never selected)\n",
"scores[:, indices_to_remove] = -float(\"Inf\")"
]
},
{
"cell_type": "code",
"execution_count": 93,
"id": "cced6917-8017-405a-ae93-c43105fce82d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-2. 1. 5. 3. 2. 1.]\n",
" [ 5. -2. 3. -1. 2. 1.]]\n",
"[[-inf -inf 5. 3. 2. -inf]\n",
" [ 5. -inf 3. -inf 2. -inf]]\n",
"[[0. 0. 0.84379473 0.1141952 0.04201007 0. ]\n",
" [0.84379473 0. 0.1141952 0. 0.04201007 0. ]]\n"
]
}
],
"source": [
"from logits_process import TopKLogitsWarper, TopPLogitsWarper, softmax\n",
"\n",
"logits = np.array([[-2.,1.,5.,3.,2.,1.], [5.,-2.,3.,-1.,2.,1.]])\n",
"print(logits)\n",
"\n",
"top_k = TopKLogitsWarper(top_k=3)\n",
"new_logits = top_k(logits)\n",
"print(new_logits)\n",
"\n",
"print(softmax(new_logits, axis=1))"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "44fb3e44-0888-49af-a5d6-2c908db324ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1. 1. 4. 3. 2. 1.]]\n",
"[[0.00418629 0.03093274 0.62130062 0.22856372 0.0840839 0.03093274]]\n",
"[[-inf -inf 4. 3. 2. -inf]]\n",
"[[0. 0. 0.66524096 0.24472847 0.09003057 0. ]]\n"
]
}
],
"source": [
"logits = np.array([[-1.,1.,4.,3.,2.,1.]])\n",
"print(logits)\n",
"\n",
"print(softmax(logits, axis=-1))\n",
"\n",
"top_p = TopPLogitsWarper(top_p=0.9)\n",
"new_logits = top_p(logits)\n",
"print(new_logits)\n",
"\n",
"print(softmax(new_logits, axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": 102,
"id": "f20813aa-0362-4820-a3b2-c32531658718",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1. 1. 3. 3.5 2.5 2. 1. ]]\n",
"[[0.00468177 0.03459387 0.25561604 0.4214396 0.15503896 0.09403589\n",
" 0.03459387]]\n",
"[[-inf -inf 3. 3.5 2.5 2. -inf]]\n",
"[[-inf -inf 3. 3.5 2.5 -inf -inf]]\n",
"[[0. 0. 0.30719589 0.50648039 0.18632372 0.\n",
" 0. ]]\n"
]
}
],
"source": [
"logits = np.array([[-1.,1.,3.,3.5,2.5,2.,1.]])\n",
"print(logits)\n",
"\n",
"print(softmax(logits, axis=-1))\n",
"\n",
"top_p = TopPLogitsWarper(top_p=0.9)\n",
"top_k = TopKLogitsWarper(top_k=3)\n",
"\n",
"new_logits = top_p(logits)\n",
"print(new_logits)\n",
"\n",
"new_new_logits = top_k(new_logits)\n",
"print(new_new_logits)\n",
"print(softmax(logits, axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "6914e436-350c-4c6a-b5ca-4146f349b1f3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1. 1. 3. 5. 2.5 2. 1. ]]\n",
"[[0.00189751 0.01402082 0.10360061 0.76551075 0.06283695 0.03811254\n",
" 0.01402082]]\n",
"[[-inf 1. 3. 5. 2.5 2. 1. ]]\n",
"[[-inf -inf 3. 5. 2.5 -inf -inf]]\n"
]
}
],
"source": [
"logits = np.array([[-1.,1.,3.,5,2.5,2.,1.]])\n",
"print(logits)\n",
"print(softmax(logits, axis=-1))\n",
"\n",
"top_p = TopPLogitsWarper(top_p=0.99)\n",
"top_k = TopKLogitsWarper(top_k=3)\n",
"\n",
"new_logits = top_p(logits)\n",
"print(new_logits)\n",
"\n",
"new_new_logits = top_k(new_logits)\n",
"print(new_new_logits)\n",
"# print(softmax(new_new_logits, axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "19ccfb53-8e40-46b4-9987-3081add2ccdb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[False, False, False, False, False]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cumulative_probs <= (1 - top_p)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "483cea72-9955-4742-b599-6f7ec7a6c6d9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0, 4, 3, 1, 2]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sorted_indices"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "38abf685-b4ae-4dc3-ab45-acf6dfd4036f",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'utils'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[65], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NextTokenChooser\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# picks largest logits\u001b[39;00m\n\u001b[1;32m 4\u001b[0m ntc \u001b[38;5;241m=\u001b[39m NextTokenChooser(do_sample\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'"
]
}
],
"source": [
"from utils import NextTokenChooser\n",
"\n",
"# picks largest logits\n",
"ntc = NextTokenChooser(do_sample=False)\n",
"logits = np.array([[-2,1,3,2]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"ntc(input_ids, logits)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "a599ccae-6457-4bb9-8240-703aaf64474f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
"sample / actual 0: 0.004 / 0.004\n",
"sample / actual 1: 0.081 / 0.082\n",
"sample / actual 2: 0.615 / 0.608\n",
"sample / actual 3: 0.219 / 0.224\n",
"sample / actual 4: 0.081 / 0.082\n"
]
}
],
"source": [
"# samples\n",
"ntc = NextTokenChooser(do_sample=True)\n",
"logits = np.array([[-2,1,3,2,1]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"\n",
"probs = ntc.choice.softmax(logits)\n",
"print(probs)\n",
"\n",
"iters = 10000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "76f4d55d-7314-46e8-91e8-4a237d9d2ec1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
"sample / actual 0: 0.066 / 0.004\n",
"sample / actual 1: 0.176 / 0.082\n",
"sample / actual 2: 0.342 / 0.608\n",
"sample / actual 3: 0.244 / 0.224\n",
"sample / actual 4: 0.172 / 0.082\n"
]
}
],
"source": [
"# should pick logits that are less likely\n",
"ntc = NextTokenChooser(do_sample=True, temperature=3.0)\n",
"logits = np.array([[-2,1,3,2,1]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"\n",
"probs = ntc.choice.softmax(logits)\n",
"print(probs)\n",
"\n",
"iters = 10000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "e172586d-e0f9-4429-9b73-52dccb2117ca",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
"sample / actual 0: 0.197 / 0.004\n",
"sample / actual 1: 0.201 / 0.082\n",
"sample / actual 2: 0.205 / 0.608\n",
"sample / actual 3: 0.203 / 0.224\n",
"sample / actual 4: 0.195 / 0.082\n"
]
}
],
"source": [
"# should approach uniform\n",
"ntc = NextTokenChooser(do_sample=True, temperature=100.0)\n",
"logits = np.array([[-2,1,3,2,1]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"\n",
"probs = ntc.choice.softmax(logits)\n",
"print(probs)\n",
"\n",
"iters = 10000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "606f2be7-4ae3-4898-8c52-ec71207a5ea3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
"sample / actual 0: 0.000 / 0.004\n",
"sample / actual 1: 0.015 / 0.082\n",
"sample / actual 2: 0.857 / 0.608\n",
"sample / actual 3: 0.111 / 0.224\n",
"sample / actual 4: 0.017 / 0.082\n"
]
}
],
"source": [
"# should pick logits that are more likely\n",
"ntc = NextTokenChooser(do_sample=True, temperature=0.5)\n",
"logits = np.array([[-2,1,3,2,1]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"\n",
"probs = ntc.choice.softmax(logits)\n",
"print(probs)\n",
"\n",
"iters = 10000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "ea0c96a5-fb8f-4f21-a30e-7443d17a440c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
"sample / actual 0: 0.000 / 0.004\n",
"sample / actual 1: 0.000 / 0.082\n",
"sample / actual 2: 1.000 / 0.608\n",
"sample / actual 3: 0.000 / 0.224\n",
"sample / actual 4: 0.000 / 0.082\n"
]
}
],
"source": [
"# should approximate greedy sampling\n",
"ntc = NextTokenChooser(do_sample=True, temperature=0.001)\n",
"logits = np.array([[-2,1,3,2,1]])\n",
"input_ids = np.array([[1,2,3,4]])\n",
"\n",
"probs = ntc.choice.softmax(logits)\n",
"print(probs)\n",
"\n",
"iters = 10000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 134,
"id": "c9dcebaa-fa27-4d32-ad95-6a566d808868",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4\n",
"4\n",
"2\n",
"4\n",
"3\n",
"1\n",
"sample / actual / original 0: 0.017 / 0.016 / 0.004\n",
"sample / actual / original 1: 0.119 / 0.122 / 0.032\n",
"sample / actual / original 2: 0.204 / 0.201 / 0.236\n",
"sample / actual / original 3: 0.329 / 0.331 / 0.087\n",
"sample / actual / original 4: 0.330 / 0.331 / 0.641\n"
]
}
],
"source": [
"# should approximate greedy sampling\n",
"ntc = NextTokenChooser(do_sample=False)\n",
"\n",
"# should select index 4\n",
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
"input_ids = np.array([[3,4]])\n",
"\n",
"print(ntc(input_ids,logits))\n",
"\n",
"# should select index 4\n",
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=1.0)\n",
"logits = np.array([[-2,1,3,2,5]])\n",
"input_ids = np.array([[3,4]])\n",
"print(ntc(input_ids,logits))\n",
"\n",
"# should select index 2\n",
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=2.0)\n",
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
"input_ids = np.array([[3,4]])\n",
"print(ntc(input_ids,logits))\n",
"\n",
"# should select index 4\n",
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
"input_ids = np.array([[2,4]])\n",
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=2.)\n",
"print(ntc(input_ids,logits))\n",
"\n",
"# should select index 3\n",
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
"input_ids = np.array([[2,4]])\n",
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=3.)\n",
"print(ntc(input_ids,logits))\n",
"\n",
"# should select index 1\n",
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
"input_ids = np.array([[2,3,4]])\n",
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=5.)\n",
"print(ntc(input_ids,logits))\n",
"\n",
"\n",
"# should make 2,4 less liekly\n",
"logits_og = np.array([[-1.,1.,3.,2.,4.]])\n",
"input_ids = np.array([[2,4]])\n",
"ntc = NextTokenChooser(do_sample=True, repetition_penalty=2.)\n",
"\n",
"original_ps = ntc.choice.softmax(logits_og,axis=1)\n",
"penalty = np.array([[1.,1.,2.,1.,2.]])\n",
"penalty_ps = ntc.choice.softmax(logits_og/penalty,axis=1)\n",
"\n",
"iters = 20000\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"\n",
"for _ in range(iters):\n",
" logits = logits_og.copy()\n",
" # print(logits)\n",
" pred = ntc(input_ids, logits)\n",
" counts[pred] += 1\n",
"\n",
"for i in range(logits.shape[1]):\n",
" print(f\"sample / actual / original {i}: {counts[i] / iters: 0.3f} / {penalty_ps[0,i]: 0.3f} / {original_ps[0,i]: 0.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 111,
"id": "afd3139e-924c-4ac0-a450-c462447a610a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-2. , 1. , 1.5, 2. , 2.5]])"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logits / penalty"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "ea9f993e-12f8-4bd0-9f8c-885bf8e56a7b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.00446236 0.08962882 0.24363641 0.66227241]]\n",
"{0: 0.0041, 1: 0.0909, 2: 0.2418, 3: 0.6632}\n"
]
}
],
"source": [
"from utils import Sampling\n",
"import numpy as np\n",
"\n",
"sampling = Sampling(seed=1)\n",
"\n",
"logits = np.array([[-2,1,2,3]])\n",
"probs = sampling.softmax(logits, axis=1)\n",
"\n",
"print(probs)\n",
"counts = {a: 0 for a in range(logits.shape[1])}\n",
"iters = 10000\n",
"for i in range(iters):\n",
" counts[sampling(logits=logits)] +=1\n",
"\n",
"for key in counts:\n",
" counts[key] /= iters\n",
"\n",
"print(counts)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 42,
@ -141,7 +908,7 @@
" }\n", " }\n",
" with requests.post(url, json=obj) as r:\n", " with requests.post(url, json=obj) as r:\n",
" print(max_new_tokens)\n", " print(max_new_tokens)\n",
" print(r.text)" " print(r.text)\n"
] ]
}, },
{ {