implemented temperature, repetition penalty

This commit is contained in:
rsnm2 2023-08-25 16:11:09 +00:00
parent 96f8365996
commit a875c05ccd

View File

@ -11,6 +11,92 @@
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "d8cd5290-a55b-44c0-ab2c-b34299a1da7c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.01 -0.01 0.01 -0.01 0.01 -0.01 0.01 -0.01]]\n",
"[[0 2 5 7]]\n",
"[[ 0.005 -0.01 0.005 -0.01 0.01 -0.02 0.01 -0.02 ]]\n",
"[[ 0.0025 -0.005 0.0025 -0.005 0.005 -0.01 0.005 -0.01 ]]\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"class RepetitionPenaltyLogitsProcessor:\n",
" def __init__(self, penalty: float):\n",
" if not isinstance(penalty, float) or not (penalty > 0):\n",
" raise ValueError(f\"`penalty` has to be a strictly positive float, but is {penalty}\")\n",
"\n",
" self.penalty = penalty\n",
"\n",
" def __call__(self, scores: np.ndarray, input_ids: np.ndarray) -> np.ndarray:\n",
" # assert shape is [1, vocab_size]\n",
" assert len(scores.shape) == 2\n",
" assert scores.shape[0] == 1\n",
"\n",
" # assert shape is [1, seq_len]\n",
" assert len(input_ids.shape) == 2\n",
" assert input_ids.shape[0] == 1\n",
" \n",
" # TODO: update logic to handle b > 1\n",
" score = scores[:, input_ids[0]]\n",
" score = np.where(score < 0, score * self.penalty, score / self.penalty)\n",
" scores[:, input_ids[0]] = score\n",
"\n",
" return scores\n",
"\n",
"class TemperatureLogitsWarper:\n",
" def __init__(self, temperature: float):\n",
" if not isinstance(temperature, float) or not (temperature > 0):\n",
" except_msg = (\n",
" f\"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token \"\n",
" \"scores will be invalid.\"\n",
" )\n",
" if isinstance(temperature, float) and temperature == 0.0:\n",
" except_msg += \" If you're looking for greedy decoding strategies, set `do_sample=False`.\"\n",
" raise ValueError(except_msg)\n",
" \n",
" self.temperature = temperature\n",
"\n",
" def __call__(self, scores: np.ndarray) -> np.ndarray:\n",
" # assert shape is [1, vocab_size]\n",
" assert len(scores.shape) == 2\n",
" assert scores.shape[0] == 1\n",
"\n",
" return scores / self.temperature\n",
"\n",
"input_ids = np.array([[0,2,5,7]])\n",
"logits = np.array([[0.01, -0.01]*4])\n",
"\n",
"print(logits)\n",
"print(input_ids)\n",
"\n",
"processor = RepetitionPenaltyLogitsProcessor(penalty=2.0)\n",
"logits = processor(scores=logits, input_ids=input_ids)\n",
"print(logits)\n",
"\n",
"warper = TemperatureLogitsWarper(temperature=2.0)\n",
"logits = warper(scores=logits)\n",
"print(logits)"
]
},
{
"cell_type": "markdown",
"id": "e548af40-7b71-4f96-98b5-f33f03ef3f66",
"metadata": {},
"source": [
"p# **Interacting with FastAPI Server**"
]
},
{
"cell_type": "code",
"execution_count": 17,