From 03fda99ee1457f1c811b2617b4af9f013fa8007a Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 20 Aug 2023 13:50:18 +0000 Subject: [PATCH] updated to include interally managed kv-cache --- interaction.ipynb | 436 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 436 insertions(+) diff --git a/interaction.ipynb b/interaction.ipynb index dd5fc582..94c8b85b 100644 --- a/interaction.ipynb +++ b/interaction.ipynb @@ -1659,6 +1659,442 @@ ], "source": [] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Intenally Managed" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "import deepsparse\n", + "from deepsparse.transformers.utils.helpers import create_causal_mask" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-20 11:07:42 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`\n", + "Using pad_token, but it is not set yet.\n", + "2023-08-20 11:07:54 deepsparse.transformers.pipelines.text_generation INFO Compiling an auxiliary engine to process a prompt with a larger processing length. This improves performance, but may result in additional memory consumption.\n", + "2023-08-20 11:07:56 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n", + "DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230815 COMMUNITY | (134dba40) (release) (optimized) (system=avx2, binary=avx2)\n", + "2023-08-20 11:08:20 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n" + ] + } + ], + "source": [ + "pipeline = deepsparse.Pipeline.create(\n", + " task=\"text-generation\", \n", + " model_path=\"zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none\",\n", + " use_deepsparse_cache=True,\n", + " prompt_processing_sequence_length=4,\n", + " max_generated_tokens=64,\n", + " sequence_length=128\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "sequence = \"Finish the following function for computing a fibonacci sequence: \\n\\n fib(n):\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TextGenerationOutput(sequences=['\\n\\n if n == 0:\\n return 0\\n elif n == 1:\\n return 1\\n else:\\n return fib(n-1) + fib(n-2)\\n\\n# Call the function.\\nprint(fib(5))\\n\\n# This code'], logits=None, session_id=None)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline(sequences=sequence)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "singletoken_engine = pipeline.engine\n", + "multitoken_engine = pipeline.multitoken_engine\n", + "assert singletoken_engine.kv_cache == multitoken_engine.kv_cache\n", + "kv_cache = singletoken_engine.kv_cache" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 271, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "without maintaining\n", + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + " fib(n):\n", + "\n", + " if n == 0:\n", + " return 0\n", + " elif n == 1:\n", + " return 1\n", + " else:\n", + " return fib(n-1) + fib(n-2)\n", + "\n", + "# Call the function.\n", + "print(fib(5))\n", + "\n", + "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n", + "<|endoftext|>\n", + "\n", + "\n", + "maintaining\n", + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + " fib(n):\n", + "\n", + " if n == 0:\n", + " return 0\n", + " elif n == 1:\n", + " return 1\n", + " else:\n", + " return fib(n-1) + fib(n-2)\n", + "\n", + "# Call the function.\n", + "print(fib(5))\n", + "\n", + "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n", + "<|endoftext|>\n" + ] + } + ], + "source": [ + "import numpy\n", + "\n", + "multitoken_length = pipeline.prompt_processing_sequence_length\n", + "sequence_length = pipeline.sequence_length\n", + "\n", + "def empty_past_key_values(engine):\n", + " past_key_values = {}\n", + " for idx, name in enumerate(engine.engine.input_names):\n", + " if name.startswith(\"past_key_values\"):\n", + " shape = engine.engine.input_shapes[idx]\n", + " past_key_values[name] = numpy.zeros(shape, dtype=engine.kv_cache_data_type)\n", + "\n", + " return past_key_values\n", + "\n", + "def engine_inputs_for_decode(tokens):\n", + " assert(len(tokens) < sequence_length)\n", + " \n", + " engine_inputs = {}\n", + " engine_inputs[\"input_ids\"] = numpy.array([[tokens[-1]]])\n", + " engine_inputs[\"attention_mask\"] = numpy.zeros((1, sequence_length), dtype=numpy.int64)\n", + " engine_inputs[\"attention_mask\"][:, -len(tokens):] = 1\n", + " \n", + " engine_inputs[\"causal_mask\"] = create_causal_mask(\n", + " engine_inputs[\"input_ids\"],\n", + " engine_inputs[\"attention_mask\"]\n", + " )\n", + " engine_inputs[\"positions\"] = numpy.array([[len(tokens) - 1]], dtype=numpy.int64)\n", + " \n", + " return engine_inputs\n", + "\n", + "def engine_inputs_for_prefill(tokens):\n", + " num_batches = len(tokens) // multitoken_length\n", + " token_batches = [tokens[i * multitoken_length : (i+1) * multitoken_length] for i in range(0, num_batches)]\n", + "\n", + " for idx, token_batch in enumerate(token_batches):\n", + " num_processed_tokens = multitoken_length * idx\n", + " \n", + " engine_inputs = {}\n", + " engine_inputs[\"input_ids\"] = numpy.array([token_batch])\n", + "\n", + " # make attention mask from the right\n", + " engine_inputs[\"attention_mask\"] = numpy.zeros((1, sequence_length), dtype=numpy.int64)\n", + " engine_inputs[\"attention_mask\"][:, -(num_processed_tokens + multitoken_length):] = 1\n", + "\n", + " # make positions (building from the right)\n", + " assert multitoken_length > 1\n", + " engine_inputs[\"positions\"] = numpy.arange(\n", + " num_processed_tokens, num_processed_tokens + multitoken_length\n", + " ).reshape(1, -1).astype(numpy.int64)\n", + "\n", + " # make causal mask (building from the right)\n", + " engine_inputs[\"causal_mask\"] = create_causal_mask(\n", + " input_ids=engine_inputs[\"input_ids\"], \n", + " attention_mask=engine_inputs[\"attention_mask\"]\n", + " )\n", + "\n", + " yield engine_inputs\n", + "\n", + "def call_engine(engine, engine_inputs, past_key_values):\n", + " # format inputs as list\n", + " inputs = [\n", + " past_key_values[name] if name.startswith(\"past_key_values\") \n", + " else engine_inputs[name] for name in engine.engine.input_names\n", + " ]\n", + "\n", + " # run inference\n", + " logits, *kvs = engine.engine._eng_net.execute_list_out(inputs, engine.kv_cache._kv_cache)\n", + "\n", + " # format output as dict\n", + " past_names = [name for name in engine.engine.input_names if name.startswith(\"past_key_values\")]\n", + " past_key_values = {name: arr for name, arr in zip(past_names, kvs)}\n", + " \n", + " return logits, past_key_values\n", + "\n", + "# bad -- returns a numpy.insert returns a full copy (update does NOT happen in place)\n", + "def insert_past_key_values(past_key_values, num_items=1, padding_value=0):\n", + " dtype = next(iter(past_key_values.values())).dtype\n", + "\n", + " for name in past_key_values:\n", + " padding_value = numpy.array(padding_value, dtype=dtype)\n", + " past_key_values[name] = numpy.insert(past_key_values[name], [0]*num_items, padding_value, axis=2)\n", + " return past_key_values\n", + "\n", + "# bad --- calls np.ascontiguous\n", + "def slice_past_key_values(past_key_values, slice_idx):\n", + " for name in past_key_values:\n", + " past_key_values[name] = numpy.ascontiguousarray(past_key_values[name][:,:,slice_idx:,:])\n", + " return past_key_values\n", + " \n", + "# maintians the kv cache state at pipeline level\n", + "def decode_maintain(tokens, past_key_values): \n", + " engine_inputs = engine_inputs_for_decode(tokens)\n", + "\n", + " logits, past_key_values = call_engine(\n", + " singletoken_engine,\n", + " engine_inputs=engine_inputs,\n", + " past_key_values=past_key_values\n", + " )\n", + "\n", + " # cleanup state (this is BAD - calls np.ascontiguous)\n", + " past_key_values = slice_past_key_values(past_key_values, 1)\n", + "\n", + " assert logits.shape[0] == 1 # assert batch 1 right now\n", + " assert logits.shape[1] == 1 # assert only one element\n", + " return logits, past_key_values\n", + "\n", + "# maintians the kv cache state at pipeline level\n", + "def prefill_maintain(tokens):\n", + " tokens_processed = 0\n", + " past_key_values = empty_past_key_values(multitoken_engine)\n", + "\n", + " for engine_inputs in engine_inputs_for_prefill(tokens):\n", + " logits, past_key_values = call_engine(\n", + " multitoken_engine, \n", + " engine_inputs=engine_inputs, \n", + " past_key_values=past_key_values\n", + " )\n", + " tokens_processed += multitoken_length\n", + "\n", + " # BAD - calls np.ascontgious - cleans up that engine returns past with prior_seq_len + input_ids_len\n", + " past_key_values = slice_past_key_values(past_key_values, multitoken_length)\n", + " \n", + " # (this is BAD - returns a copy) - expand kv cache for single token engine \n", + " past_key_values = insert_past_key_values(past_key_values, num_items=(multitoken_length-1))\n", + "\n", + " # loop of singletoken engine for anything left over\n", + " while tokens_processed < len(tokens):\n", + " logits, past_key_values = decode_maintain(\n", + " tokens=tokens[:tokens_processed+1],\n", + " past_key_values=past_key_values\n", + " )\n", + " tokens_processed += 1\n", + " \n", + " return logits, past_key_values\n", + "\n", + "empty_past_key_values_multi = empty_past_key_values(multitoken_engine)\n", + "empty_past_key_values_single = empty_past_key_values(singletoken_engine)\n", + "\n", + "# does not maintian kv cache state at pipeline level\n", + "def decode(tokens):\n", + " engine_inputs = engine_inputs_for_decode(tokens)\n", + "\n", + " logits, past_key_values = call_engine(\n", + " singletoken_engine,\n", + " engine_inputs=engine_inputs,\n", + " past_key_values=empty_past_key_values_single\n", + " )\n", + "\n", + " return logits\n", + "\n", + "# does not maintain the state at pipeline level\n", + "def prefill(tokens):\n", + " tokens_processed = 0\n", + " \n", + " for engine_inputs in engine_inputs_for_prefill(tokens):\n", + " logits, _ = call_engine(\n", + " multitoken_engine, \n", + " engine_inputs=engine_inputs, \n", + " past_key_values=empty_past_key_values_multi\n", + " )\n", + " tokens_processed += multitoken_length\n", + "\n", + " # loop of singletoken engine for anything left over\n", + " while tokens_processed < len(tokens):\n", + " logits = decode(tokens[:tokens_processed+1])\n", + " tokens_processed += 1\n", + " \n", + " return logits\n", + "\n", + "def sample_token(logits):\n", + " assert(logits.shape[0] == 1)\n", + " return numpy.argmax(logits[0,-1,:])\n", + "\n", + "eos_token = pipeline.tokenizer.eos_token_id\n", + "\n", + "print(\"without maintaining\")\n", + "pipeline._reset_engines_cache()\n", + "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]\n", + "tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()\n", + "\n", + "logits = prefill(tokens)\n", + "tokens.append(sample_token(logits))\n", + "while len(tokens) < sequence_length and tokens[-1] != eos_token:\n", + " logits = decode(tokens)\n", + " tokens.append(sample_token(logits))\n", + "\n", + "print(pipeline.tokenizer.decode(tokens))\n", + "\n", + "print(\"\\n\\nmaintaining\")\n", + "pipeline._reset_engines_cache()\n", + "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]\n", + "tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()\n", + "\n", + "logits, past_key_values = prefill_maintain(tokens)\n", + "tokens.append(sample_token(logits))\n", + "while len(tokens) < sequence_length and tokens[-1] != eos_token:\n", + " logits, past_key_values = decode_maintain(tokens, past_key_values)\n", + " tokens.append(sample_token(logits))\n", + " \n", + "print(pipeline.tokenizer.decode(tokens))" + ] + }, + { + "cell_type": "code", + "execution_count": 278, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + " fib(n):\n", + "\n", + " if n == 0:\n", + " return 0\n", + " elif n == 1:\n", + " return 1\n", + " else:\n", + " return fib(n-1) + fib(n-2)\n", + "\n", + "# Call the function.\n", + "print(fib(5))\n", + "\n", + "# This code is contributed by Nikhil Kumar Singh(nickzuck_007)\n", + "<|endoftext|>\n" + ] + } + ], + "source": [ + "def prefill_pipeline(pipeline, tokens):\n", + " num_tokens_processed = 0\n", + " for engine_inputs in pipeline.engine_inputs_for_prefill(tokens):\n", + " _, logits = pipeline.multitoken_engine(engine_inputs)\n", + " num_tokens_processed += multitoken_length\n", + "\n", + " if num_tokens_processed > 0:\n", + " pipeline.engine.transfer_cache_state(cache=pipeline.multitoken_engine.kv_cache)\n", + "\n", + " run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed]\n", + " for token in tokens[num_tokens_processed:]:\n", + " run_tokens.append(token)\n", + " new_token, logits = pipeline.autoregressive_inference(run_tokens)\n", + " return logits\n", + " \n", + "pipeline._reset_engines_cache()\n", + "engine_inputs = pipeline.process_inputs(pipeline.parse_inputs(sequences=sequence))[0]\n", + "tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()\n", + "\n", + "logits = prefill_pipeline(pipeline, tokens)\n", + "tokens.append(sample_token(logits))\n", + "\n", + "while len(tokens) < pipeline.sequence_length and tokens[-1] != eos_token:\n", + " _, logits = pipeline.autoregressive_inference(tokens)\n", + " tokens.append(sample_token(logits))\n", + "\n", + "print(pipeline.tokenizer.decode(tokens))" + ] + }, + { + "cell_type": "code", + "execution_count": 277, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finish the following function for computing a fibonacci sequence: \n", + "\n", + " fib(n):\n", + "\n", + " if n == 0:\n", + " return 0\n", + " elif n == 1:\n", + " return 1\n", + " else:\n", + " return fib(n-1) + fib(n-2)\n", + "\n", + "# Call the function.\n", + "print(fib(5))\n", + "\n", + "# This code\n" + ] + } + ], + "source": [ + "print(f\"{sequence}{pipeline(sequences=sequence).sequences[0]}\")" + ] + }, { "cell_type": "code", "execution_count": null,