mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
implemented temperature, repetition penalty
This commit is contained in:
parent
96f8365996
commit
a875c05ccd
@ -11,6 +11,92 @@
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"id": "d8cd5290-a55b-44c0-ab2c-b34299a1da7c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[ 0.01 -0.01 0.01 -0.01 0.01 -0.01 0.01 -0.01]]\n",
|
||||
"[[0 2 5 7]]\n",
|
||||
"[[ 0.005 -0.01 0.005 -0.01 0.01 -0.02 0.01 -0.02 ]]\n",
|
||||
"[[ 0.0025 -0.005 0.0025 -0.005 0.005 -0.01 0.005 -0.01 ]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"class RepetitionPenaltyLogitsProcessor:\n",
|
||||
" def __init__(self, penalty: float):\n",
|
||||
" if not isinstance(penalty, float) or not (penalty > 0):\n",
|
||||
" raise ValueError(f\"`penalty` has to be a strictly positive float, but is {penalty}\")\n",
|
||||
"\n",
|
||||
" self.penalty = penalty\n",
|
||||
"\n",
|
||||
" def __call__(self, scores: np.ndarray, input_ids: np.ndarray) -> np.ndarray:\n",
|
||||
" # assert shape is [1, vocab_size]\n",
|
||||
" assert len(scores.shape) == 2\n",
|
||||
" assert scores.shape[0] == 1\n",
|
||||
"\n",
|
||||
" # assert shape is [1, seq_len]\n",
|
||||
" assert len(input_ids.shape) == 2\n",
|
||||
" assert input_ids.shape[0] == 1\n",
|
||||
" \n",
|
||||
" # TODO: update logic to handle b > 1\n",
|
||||
" score = scores[:, input_ids[0]]\n",
|
||||
" score = np.where(score < 0, score * self.penalty, score / self.penalty)\n",
|
||||
" scores[:, input_ids[0]] = score\n",
|
||||
"\n",
|
||||
" return scores\n",
|
||||
"\n",
|
||||
"class TemperatureLogitsWarper:\n",
|
||||
" def __init__(self, temperature: float):\n",
|
||||
" if not isinstance(temperature, float) or not (temperature > 0):\n",
|
||||
" except_msg = (\n",
|
||||
" f\"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token \"\n",
|
||||
" \"scores will be invalid.\"\n",
|
||||
" )\n",
|
||||
" if isinstance(temperature, float) and temperature == 0.0:\n",
|
||||
" except_msg += \" If you're looking for greedy decoding strategies, set `do_sample=False`.\"\n",
|
||||
" raise ValueError(except_msg)\n",
|
||||
" \n",
|
||||
" self.temperature = temperature\n",
|
||||
"\n",
|
||||
" def __call__(self, scores: np.ndarray) -> np.ndarray:\n",
|
||||
" # assert shape is [1, vocab_size]\n",
|
||||
" assert len(scores.shape) == 2\n",
|
||||
" assert scores.shape[0] == 1\n",
|
||||
"\n",
|
||||
" return scores / self.temperature\n",
|
||||
"\n",
|
||||
"input_ids = np.array([[0,2,5,7]])\n",
|
||||
"logits = np.array([[0.01, -0.01]*4])\n",
|
||||
"\n",
|
||||
"print(logits)\n",
|
||||
"print(input_ids)\n",
|
||||
"\n",
|
||||
"processor = RepetitionPenaltyLogitsProcessor(penalty=2.0)\n",
|
||||
"logits = processor(scores=logits, input_ids=input_ids)\n",
|
||||
"print(logits)\n",
|
||||
"\n",
|
||||
"warper = TemperatureLogitsWarper(temperature=2.0)\n",
|
||||
"logits = warper(scores=logits)\n",
|
||||
"print(logits)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e548af40-7b71-4f96-98b5-f33f03ef3f66",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"p# **Interacting with FastAPI Server**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
|
Loading…
Reference in New Issue
Block a user