mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
finished deepsparse_model.py implementation
This commit is contained in:
parent
03fda99ee1
commit
f8565cd915
@ -2095,6 +2095,33 @@
|
||||
"print(f\"{sequence}{pipeline(sequences=sequence).sequences[0]}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 279,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using pad_token, but it is not set yet.\n",
|
||||
"2023-08-20 13:44:57 deepsparse.transformers.pipelines.text_generation INFO Compiling an auxiliary engine to process a prompt with a larger processing length. This improves performance, but may result in additional memory consumption.\n",
|
||||
"2023-08-20 13:44:58 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n",
|
||||
"2023-08-20 13:45:23 deepsparse.transformers.utils.helpers INFO Overwriting in-place the input shapes of the transformer model at /home/robertgshaw/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx/model.onnx\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pipeline2 = deepsparse.Pipeline.create(\n",
|
||||
" task=\"text-generation\", \n",
|
||||
" model_path=\"zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none\",\n",
|
||||
" use_deepsparse_cache=False,\n",
|
||||
" prompt_processing_sequence_length=4,\n",
|
||||
" max_generated_tokens=64,\n",
|
||||
" sequence_length=128\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
127
server/text_generation_server/models/deepsparse_causal_lm.py
Normal file
127
server/text_generation_server/models/deepsparse_causal_lm.py
Normal file
@ -0,0 +1,127 @@
|
||||
import numpy, torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
DEEPSPARSE_SEQUENCE_LENGTH = 128
|
||||
|
||||
@dataclass
|
||||
class DeepSparseCausalLMBatch(Batch):
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
requests_idx_mapping: Dict[int, int]
|
||||
|
||||
# TODO: update to handle calculating max_tokens --- needed for CachedBatch
|
||||
|
||||
# Decoder values
|
||||
input_ids_list: List[numpy.ndarray]
|
||||
past_key_values_list: Optional[List[DeepSparsePastKeyValues]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "DeepSparseCausalLMBatch":
|
||||
|
||||
# parse batch
|
||||
input_ids_list = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
# setup tokenizer for deepsparse left padding
|
||||
tokenizer.padding_side = "left"
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
padding, truncation = "longest", False
|
||||
|
||||
# loop through items in the batch
|
||||
for i, r in enumerate(pb.requests):
|
||||
# get mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
# setup inputs
|
||||
tokenized_inputs = tokenizer(
|
||||
r.inputs,
|
||||
return_tensors="np",
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
return_token_type_ids=False,
|
||||
max_length=DEEPSPARSE_SEQUENCE_LENGTH
|
||||
)
|
||||
input_ids_list.append(tokenized_inputs["input_ids"])
|
||||
|
||||
# setup sequence generation helpers, capping at DEEPSPARSE_SEQUENCE_LENGTH
|
||||
# cap stopping parameters to DeepSparse sequence length
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
assert DEEPSPARSE_SEQUENCE_LENGTH - input_len > 0
|
||||
r.stopping_parameters.max_new_tokens = min(
|
||||
r.stopping_parameters.max_new_tokens,
|
||||
DEEPSPARSE_SEQUENCE_LENGTH - input_len
|
||||
)
|
||||
stopping_criterias.append(StoppingCriteria.from_pb(r.stopping_parameters, tokenizer))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
input_ids_list=input_ids_list,
|
||||
past_key_values_list=None
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]:
|
||||
pass
|
||||
|
||||
def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch":
|
||||
pass
|
||||
|
||||
|
||||
class DeepSparseCausalLM:
|
||||
def __init__(
|
||||
self,
|
||||
deployment_path: str
|
||||
):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(deployment_path)
|
||||
self.tokenizer.padding_side = "left"
|
||||
if not self.tokenizer.pad_token:
|
||||
assert self.tokenizer.eos_token
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[DeepSparseCausalLMBatch]:
|
||||
return DeepSparseCausalLMBatch
|
212
server/text_generation_server/models/deepsparse_model.py
Normal file
212
server/text_generation_server/models/deepsparse_model.py
Normal file
@ -0,0 +1,212 @@
|
||||
import numpy as np
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
from deepsparse import Context
|
||||
from deepsparse.engine import LIB
|
||||
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
|
||||
from deepsparse.transformers.utils.helpers import overwrite_onnx_model_inputs, create_causal_mask
|
||||
|
||||
PAST_KEY_VALUES_NAME = "past_key_values"
|
||||
|
||||
class DeepSparsePastKeyValues:
|
||||
def __init__(self):
|
||||
prev_num_tokens = 0
|
||||
num_frozen_tokens = 1
|
||||
self.internal_past_key_values = LIB.kv_cache(prev_num_tokens, num_frozen_tokens)
|
||||
|
||||
class DeepSparseDecoderEngine:
|
||||
def __init__ (
|
||||
self,
|
||||
onnx_file_path: str,
|
||||
sequence_length: int = 1024,
|
||||
input_ids_length: int = 1,
|
||||
engine_context: Optional[Context] = None,
|
||||
):
|
||||
|
||||
# update ONNX graph
|
||||
onnx_file_path, cached_outputs, data_type = overwrite_onnx_model_inputs(
|
||||
onnx_file_path=onnx_file_path,
|
||||
batch_size=1,
|
||||
sequence_length=sequence_length,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
# compile engine
|
||||
self.engine = create_engine(
|
||||
onnx_file_path=onnx_file_path,
|
||||
engine_type=DEEPSPARSE_ENGINE,
|
||||
engine_args={"cached_outputs": cached_outputs},
|
||||
context=engine_context,
|
||||
)
|
||||
print(self.engine)
|
||||
|
||||
# save utilties
|
||||
self.past_key_value_dtype = data_type
|
||||
self.onnx_inputs = self.engine.input_names
|
||||
self.empty_past_key_values = self.make_empty_past_key_values()
|
||||
|
||||
# forward function
|
||||
def __call__(
|
||||
self,
|
||||
engine_inputs: Dict[str, np.ndarray],
|
||||
past_key_values: DeepSparsePastKeyValues,
|
||||
val_inputs: bool = True
|
||||
):
|
||||
# format input into lists (we pass empty past key values)
|
||||
inputs = [self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME)
|
||||
else engine_inputs[name] for name in self.engine.input_names]
|
||||
|
||||
# validate inputs formatted correctly
|
||||
if val_inputs:
|
||||
self.engine._validate_inputs(inputs)
|
||||
|
||||
# run inference, updates past_key_values internally
|
||||
output = self.engine._eng_net.execute_list_out(
|
||||
inputs,
|
||||
past_key_values.internal_past_key_values
|
||||
)
|
||||
logits = output[0]
|
||||
return logits, past_key_values
|
||||
|
||||
# empty past kvs (dummy values to be passed around)
|
||||
def make_empty_past_key_values(self):
|
||||
past_key_values = {}
|
||||
for idx, name in enumerate(self.onnx_inputs):
|
||||
if name.startswith(PAST_KEY_VALUES_NAME):
|
||||
past_key_values[name] = np.zeros(
|
||||
self.engine.input_shapes[idx],
|
||||
dtype=self.past_key_value_dtype
|
||||
)
|
||||
|
||||
return past_key_values
|
||||
|
||||
class DeepSparseDecoderModel:
|
||||
def __init__(
|
||||
self,
|
||||
onnx_file_path: str,
|
||||
sequence_length: int = 1024,
|
||||
multitoken_length: int = 16,
|
||||
engine_context: Optional[Context] = None,
|
||||
singletoken_engine = None,
|
||||
multitoken_engine = None,
|
||||
):
|
||||
self.sequence_length = sequence_length
|
||||
self.multitoken_length = multitoken_length
|
||||
|
||||
if singletoken_engine is not None and multitoken_engine is not None:
|
||||
self.singletoken_engine = singletoken_engine
|
||||
self.multitoken_engine = multitoken_engine
|
||||
else:
|
||||
self.singletoken_engine = DeepSparseDecoderEngine(
|
||||
onnx_file_path=onnx_file_path,
|
||||
engine_context=engine_context,
|
||||
sequence_length=sequence_length,
|
||||
input_ids_length=1,
|
||||
)
|
||||
|
||||
self.multitoken_engine = DeepSparseDecoderEngine(
|
||||
onnx_file_path=onnx_file_path,
|
||||
engine_context=engine_context,
|
||||
sequence_length=sequence_length,
|
||||
input_ids_length=self.multitoken_length,
|
||||
)
|
||||
|
||||
assert "input_ids" in self.singletoken_engine.onnx_inputs
|
||||
assert "attention_mask" in self.singletoken_engine.onnx_inputs
|
||||
assert "causal_mask" in self.singletoken_engine.onnx_inputs
|
||||
assert "positions" in self.singletoken_engine.onnx_inputs
|
||||
|
||||
def engine_inputs_for_prefill(
|
||||
self,
|
||||
tokens: List[int],
|
||||
):
|
||||
assert len(tokens) < self.sequence_length
|
||||
|
||||
# split batch into N token_batches
|
||||
num_batches = len(tokens) // self.multitoken_length
|
||||
token_batches = [
|
||||
tokens[i * self.multitoken_length : (i+1) * self.multitoken_length]
|
||||
for i in range(0, num_batches)
|
||||
]
|
||||
|
||||
# format inputs for each of the N token_batches
|
||||
for idx, token_batch in enumerate(token_batches):
|
||||
num_processed_tokens = self.multitoken_length * idx
|
||||
|
||||
engine_inputs = {}
|
||||
engine_inputs["input_ids"] = np.array([token_batch])
|
||||
|
||||
# make attention mask from the right
|
||||
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
|
||||
engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1
|
||||
|
||||
# make positions (building from the right)
|
||||
# TODO: handle case when multitoken engine is 1
|
||||
assert self.multitoken_length > 1
|
||||
engine_inputs["positions"] = np.arange(
|
||||
num_processed_tokens, num_processed_tokens + self.multitoken_length
|
||||
).reshape(1, -1).astype(np.int64)
|
||||
|
||||
# make causal mask (building from the right)
|
||||
engine_inputs["causal_mask"] = create_causal_mask(
|
||||
input_ids=engine_inputs["input_ids"],
|
||||
attention_mask=engine_inputs["attention_mask"]
|
||||
)
|
||||
yield engine_inputs
|
||||
|
||||
def engine_inputs_for_decode(
|
||||
self,
|
||||
tokens: List[int],
|
||||
):
|
||||
assert len(tokens) < self.sequence_length
|
||||
|
||||
engine_inputs = {}
|
||||
engine_inputs["input_ids"] = np.array([[tokens[-1]]])
|
||||
engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64)
|
||||
engine_inputs["attention_mask"][:, -len(tokens):] = 1
|
||||
|
||||
engine_inputs["causal_mask"] = create_causal_mask(
|
||||
engine_inputs["input_ids"],
|
||||
engine_inputs["attention_mask"]
|
||||
)
|
||||
engine_inputs["positions"] = np.array([[len(tokens) - 1]], dtype=np.int64)
|
||||
|
||||
return engine_inputs
|
||||
|
||||
def decode(
|
||||
self,
|
||||
tokens: List[int],
|
||||
past_key_values: DeepSparsePastKeyValues
|
||||
) -> (np.ndarray, DeepSparsePastKeyValues):
|
||||
|
||||
engine_inputs = self.engine_inputs_for_decode(tokens)
|
||||
logits, past_key_values = self.singletoken_engine(
|
||||
engine_inputs,
|
||||
past_key_values
|
||||
)
|
||||
|
||||
return logits, past_key_values
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
tokens: List[int],
|
||||
past_key_values: DeepSparsePastKeyValues
|
||||
) -> (np.ndarray, DeepSparsePastKeyValues):
|
||||
|
||||
tokens_processed = 0
|
||||
|
||||
for engine_inputs in self.engine_inputs_for_prefill(tokens):
|
||||
_, past_key_values = self.multitoken_engine(
|
||||
engine_inputs,
|
||||
past_key_values
|
||||
)
|
||||
tokens_processed += self.multitoken_length
|
||||
|
||||
while tokens_processed < len(tokens):
|
||||
logits, past_key_values = self.decode(
|
||||
tokens= tokens[:tokens_processed + 1],
|
||||
past_key_values=past_key_values
|
||||
)
|
||||
tokens_processed += 1
|
||||
|
||||
return logits, past_key_values
|
Loading…
Reference in New Issue
Block a user