Merge pull request #5 from rsnm2/stopping-next-token-chooser

stash
This commit is contained in:
Robert Shaw 2023-08-27 20:14:34 -06:00 committed by GitHub
commit 59d3688ce6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,7 +13,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 89, "execution_count": 91,
"id": "7513b2a1-0749-44b5-88a9-d91d5b175e8e", "id": "7513b2a1-0749-44b5-88a9-d91d5b175e8e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -21,19 +21,8 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"45\n", "55\n",
"Finish the following function for computing a fibonacci sequence: \n", "{'response_text': '\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Driver function to test above function\\nn', 'finish_reason': 1}\n",
"\n",
"def fib(n):\n",
" if n<=1:\n",
" return n\n",
" else:\n",
" return fib(n-1)+fib(n-2)\n",
" \n",
"print(fib(15))\n",
"\n",
"\n",
"50\n",
"Finish the following function for computing a fibonacci sequence: \n", "Finish the following function for computing a fibonacci sequence: \n",
"\n", "\n",
"def fib(n):\n", "def fib(n):\n",
@ -42,23 +31,10 @@
" elif n == 1:\n", " elif n == 1:\n",
" return 1\n", " return 1\n",
" else:\n", " else:\n",
" return fib(n - 1) + fib(n - 2)\n", " return fib(n-1) + fib(n-2)\n",
"\n", "\n",
"n = int(\n", "# Driver function to test above function\n",
"55\n", "n\n"
"Finish the following function for computing a fibonacci sequence: \n",
"\n",
"def fib(n):\n",
" x = 1\n",
" y = 1\n",
" if n < 2:\n",
" return 1\n",
" else:\n",
" for i in range(2, n):\n",
" z = x + y\n",
" x = y\n",
" y = z\n",
"\n"
] ]
} }
], ],
@ -87,14 +63,17 @@
" with requests.post(url, json=obj) as r:\n", " with requests.post(url, json=obj) as r:\n",
" print(max_new_tokens)\n", " print(max_new_tokens)\n",
" dct = json.loads(r.text)\n", " dct = json.loads(r.text)\n",
" print(dct)\n",
" print(f'{sequence}{dct[\"response_text\"]}')\n", " print(f'{sequence}{dct[\"response_text\"]}')\n",
"\n", "\n",
"max_new_tokens_lst = [55, 50, 45]\n", "max_new_tokens_lst = [55, 50, 45]\n",
"seeds = [1,2,3]\n", "seeds = [1,2,3]\n",
"# max_new_tokens_lst = [100, 200, 300]\n", "# max_new_tokens_lst = [100, 200, 300]\n",
"\n", "\n",
"cnt = 1\n",
"request_ts = [\n", "request_ts = [\n",
" Thread(target=request_task, args=[seed, max_new_tokens]) for seed, max_new_tokens in zip(seeds, max_new_tokens_lst)\n", " Thread(target=request_task, args=[seed, max_new_tokens]) \n",
" for seed, max_new_tokens in zip(seeds[:cnt], max_new_tokens_lst[:cnt])\n",
"]\n", "]\n",
"\n", "\n",
"import time\n", "import time\n",