mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
successfully implemented and integrated nexttokenchooser with support for repetition penalty, do_sample, temperature, top_k, and top_p
This commit is contained in:
parent
a875c05ccd
commit
bb124e4029
@ -2,7 +2,7 @@ from queue import Queue
|
|||||||
from typing import List, Dict, Optional, Tuple
|
from typing import List, Dict, Optional, Tuple
|
||||||
from service.service import DeepSparseService
|
from service.service import DeepSparseService
|
||||||
from service.causal_lm import DeepSparseCausalLM
|
from service.causal_lm import DeepSparseCausalLM
|
||||||
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, StoppingCriteria
|
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, GenerationParameters
|
||||||
|
|
||||||
class DeepSparseRouter:
|
class DeepSparseRouter:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -11,10 +11,7 @@ class DeepSparseRouter:
|
|||||||
model_path: Optional[str] = None,
|
model_path: Optional[str] = None,
|
||||||
tokenizer_path: Optional[str] = None
|
tokenizer_path: Optional[str] = None
|
||||||
):
|
):
|
||||||
assert (
|
assert service is not None or (model_path is not None and tokenizer_path is not None)
|
||||||
service is not None or
|
|
||||||
(model_path is not None and tokenizer_path is not None)
|
|
||||||
)
|
|
||||||
|
|
||||||
if service is not None:
|
if service is not None:
|
||||||
self.service = service
|
self.service = service
|
||||||
@ -38,8 +35,8 @@ class DeepSparseRouter:
|
|||||||
|
|
||||||
# unblock the batching task with a dummy request if blocked
|
# unblock the batching task with a dummy request if blocked
|
||||||
self.queue.append(GenerateRequest(
|
self.queue.append(GenerateRequest(
|
||||||
inputs="dummy",
|
inputs="stop",
|
||||||
max_new_tokens=1,
|
generation_parameters=GenerationParameters(max_new_tokens=1),
|
||||||
response_stream=Queue()
|
response_stream=Queue()
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -166,9 +163,10 @@ class DeepSparseQueue:
|
|||||||
|
|
||||||
# if block = True, this blocks until something ready
|
# if block = True, this blocks until something ready
|
||||||
# if block = False, the queue has data (if not an exception is raised)
|
# if block = False, the queue has data (if not an exception is raised)
|
||||||
# while queue.empty() == False does not guarentee data
|
# while queue.empty() == False typically not guarentee data on next queue.get(), this
|
||||||
# the queue is only subscribed to by one thread (this one)
|
# queue is only subscribed to by one thread (this one) since batching_task is the only
|
||||||
# since batching_task is the only function that calls next_batch
|
# so it does in our case
|
||||||
|
|
||||||
generate_request = self.queue.get(block=block)
|
generate_request = self.queue.get(block=block)
|
||||||
generate_requests = {self.next_request_id: generate_request}
|
generate_requests = {self.next_request_id: generate_request}
|
||||||
|
|
||||||
@ -176,7 +174,7 @@ class DeepSparseQueue:
|
|||||||
request = Request(
|
request = Request(
|
||||||
id=self.next_request_id,
|
id=self.next_request_id,
|
||||||
inputs=generate_request.inputs,
|
inputs=generate_request.inputs,
|
||||||
max_new_tokens=generate_request.max_new_tokens
|
generation_parameters=generate_request.generation_parameters,
|
||||||
)
|
)
|
||||||
self.next_request_id += 1
|
self.next_request_id += 1
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
|
from service.model import DeepSparsePastKeyValues, DeepSparseDecoderModel
|
||||||
from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria
|
from utils import Request, Batch, CachedBatch, Generation, StoppingCriteria, NextTokenChooser
|
||||||
|
|
||||||
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
||||||
DEEPSPARSE_MULTITOKEN_LENGTH = 4
|
DEEPSPARSE_MULTITOKEN_LENGTH = 4
|
||||||
@ -17,6 +17,7 @@ class DeepSparseCausalLMBatch:
|
|||||||
input_ids_list: List[np.ndarray]
|
input_ids_list: List[np.ndarray]
|
||||||
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
|
past_key_values_list: List[Optional[DeepSparsePastKeyValues]]
|
||||||
stopping_criteria_list: List[StoppingCriteria]
|
stopping_criteria_list: List[StoppingCriteria]
|
||||||
|
next_token_chooser_list: List[NextTokenChooser]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_batch(
|
def from_batch(
|
||||||
@ -29,6 +30,7 @@ class DeepSparseCausalLMBatch:
|
|||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
input_ids_list = []
|
input_ids_list = []
|
||||||
stopping_criteria_list = []
|
stopping_criteria_list = []
|
||||||
|
next_token_chooser_list = []
|
||||||
|
|
||||||
# loop through items in the batch
|
# loop through items in the batch
|
||||||
for idx, r in enumerate(batch.requests):
|
for idx, r in enumerate(batch.requests):
|
||||||
@ -51,7 +53,19 @@ class DeepSparseCausalLMBatch:
|
|||||||
stopping_criteria_list.append(
|
stopping_criteria_list.append(
|
||||||
StoppingCriteria(
|
StoppingCriteria(
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
max_new_tokens=min(r.max_new_tokens, model_max_new_tokens)
|
max_new_tokens=min(r.generation_parameters.max_new_tokens, model_max_new_tokens)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# get next token chooser based on input
|
||||||
|
next_token_chooser_list.append(
|
||||||
|
NextTokenChooser(
|
||||||
|
repetition_penalty=r.generation_parameters.repetition_penalty,
|
||||||
|
temperature=r.generation_parameters.temperature,
|
||||||
|
top_k=r.generation_parameters.top_k,
|
||||||
|
top_p=r.generation_parameters.top_p,
|
||||||
|
do_sample=r.generation_parameters.do_sample,
|
||||||
|
seed=r.generation_parameters.seed,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -61,7 +75,8 @@ class DeepSparseCausalLMBatch:
|
|||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids_list=input_ids_list,
|
input_ids_list=input_ids_list,
|
||||||
past_key_values_list=[None] * len(batch.requests),
|
past_key_values_list=[None] * len(batch.requests),
|
||||||
stopping_criteria_list=stopping_criteria_list
|
stopping_criteria_list=stopping_criteria_list,
|
||||||
|
next_token_chooser_list=next_token_chooser_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_cached_batch(self) -> CachedBatch:
|
def to_cached_batch(self) -> CachedBatch:
|
||||||
@ -83,6 +98,7 @@ class DeepSparseCausalLMBatch:
|
|||||||
input_ids_list = []
|
input_ids_list = []
|
||||||
past_key_values_list = []
|
past_key_values_list = []
|
||||||
stopping_criteria_list = []
|
stopping_criteria_list = []
|
||||||
|
next_token_chooser_list = []
|
||||||
|
|
||||||
# loop through requests, keep ones that should remain
|
# loop through requests, keep ones that should remain
|
||||||
for new_idx, request_id in enumerate(request_ids):
|
for new_idx, request_id in enumerate(request_ids):
|
||||||
@ -95,6 +111,7 @@ class DeepSparseCausalLMBatch:
|
|||||||
input_ids_list.append(self.input_ids_list[old_idx])
|
input_ids_list.append(self.input_ids_list[old_idx])
|
||||||
past_key_values_list.append(self.past_key_values_list[old_idx])
|
past_key_values_list.append(self.past_key_values_list[old_idx])
|
||||||
stopping_criteria_list.append(self.stopping_criteria_list[old_idx])
|
stopping_criteria_list.append(self.stopping_criteria_list[old_idx])
|
||||||
|
next_token_chooser_list.append(self.next_token_chooser_list[old_idx])
|
||||||
|
|
||||||
# update batch state
|
# update batch state
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
@ -102,6 +119,7 @@ class DeepSparseCausalLMBatch:
|
|||||||
self.input_ids_list = input_ids_list
|
self.input_ids_list = input_ids_list
|
||||||
self.past_key_values_list = past_key_values_list
|
self.past_key_values_list = past_key_values_list
|
||||||
self.stopping_criteria_list = stopping_criteria_list
|
self.stopping_criteria_list = stopping_criteria_list
|
||||||
|
self.next_token_chooser_list = next_token_chooser_list
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -115,6 +133,7 @@ class DeepSparseCausalLMBatch:
|
|||||||
input_ids_list = []
|
input_ids_list = []
|
||||||
past_key_values_list = []
|
past_key_values_list = []
|
||||||
stopping_criteria_list = []
|
stopping_criteria_list = []
|
||||||
|
next_token_chooser_list = []
|
||||||
|
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
@ -125,6 +144,7 @@ class DeepSparseCausalLMBatch:
|
|||||||
input_ids_list.extend(batch.input_ids_list)
|
input_ids_list.extend(batch.input_ids_list)
|
||||||
past_key_values_list.extend(batch.past_key_values_list)
|
past_key_values_list.extend(batch.past_key_values_list)
|
||||||
stopping_criteria_list.extend(batch.stopping_criteria_list)
|
stopping_criteria_list.extend(batch.stopping_criteria_list)
|
||||||
|
next_token_chooser_list.extend(batch.next_token_chooser_list)
|
||||||
|
|
||||||
# merge the request_id to index mapping
|
# merge the request_id to index mapping
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@ -141,7 +161,8 @@ class DeepSparseCausalLMBatch:
|
|||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids_list=input_ids_list,
|
input_ids_list=input_ids_list,
|
||||||
past_key_values_list=past_key_values_list,
|
past_key_values_list=past_key_values_list,
|
||||||
stopping_criteria_list=stopping_criteria_list
|
stopping_criteria_list=stopping_criteria_list,
|
||||||
|
next_token_chooser_list=next_token_chooser_list
|
||||||
)
|
)
|
||||||
|
|
||||||
class DeepSparseCausalLM:
|
class DeepSparseCausalLM:
|
||||||
@ -164,18 +185,6 @@ class DeepSparseCausalLM:
|
|||||||
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
|
multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: switch to NextTokenChooser
|
|
||||||
def sample_token(
|
|
||||||
self,
|
|
||||||
logits: np.ndarray
|
|
||||||
):
|
|
||||||
# assert b=1 for now
|
|
||||||
assert(logits.shape[0] == 1)
|
|
||||||
|
|
||||||
# grab logits for the last item in the sequence
|
|
||||||
# shape == (batch, seq, vocabulary_size)
|
|
||||||
return np.argmax(logits[0,-1,:])
|
|
||||||
|
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self,
|
self,
|
||||||
batch: DeepSparseCausalLMBatch,
|
batch: DeepSparseCausalLMBatch,
|
||||||
@ -189,14 +198,21 @@ class DeepSparseCausalLM:
|
|||||||
# b) sample and check stopping criteria
|
# b) sample and check stopping criteria
|
||||||
# c) create generation
|
# c) create generation
|
||||||
# d) update batch
|
# d) update batch
|
||||||
for i, (request, input_ids, past_key_values,stopping_criteria,) in enumerate(
|
|
||||||
zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.input_ids_list,
|
batch.input_ids_list,
|
||||||
batch.past_key_values_list,
|
batch.past_key_values_list,
|
||||||
batch.stopping_criteria_list,
|
batch.stopping_criteria_list,
|
||||||
|
batch.next_token_chooser_list,
|
||||||
)
|
)
|
||||||
):
|
for i, (
|
||||||
|
request,
|
||||||
|
input_ids,
|
||||||
|
past_key_values,
|
||||||
|
stopping_criteria,
|
||||||
|
next_token_chooser
|
||||||
|
) in enumerate(iterator):
|
||||||
# assert input_ids is b=1
|
# assert input_ids is b=1
|
||||||
assert len(input_ids.shape) == 2
|
assert len(input_ids.shape) == 2
|
||||||
assert input_ids.shape[0] == 1
|
assert input_ids.shape[0] == 1
|
||||||
@ -206,7 +222,7 @@ class DeepSparseCausalLM:
|
|||||||
|
|
||||||
# b) sample token and check stopping criteria
|
# b) sample token and check stopping criteria
|
||||||
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
|
# TODO: should use NextTokenChooser/StoppingCriteria (simple for now)
|
||||||
generated_token_id = self.sample_token(logits)
|
generated_token_id = next_token_chooser(input_ids=input_ids, scores=logits[:,-1,:])
|
||||||
generated_token = self.tokenizer.decode(generated_token_id)
|
generated_token = self.tokenizer.decode(generated_token_id)
|
||||||
|
|
||||||
stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id)
|
stop, finish_reason = stopping_criteria(generated_token_id=generated_token_id)
|
||||||
|
@ -1,7 +1,72 @@
|
|||||||
|
import numpy as np
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from logits_process import (RepetitionPenaltyLogitsProcessor, LogitsWarpers, softmax)
|
||||||
|
|
||||||
|
# TODO: sample for b > 1 with vectorized code
|
||||||
|
class Greedy:
|
||||||
|
def __call__(self, logits: np.ndarray):
|
||||||
|
# assert b=1 for now
|
||||||
|
# shape == (batch, vocabulary_size)
|
||||||
|
assert(logits.shape[0] == 1)
|
||||||
|
assert(len(logits.shape) == 2)
|
||||||
|
|
||||||
|
return np.argmax(logits[0,:])
|
||||||
|
|
||||||
|
# TODO: sample for b > 1 with vectorized code
|
||||||
|
# https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a
|
||||||
|
class Sampling:
|
||||||
|
def __init__(self, seed:int=42):
|
||||||
|
self.generator = np.random.default_rng(seed=seed)
|
||||||
|
|
||||||
|
def __call__(self, logits:np.ndarray):
|
||||||
|
# assert b=1 for now
|
||||||
|
# shape == (batch, vocabulary_size)
|
||||||
|
assert(logits.shape[0] == 1)
|
||||||
|
assert(len(logits.shape) == 2)
|
||||||
|
|
||||||
|
probs = softmax(logits, axis=1)
|
||||||
|
return self.generator.choice(probs.shape[1], p=probs[0,:])
|
||||||
|
|
||||||
|
class NextTokenChooser:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repetition_penalty: Optional[float] = 1.0,
|
||||||
|
temperature: Optional[float] = 1.0,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
do_sample:bool = False,
|
||||||
|
seed: int = 42,
|
||||||
|
):
|
||||||
|
self.repetition_processor = (
|
||||||
|
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
||||||
|
if repetition_penalty and repetition_penalty != 1.0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
has_warpers = (
|
||||||
|
(temperature is not None and temperature != 1.0)
|
||||||
|
or (top_k is not None and top_k != 0)
|
||||||
|
or (top_p is not None and top_p < 1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_warpers:
|
||||||
|
self.warpers = LogitsWarpers(temperature=temperature, top_k=top_k, top_p=top_p)
|
||||||
|
else:
|
||||||
|
self.warpers = None
|
||||||
|
|
||||||
|
self.choice = Sampling(seed=seed) if do_sample or has_warpers else Greedy()
|
||||||
|
|
||||||
|
def __call__(self, input_ids: np.ndarray, scores:np.ndarray) -> int:
|
||||||
|
if self.repetition_processor is not None:
|
||||||
|
scores = self.repetition_processor(input_ids=input_ids, scores=scores)
|
||||||
|
|
||||||
|
if self.warpers is not None:
|
||||||
|
scores = self.warpers(scores=scores)
|
||||||
|
|
||||||
|
return self.choice(scores)
|
||||||
|
|
||||||
class FinishReason(Enum):
|
class FinishReason(Enum):
|
||||||
FINISH_REASON_LENGTH = 1
|
FINISH_REASON_LENGTH = 1
|
||||||
@ -28,11 +93,28 @@ class StoppingCriteria:
|
|||||||
|
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
class GenerationParameters(BaseModel):
|
||||||
|
max_new_tokens: int = Field(default=20)
|
||||||
|
repetition_penalty: float = Field(default=1)
|
||||||
|
do_sample: bool = Field(default=False)
|
||||||
|
temperature: float = Field(default=1.0)
|
||||||
|
top_k: Optional[int] = Field(default=None)
|
||||||
|
top_p: Optional[float] = Field(default=None)
|
||||||
|
seed: int = Field(default=42)
|
||||||
|
|
||||||
|
class GenerateRequestInputs(BaseModel):
|
||||||
|
inputs: str
|
||||||
|
generation_parameters: GenerationParameters
|
||||||
|
|
||||||
|
class GenerateRequestOutputs(BaseModel):
|
||||||
|
response_text: str = Field(default="")
|
||||||
|
finish_reason: Optional[FinishReason] = Field(default=None)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Request:
|
class Request:
|
||||||
id: int
|
id: int
|
||||||
inputs: str
|
inputs: str
|
||||||
max_new_tokens: int
|
generation_parameters: GenerationParameters
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Batch:
|
class Batch:
|
||||||
@ -58,5 +140,13 @@ class Generation:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class GenerateRequest:
|
class GenerateRequest:
|
||||||
inputs: str
|
inputs: str
|
||||||
max_new_tokens: int
|
generation_parameters: GenerationParameters
|
||||||
response_stream: Queue[Generation]
|
response_stream: Queue[Generation]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_gr_inputs(cls, gr_inputs: GenerateRequestInputs):
|
||||||
|
return cls(
|
||||||
|
inputs=gr_inputs.inputs,
|
||||||
|
generation_parameters=gr_inputs.generation_parameters,
|
||||||
|
response_stream=Queue()
|
||||||
|
)
|
769
server-dev.ipynb
769
server-dev.ipynb
@ -11,6 +11,773 @@
|
|||||||
"%autoreload 2"
|
"%autoreload 2"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 89,
|
||||||
|
"id": "7513b2a1-0749-44b5-88a9-d91d5b175e8e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"45\n",
|
||||||
|
"Finish the following function for computing a fibonacci sequence: \n",
|
||||||
|
"\n",
|
||||||
|
"def fib(n):\n",
|
||||||
|
" if n<=1:\n",
|
||||||
|
" return n\n",
|
||||||
|
" else:\n",
|
||||||
|
" return fib(n-1)+fib(n-2)\n",
|
||||||
|
" \n",
|
||||||
|
"print(fib(15))\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"50\n",
|
||||||
|
"Finish the following function for computing a fibonacci sequence: \n",
|
||||||
|
"\n",
|
||||||
|
"def fib(n):\n",
|
||||||
|
" if n == 0:\n",
|
||||||
|
" return 0\n",
|
||||||
|
" elif n == 1:\n",
|
||||||
|
" return 1\n",
|
||||||
|
" else:\n",
|
||||||
|
" return fib(n - 1) + fib(n - 2)\n",
|
||||||
|
"\n",
|
||||||
|
"n = int(\n",
|
||||||
|
"55\n",
|
||||||
|
"Finish the following function for computing a fibonacci sequence: \n",
|
||||||
|
"\n",
|
||||||
|
"def fib(n):\n",
|
||||||
|
" x = 1\n",
|
||||||
|
" y = 1\n",
|
||||||
|
" if n < 2:\n",
|
||||||
|
" return 1\n",
|
||||||
|
" else:\n",
|
||||||
|
" for i in range(2, n):\n",
|
||||||
|
" z = x + y\n",
|
||||||
|
" x = y\n",
|
||||||
|
" y = z\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import requests\n",
|
||||||
|
"from threading import Thread\n",
|
||||||
|
"import json\n",
|
||||||
|
"\n",
|
||||||
|
"url = \"http://127.0.0.1:5543/generate\"\n",
|
||||||
|
"sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\ndef fib(n):\"\n",
|
||||||
|
"# sequence = \"def fib(n):\"\n",
|
||||||
|
"\n",
|
||||||
|
"def request_task(seed, max_new_tokens):\n",
|
||||||
|
" obj = {\n",
|
||||||
|
" \"inputs\": sequence,\n",
|
||||||
|
" \"generation_parameters\": {\n",
|
||||||
|
" \"max_new_tokens\":max_new_tokens,\n",
|
||||||
|
" # \"repetition_penalty\": 1.1,\n",
|
||||||
|
" # \"do_sample\": True,\n",
|
||||||
|
" # \"temperature\": 1.1,\n",
|
||||||
|
" # \"top_k\": 3,\n",
|
||||||
|
" # \"top_p\": 0.9,\n",
|
||||||
|
" \"seed\": seed,\n",
|
||||||
|
" }\n",
|
||||||
|
" }\n",
|
||||||
|
" with requests.post(url, json=obj) as r:\n",
|
||||||
|
" print(max_new_tokens)\n",
|
||||||
|
" dct = json.loads(r.text)\n",
|
||||||
|
" print(f'{sequence}{dct[\"response_text\"]}')\n",
|
||||||
|
"\n",
|
||||||
|
"max_new_tokens_lst = [55, 50, 45]\n",
|
||||||
|
"seeds = [1,2,3]\n",
|
||||||
|
"# max_new_tokens_lst = [100, 200, 300]\n",
|
||||||
|
"\n",
|
||||||
|
"request_ts = [\n",
|
||||||
|
" Thread(target=request_task, args=[seed, max_new_tokens]) for seed, max_new_tokens in zip(seeds, max_new_tokens_lst)\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"import time\n",
|
||||||
|
"for request_t in request_ts:\n",
|
||||||
|
" request_t.start()\n",
|
||||||
|
" time.sleep(0.1)\n",
|
||||||
|
"\n",
|
||||||
|
"for request_t in request_ts:\n",
|
||||||
|
" request_t.join()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 36,
|
||||||
|
"id": "28964c0b-0822-48aa-bc4c-a519035fae9a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"obj = {\n",
|
||||||
|
" \"inputs\": sequence,\n",
|
||||||
|
" \"generation_parameters\": {\n",
|
||||||
|
" \"max_new_tokens\":100,\n",
|
||||||
|
" \"repetion_penalty\": 2.0,\n",
|
||||||
|
" # \"do_sample\": False,\n",
|
||||||
|
" # \"temperature\": 1.0,\n",
|
||||||
|
" # \"top_k\": 1,\n",
|
||||||
|
" # \"top_p\": 0.9,\n",
|
||||||
|
" # \"seed\": 43,\n",
|
||||||
|
" }\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
"resp = requests.post(url, json=obj)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 47,
|
||||||
|
"id": "bdb67325-1341-4da4-808f-c1b12fdbf607",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"a = \"\"\n",
|
||||||
|
"token = '\\n'\n",
|
||||||
|
"for i in range(5):\n",
|
||||||
|
" a += token"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 48,
|
||||||
|
"id": "96aebc29-2c43-456e-ab62-3c422feb57c6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(a)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 44,
|
||||||
|
"id": "85447c0d-7020-46af-826f-fec58c9097b8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\"\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\ndef fib2(n):\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib2(n-1) + fib2(n-2)\\n\\ndef\"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(resp.text)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "ff94c668-5832-4096-a723-3fc673b90495",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/robertgshaw/deepsparse-continuous-batching/deepsparse\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%cd deepsparse"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "d314c9f2-3fa7-44c4-8609-3bb24d008bb0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[-1. 1. 4. 3. 2. 1.]]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from logits_process import TopKLogitsWarper, TopPLogitsWarper, softmax\n",
|
||||||
|
"\n",
|
||||||
|
"scores = np.array([[-1.,1.,4.,3.,2.,1.]])\n",
|
||||||
|
"\n",
|
||||||
|
"sorted_indices = np.argsort(scores, axis=-1)\n",
|
||||||
|
"sorted_scores = np.take_along_axis(scores, sorted_indices, axis=-1)\n",
|
||||||
|
"\n",
|
||||||
|
"# sort, grabbing all those outside top_p\n",
|
||||||
|
"cumulative_probs = softmax(sorted_scores, axis=-1).cumsum(axis=-1)\n",
|
||||||
|
"sorted_indices_to_remove = cumulative_probs <= (1 - 0.9)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# note: this relies on b=1\n",
|
||||||
|
"indices_to_remove = sorted_indices[sorted_indices_to_remove]\n",
|
||||||
|
"\n",
|
||||||
|
"# set removed indices logits to -Inf (never selected)\n",
|
||||||
|
"scores[:, indices_to_remove] = -float(\"Inf\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 93,
|
||||||
|
"id": "cced6917-8017-405a-ae93-c43105fce82d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[-2. 1. 5. 3. 2. 1.]\n",
|
||||||
|
" [ 5. -2. 3. -1. 2. 1.]]\n",
|
||||||
|
"[[-inf -inf 5. 3. 2. -inf]\n",
|
||||||
|
" [ 5. -inf 3. -inf 2. -inf]]\n",
|
||||||
|
"[[0. 0. 0.84379473 0.1141952 0.04201007 0. ]\n",
|
||||||
|
" [0.84379473 0. 0.1141952 0. 0.04201007 0. ]]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from logits_process import TopKLogitsWarper, TopPLogitsWarper, softmax\n",
|
||||||
|
"\n",
|
||||||
|
"logits = np.array([[-2.,1.,5.,3.,2.,1.], [5.,-2.,3.,-1.,2.,1.]])\n",
|
||||||
|
"print(logits)\n",
|
||||||
|
"\n",
|
||||||
|
"top_k = TopKLogitsWarper(top_k=3)\n",
|
||||||
|
"new_logits = top_k(logits)\n",
|
||||||
|
"print(new_logits)\n",
|
||||||
|
"\n",
|
||||||
|
"print(softmax(new_logits, axis=1))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 97,
|
||||||
|
"id": "44fb3e44-0888-49af-a5d6-2c908db324ec",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[-1. 1. 4. 3. 2. 1.]]\n",
|
||||||
|
"[[0.00418629 0.03093274 0.62130062 0.22856372 0.0840839 0.03093274]]\n",
|
||||||
|
"[[-inf -inf 4. 3. 2. -inf]]\n",
|
||||||
|
"[[0. 0. 0.66524096 0.24472847 0.09003057 0. ]]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"logits = np.array([[-1.,1.,4.,3.,2.,1.]])\n",
|
||||||
|
"print(logits)\n",
|
||||||
|
"\n",
|
||||||
|
"print(softmax(logits, axis=-1))\n",
|
||||||
|
"\n",
|
||||||
|
"top_p = TopPLogitsWarper(top_p=0.9)\n",
|
||||||
|
"new_logits = top_p(logits)\n",
|
||||||
|
"print(new_logits)\n",
|
||||||
|
"\n",
|
||||||
|
"print(softmax(new_logits, axis=-1))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 102,
|
||||||
|
"id": "f20813aa-0362-4820-a3b2-c32531658718",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[-1. 1. 3. 3.5 2.5 2. 1. ]]\n",
|
||||||
|
"[[0.00468177 0.03459387 0.25561604 0.4214396 0.15503896 0.09403589\n",
|
||||||
|
" 0.03459387]]\n",
|
||||||
|
"[[-inf -inf 3. 3.5 2.5 2. -inf]]\n",
|
||||||
|
"[[-inf -inf 3. 3.5 2.5 -inf -inf]]\n",
|
||||||
|
"[[0. 0. 0.30719589 0.50648039 0.18632372 0.\n",
|
||||||
|
" 0. ]]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"logits = np.array([[-1.,1.,3.,3.5,2.5,2.,1.]])\n",
|
||||||
|
"print(logits)\n",
|
||||||
|
"\n",
|
||||||
|
"print(softmax(logits, axis=-1))\n",
|
||||||
|
"\n",
|
||||||
|
"top_p = TopPLogitsWarper(top_p=0.9)\n",
|
||||||
|
"top_k = TopKLogitsWarper(top_k=3)\n",
|
||||||
|
"\n",
|
||||||
|
"new_logits = top_p(logits)\n",
|
||||||
|
"print(new_logits)\n",
|
||||||
|
"\n",
|
||||||
|
"new_new_logits = top_k(new_logits)\n",
|
||||||
|
"print(new_new_logits)\n",
|
||||||
|
"print(softmax(logits, axis=-1))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 113,
|
||||||
|
"id": "6914e436-350c-4c6a-b5ca-4146f349b1f3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[-1. 1. 3. 5. 2.5 2. 1. ]]\n",
|
||||||
|
"[[0.00189751 0.01402082 0.10360061 0.76551075 0.06283695 0.03811254\n",
|
||||||
|
" 0.01402082]]\n",
|
||||||
|
"[[-inf 1. 3. 5. 2.5 2. 1. ]]\n",
|
||||||
|
"[[-inf -inf 3. 5. 2.5 -inf -inf]]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"logits = np.array([[-1.,1.,3.,5,2.5,2.,1.]])\n",
|
||||||
|
"print(logits)\n",
|
||||||
|
"print(softmax(logits, axis=-1))\n",
|
||||||
|
"\n",
|
||||||
|
"top_p = TopPLogitsWarper(top_p=0.99)\n",
|
||||||
|
"top_k = TopKLogitsWarper(top_k=3)\n",
|
||||||
|
"\n",
|
||||||
|
"new_logits = top_p(logits)\n",
|
||||||
|
"print(new_logits)\n",
|
||||||
|
"\n",
|
||||||
|
"new_new_logits = top_k(new_logits)\n",
|
||||||
|
"print(new_new_logits)\n",
|
||||||
|
"# print(softmax(new_new_logits, axis=-1))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "19ccfb53-8e40-46b4-9987-3081add2ccdb",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"array([[False, False, False, False, False]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"cumulative_probs <= (1 - top_p)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "483cea72-9955-4742-b599-6f7ec7a6c6d9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"array([[0, 4, 3, 1, 2]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"sorted_indices"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 65,
|
||||||
|
"id": "38abf685-b4ae-4dc3-ab45-acf6dfd4036f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "ModuleNotFoundError",
|
||||||
|
"evalue": "No module named 'utils'",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn[65], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NextTokenChooser\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# picks largest logits\u001b[39;00m\n\u001b[1;32m 4\u001b[0m ntc \u001b[38;5;241m=\u001b[39m NextTokenChooser(do_sample\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
|
||||||
|
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from utils import NextTokenChooser\n",
|
||||||
|
"\n",
|
||||||
|
"# picks largest logits\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=False)\n",
|
||||||
|
"logits = np.array([[-2,1,3,2]])\n",
|
||||||
|
"input_ids = np.array([[1,2,3,4]])\n",
|
||||||
|
"ntc(input_ids, logits)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 53,
|
||||||
|
"id": "a599ccae-6457-4bb9-8240-703aaf64474f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
|
||||||
|
"sample / actual 0: 0.004 / 0.004\n",
|
||||||
|
"sample / actual 1: 0.081 / 0.082\n",
|
||||||
|
"sample / actual 2: 0.615 / 0.608\n",
|
||||||
|
"sample / actual 3: 0.219 / 0.224\n",
|
||||||
|
"sample / actual 4: 0.081 / 0.082\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# samples\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=True)\n",
|
||||||
|
"logits = np.array([[-2,1,3,2,1]])\n",
|
||||||
|
"input_ids = np.array([[1,2,3,4]])\n",
|
||||||
|
"\n",
|
||||||
|
"probs = ntc.choice.softmax(logits)\n",
|
||||||
|
"print(probs)\n",
|
||||||
|
"\n",
|
||||||
|
"iters = 10000\n",
|
||||||
|
"counts = {a: 0 for a in range(logits.shape[1])}\n",
|
||||||
|
"\n",
|
||||||
|
"for _ in range(iters):\n",
|
||||||
|
" pred = ntc(input_ids, logits)\n",
|
||||||
|
" counts[pred] += 1\n",
|
||||||
|
"\n",
|
||||||
|
"for i in range(logits.shape[1]):\n",
|
||||||
|
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 57,
|
||||||
|
"id": "76f4d55d-7314-46e8-91e8-4a237d9d2ec1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
|
||||||
|
"sample / actual 0: 0.066 / 0.004\n",
|
||||||
|
"sample / actual 1: 0.176 / 0.082\n",
|
||||||
|
"sample / actual 2: 0.342 / 0.608\n",
|
||||||
|
"sample / actual 3: 0.244 / 0.224\n",
|
||||||
|
"sample / actual 4: 0.172 / 0.082\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# should pick logits that are less likely\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=True, temperature=3.0)\n",
|
||||||
|
"logits = np.array([[-2,1,3,2,1]])\n",
|
||||||
|
"input_ids = np.array([[1,2,3,4]])\n",
|
||||||
|
"\n",
|
||||||
|
"probs = ntc.choice.softmax(logits)\n",
|
||||||
|
"print(probs)\n",
|
||||||
|
"\n",
|
||||||
|
"iters = 10000\n",
|
||||||
|
"counts = {a: 0 for a in range(logits.shape[1])}\n",
|
||||||
|
"\n",
|
||||||
|
"for _ in range(iters):\n",
|
||||||
|
" pred = ntc(input_ids, logits)\n",
|
||||||
|
" counts[pred] += 1\n",
|
||||||
|
"\n",
|
||||||
|
"for i in range(logits.shape[1]):\n",
|
||||||
|
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 61,
|
||||||
|
"id": "e172586d-e0f9-4429-9b73-52dccb2117ca",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
|
||||||
|
"sample / actual 0: 0.197 / 0.004\n",
|
||||||
|
"sample / actual 1: 0.201 / 0.082\n",
|
||||||
|
"sample / actual 2: 0.205 / 0.608\n",
|
||||||
|
"sample / actual 3: 0.203 / 0.224\n",
|
||||||
|
"sample / actual 4: 0.195 / 0.082\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# should approach uniform\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=True, temperature=100.0)\n",
|
||||||
|
"logits = np.array([[-2,1,3,2,1]])\n",
|
||||||
|
"input_ids = np.array([[1,2,3,4]])\n",
|
||||||
|
"\n",
|
||||||
|
"probs = ntc.choice.softmax(logits)\n",
|
||||||
|
"print(probs)\n",
|
||||||
|
"\n",
|
||||||
|
"iters = 10000\n",
|
||||||
|
"counts = {a: 0 for a in range(logits.shape[1])}\n",
|
||||||
|
"\n",
|
||||||
|
"for _ in range(iters):\n",
|
||||||
|
" pred = ntc(input_ids, logits)\n",
|
||||||
|
" counts[pred] += 1\n",
|
||||||
|
"\n",
|
||||||
|
"for i in range(logits.shape[1]):\n",
|
||||||
|
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 64,
|
||||||
|
"id": "606f2be7-4ae3-4898-8c52-ec71207a5ea3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
|
||||||
|
"sample / actual 0: 0.000 / 0.004\n",
|
||||||
|
"sample / actual 1: 0.015 / 0.082\n",
|
||||||
|
"sample / actual 2: 0.857 / 0.608\n",
|
||||||
|
"sample / actual 3: 0.111 / 0.224\n",
|
||||||
|
"sample / actual 4: 0.017 / 0.082\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# should pick logits that are more likely\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=True, temperature=0.5)\n",
|
||||||
|
"logits = np.array([[-2,1,3,2,1]])\n",
|
||||||
|
"input_ids = np.array([[1,2,3,4]])\n",
|
||||||
|
"\n",
|
||||||
|
"probs = ntc.choice.softmax(logits)\n",
|
||||||
|
"print(probs)\n",
|
||||||
|
"\n",
|
||||||
|
"iters = 10000\n",
|
||||||
|
"counts = {a: 0 for a in range(logits.shape[1])}\n",
|
||||||
|
"\n",
|
||||||
|
"for _ in range(iters):\n",
|
||||||
|
" pred = ntc(input_ids, logits)\n",
|
||||||
|
" counts[pred] += 1\n",
|
||||||
|
"\n",
|
||||||
|
"for i in range(logits.shape[1]):\n",
|
||||||
|
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 65,
|
||||||
|
"id": "ea0c96a5-fb8f-4f21-a30e-7443d17a440c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[0.0040953 0.08225629 0.60779634 0.22359578 0.08225629]]\n",
|
||||||
|
"sample / actual 0: 0.000 / 0.004\n",
|
||||||
|
"sample / actual 1: 0.000 / 0.082\n",
|
||||||
|
"sample / actual 2: 1.000 / 0.608\n",
|
||||||
|
"sample / actual 3: 0.000 / 0.224\n",
|
||||||
|
"sample / actual 4: 0.000 / 0.082\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# should approximate greedy sampling\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=True, temperature=0.001)\n",
|
||||||
|
"logits = np.array([[-2,1,3,2,1]])\n",
|
||||||
|
"input_ids = np.array([[1,2,3,4]])\n",
|
||||||
|
"\n",
|
||||||
|
"probs = ntc.choice.softmax(logits)\n",
|
||||||
|
"print(probs)\n",
|
||||||
|
"\n",
|
||||||
|
"iters = 10000\n",
|
||||||
|
"counts = {a: 0 for a in range(logits.shape[1])}\n",
|
||||||
|
"\n",
|
||||||
|
"for _ in range(iters):\n",
|
||||||
|
" pred = ntc(input_ids, logits)\n",
|
||||||
|
" counts[pred] += 1\n",
|
||||||
|
"\n",
|
||||||
|
"for i in range(logits.shape[1]):\n",
|
||||||
|
" print(f\"sample / actual {i}: {counts[i] / iters: 0.3f} / {probs[0,i]: 0.3f}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 134,
|
||||||
|
"id": "c9dcebaa-fa27-4d32-ad95-6a566d808868",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"4\n",
|
||||||
|
"4\n",
|
||||||
|
"2\n",
|
||||||
|
"4\n",
|
||||||
|
"3\n",
|
||||||
|
"1\n",
|
||||||
|
"sample / actual / original 0: 0.017 / 0.016 / 0.004\n",
|
||||||
|
"sample / actual / original 1: 0.119 / 0.122 / 0.032\n",
|
||||||
|
"sample / actual / original 2: 0.204 / 0.201 / 0.236\n",
|
||||||
|
"sample / actual / original 3: 0.329 / 0.331 / 0.087\n",
|
||||||
|
"sample / actual / original 4: 0.330 / 0.331 / 0.641\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# should approximate greedy sampling\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=False)\n",
|
||||||
|
"\n",
|
||||||
|
"# should select index 4\n",
|
||||||
|
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
|
||||||
|
"input_ids = np.array([[3,4]])\n",
|
||||||
|
"\n",
|
||||||
|
"print(ntc(input_ids,logits))\n",
|
||||||
|
"\n",
|
||||||
|
"# should select index 4\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=1.0)\n",
|
||||||
|
"logits = np.array([[-2,1,3,2,5]])\n",
|
||||||
|
"input_ids = np.array([[3,4]])\n",
|
||||||
|
"print(ntc(input_ids,logits))\n",
|
||||||
|
"\n",
|
||||||
|
"# should select index 2\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=2.0)\n",
|
||||||
|
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
|
||||||
|
"input_ids = np.array([[3,4]])\n",
|
||||||
|
"print(ntc(input_ids,logits))\n",
|
||||||
|
"\n",
|
||||||
|
"# should select index 4\n",
|
||||||
|
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
|
||||||
|
"input_ids = np.array([[2,4]])\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=2.)\n",
|
||||||
|
"print(ntc(input_ids,logits))\n",
|
||||||
|
"\n",
|
||||||
|
"# should select index 3\n",
|
||||||
|
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
|
||||||
|
"input_ids = np.array([[2,4]])\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=3.)\n",
|
||||||
|
"print(ntc(input_ids,logits))\n",
|
||||||
|
"\n",
|
||||||
|
"# should select index 1\n",
|
||||||
|
"logits = np.array([[-2.,1.,3.,2.,5.]])\n",
|
||||||
|
"input_ids = np.array([[2,3,4]])\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=False, repetition_penalty=5.)\n",
|
||||||
|
"print(ntc(input_ids,logits))\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# should make 2,4 less liekly\n",
|
||||||
|
"logits_og = np.array([[-1.,1.,3.,2.,4.]])\n",
|
||||||
|
"input_ids = np.array([[2,4]])\n",
|
||||||
|
"ntc = NextTokenChooser(do_sample=True, repetition_penalty=2.)\n",
|
||||||
|
"\n",
|
||||||
|
"original_ps = ntc.choice.softmax(logits_og,axis=1)\n",
|
||||||
|
"penalty = np.array([[1.,1.,2.,1.,2.]])\n",
|
||||||
|
"penalty_ps = ntc.choice.softmax(logits_og/penalty,axis=1)\n",
|
||||||
|
"\n",
|
||||||
|
"iters = 20000\n",
|
||||||
|
"counts = {a: 0 for a in range(logits.shape[1])}\n",
|
||||||
|
"\n",
|
||||||
|
"for _ in range(iters):\n",
|
||||||
|
" logits = logits_og.copy()\n",
|
||||||
|
" # print(logits)\n",
|
||||||
|
" pred = ntc(input_ids, logits)\n",
|
||||||
|
" counts[pred] += 1\n",
|
||||||
|
"\n",
|
||||||
|
"for i in range(logits.shape[1]):\n",
|
||||||
|
" print(f\"sample / actual / original {i}: {counts[i] / iters: 0.3f} / {penalty_ps[0,i]: 0.3f} / {original_ps[0,i]: 0.3f}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 111,
|
||||||
|
"id": "afd3139e-924c-4ac0-a450-c462447a610a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"array([[-2. , 1. , 1.5, 2. , 2.5]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 111,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"logits / penalty"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"id": "ea9f993e-12f8-4bd0-9f8c-885bf8e56a7b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[0.00446236 0.08962882 0.24363641 0.66227241]]\n",
|
||||||
|
"{0: 0.0041, 1: 0.0909, 2: 0.2418, 3: 0.6632}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from utils import Sampling\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"sampling = Sampling(seed=1)\n",
|
||||||
|
"\n",
|
||||||
|
"logits = np.array([[-2,1,2,3]])\n",
|
||||||
|
"probs = sampling.softmax(logits, axis=1)\n",
|
||||||
|
"\n",
|
||||||
|
"print(probs)\n",
|
||||||
|
"counts = {a: 0 for a in range(logits.shape[1])}\n",
|
||||||
|
"iters = 10000\n",
|
||||||
|
"for i in range(iters):\n",
|
||||||
|
" counts[sampling(logits=logits)] +=1\n",
|
||||||
|
"\n",
|
||||||
|
"for key in counts:\n",
|
||||||
|
" counts[key] /= iters\n",
|
||||||
|
"\n",
|
||||||
|
"print(counts)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 42,
|
"execution_count": 42,
|
||||||
@ -141,7 +908,7 @@
|
|||||||
" }\n",
|
" }\n",
|
||||||
" with requests.post(url, json=obj) as r:\n",
|
" with requests.post(url, json=obj) as r:\n",
|
||||||
" print(max_new_tokens)\n",
|
" print(max_new_tokens)\n",
|
||||||
" print(r.text)"
|
" print(r.text)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user