From a875c05ccdbc9e3c6c7e47572447e21a2e5b51b2 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Fri, 25 Aug 2023 16:11:09 +0000 Subject: [PATCH] implemented temperature, repetition penalty --- server-dev.ipynb | 86 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/server-dev.ipynb b/server-dev.ipynb index d2ef7b3b..23a48e3b 100644 --- a/server-dev.ipynb +++ b/server-dev.ipynb @@ -11,6 +11,92 @@ "%autoreload 2" ] }, + { + "cell_type": "code", + "execution_count": 42, + "id": "d8cd5290-a55b-44c0-ab2c-b34299a1da7c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0.01 -0.01 0.01 -0.01 0.01 -0.01 0.01 -0.01]]\n", + "[[0 2 5 7]]\n", + "[[ 0.005 -0.01 0.005 -0.01 0.01 -0.02 0.01 -0.02 ]]\n", + "[[ 0.0025 -0.005 0.0025 -0.005 0.005 -0.01 0.005 -0.01 ]]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "class RepetitionPenaltyLogitsProcessor:\n", + " def __init__(self, penalty: float):\n", + " if not isinstance(penalty, float) or not (penalty > 0):\n", + " raise ValueError(f\"`penalty` has to be a strictly positive float, but is {penalty}\")\n", + "\n", + " self.penalty = penalty\n", + "\n", + " def __call__(self, scores: np.ndarray, input_ids: np.ndarray) -> np.ndarray:\n", + " # assert shape is [1, vocab_size]\n", + " assert len(scores.shape) == 2\n", + " assert scores.shape[0] == 1\n", + "\n", + " # assert shape is [1, seq_len]\n", + " assert len(input_ids.shape) == 2\n", + " assert input_ids.shape[0] == 1\n", + " \n", + " # TODO: update logic to handle b > 1\n", + " score = scores[:, input_ids[0]]\n", + " score = np.where(score < 0, score * self.penalty, score / self.penalty)\n", + " scores[:, input_ids[0]] = score\n", + "\n", + " return scores\n", + "\n", + "class TemperatureLogitsWarper:\n", + " def __init__(self, temperature: float):\n", + " if not isinstance(temperature, float) or not (temperature > 0):\n", + " except_msg = (\n", + " f\"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token \"\n", + " \"scores will be invalid.\"\n", + " )\n", + " if isinstance(temperature, float) and temperature == 0.0:\n", + " except_msg += \" If you're looking for greedy decoding strategies, set `do_sample=False`.\"\n", + " raise ValueError(except_msg)\n", + " \n", + " self.temperature = temperature\n", + "\n", + " def __call__(self, scores: np.ndarray) -> np.ndarray:\n", + " # assert shape is [1, vocab_size]\n", + " assert len(scores.shape) == 2\n", + " assert scores.shape[0] == 1\n", + "\n", + " return scores / self.temperature\n", + "\n", + "input_ids = np.array([[0,2,5,7]])\n", + "logits = np.array([[0.01, -0.01]*4])\n", + "\n", + "print(logits)\n", + "print(input_ids)\n", + "\n", + "processor = RepetitionPenaltyLogitsProcessor(penalty=2.0)\n", + "logits = processor(scores=logits, input_ids=input_ids)\n", + "print(logits)\n", + "\n", + "warper = TemperatureLogitsWarper(temperature=2.0)\n", + "logits = warper(scores=logits)\n", + "print(logits)" + ] + }, + { + "cell_type": "markdown", + "id": "e548af40-7b71-4f96-98b5-f33f03ef3f66", + "metadata": {}, + "source": [ + "p# **Interacting with FastAPI Server**" + ] + }, { "cell_type": "code", "execution_count": 17,