diff --git a/server/text_generation_server/models/deepsparse_causal_lm.py b/server/text_generation_server/models/deepsparse_causal_lm.py index 1037263b..1e88b2fe 100644 --- a/server/text_generation_server/models/deepsparse_causal_lm.py +++ b/server/text_generation_server/models/deepsparse_causal_lm.py @@ -1,61 +1,37 @@ -import numpy, torch -from dataclasses import dataclass -from typing import Optional, Tuple, List, Type, Dict - +import numpy as np from transformers import AutoTokenizer, PreTrainedTokenizerBase -from text_generation_server.models import Model -from text_generation_server.models.types import ( - Batch, - PrefillTokens, - Generation, - GeneratedText, -) +from dataclasses import dataclass +from typing import List, Dict, Optional, Type +from text_generation_server.models.deepsparse_model import ( + DeepSparsePastKeyValues, + DeepSparseDecoderModel +) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling DEEPSPARSE_SEQUENCE_LENGTH = 128 +DEEPSPARSE_MULTITOKEN_LENGTH = 4 @dataclass -class DeepSparseCausalLMBatch(Batch): +class DeepSparseCausalLMBatch: batch_id: int requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] - - # TODO: update to handle calculating max_tokens --- needed for CachedBatch - - # Decoder values - input_ids_list: List[numpy.ndarray] - past_key_values_list: Optional[List[DeepSparsePastKeyValues]] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) + requests_idx_mapping: Dict[int,int] + input_ids_list: List[np.ndarray] + past_key_values_list: List[Optional[DeepSparsePastKeyValues]] @classmethod def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, ) -> "DeepSparseCausalLMBatch": # parse batch - input_ids_list = [] - next_token_choosers = [] - stopping_criterias = [] requests_idx_mapping = {} - + input_ids_list = [] + # setup tokenizer for deepsparse left padding tokenizer.padding_side = "left" if not tokenizer.pad_token: @@ -63,11 +39,10 @@ class DeepSparseCausalLMBatch(Batch): padding, truncation = "longest", False # loop through items in the batch - for i, r in enumerate(pb.requests): - # get mapping - requests_idx_mapping[r.id] = i + for idx, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = idx - # setup inputs + # setup inputs_ids, past_key_values tokenized_inputs = tokenizer( r.inputs, return_tensors="np", @@ -77,51 +52,189 @@ class DeepSparseCausalLMBatch(Batch): max_length=DEEPSPARSE_SEQUENCE_LENGTH ) input_ids_list.append(tokenized_inputs["input_ids"]) - - # setup sequence generation helpers, capping at DEEPSPARSE_SEQUENCE_LENGTH - # cap stopping parameters to DeepSparse sequence length - input_len = tokenized_inputs["input_ids"].shape[1] - assert DEEPSPARSE_SEQUENCE_LENGTH - input_len > 0 - r.stopping_parameters.max_new_tokens = min( - r.stopping_parameters.max_new_tokens, - DEEPSPARSE_SEQUENCE_LENGTH - input_len - ) - stopping_criterias.append(StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, input_ids_list=input_ids_list, - past_key_values_list=None + past_key_values_list=[None] * len(pb.requests), ) - + + # length of the batch def __len__(self): return len(self.requests) + # pass list of request ids, returns batch with only those request ids def filter(self, request_ids: List[int]) -> Optional["DeepSparseCausalLMBatch"]: - pass + assert(len(request_ids) > 0) + requests_idx_mapping = {} + requests = [] + input_ids_list = [] + past_key_values_list = [] + + # loop through requests, keep ones that should remain + for new_idx, request_id in enumerate(request_ids): + assert request_id in self.requests_idx_mapping.keys(), "all request ids must be in the batch" + + requests_idx_mapping[request_id] = new_idx + + old_idx = self.requests_idx_mapping[request_id] + requests.append(self.requests[old_idx]) + input_ids_list.append(self.input_ids_list[old_idx]) + past_key_values_list.append(self.past_key_values[old_idx]) + + # update batch state + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids_list = input_ids_list + self.past_key_values_list = past_key_values_list + + return self + + # combine two batches into one + @classmethod def concatenate(cls, batches: List["DeepSparseCausalLMBatch"]) -> "DeepSparseCausalLMBatch": - pass + assert len(batches) > 1, "must have more than 1 batch to concatenate" + requests_idx_mapping = {} + requests = [] + input_ids_list = [] + past_key_values_list = [] + + start_index = 0 + for i, batch in enumerate(batches): + assert batch.past_key_values_list is None, "only concatenate prefilled batches" + + # concatenate request, input_ids, and past_key_values lists + requests.extend(batch.requests) + input_ids_list.extend(batch.input_ids_list) + past_key_values_list.extend(batch.past_key_values_list) + + # merge the request_id to index mapping + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + start_index += len(batch) + + return cls( + batch_id= batches[0].id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids_list=input_ids_list, + past_key_values_list=past_key_values_list + ) class DeepSparseCausalLM: def __init__( self, - deployment_path: str - ): - self.tokenizer = AutoTokenizer.from_pretrained(deployment_path) + model_path: str, + tokenizer_path: str, + ): + # setup tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.tokenizer.padding_side = "left" if not self.tokenizer.pad_token: assert self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token - + # setup model + self.model = DeepSparseDecoderModel( + onnx_file_path = model_path, + sequence_length = DEEPSPARSE_SEQUENCE_LENGTH, + multitoken_length = DEEPSPARSE_MULTITOKEN_LENGTH, + ) - @property - def batch_type(self) -> Type[DeepSparseCausalLMBatch]: - return DeepSparseCausalLMBatch \ No newline at end of file + # TODO (@rsnm2): switch to NextTokenChooser + def sample_token( + self, + logits: np.ndarray + ): + assert(logits.shape[0] == 1) # assert b=1 for now + return np.argmax(logits[0,-1,:]) # grab logits for the last item in the sequence + + # TODO (@rsnm2): switch to StoppingCriteria + def should_stop( + self, + num_tokens_processed: int, + generated_token_id: int + ): + if num_tokens_processed >= self.model.sequence_length: + return True + if generated_token_id == self.tokenizer.eos_token_id: + return True + return False + + def generate_token( + self, + batch: DeepSparseCausalLMBatch, + ) -> (Dict[int,str], Optional[DeepSparseCausalLMBatch]): + + generations: Dict[int, str] = {} + all_stopped = True + + # if we supported continuous batching, we would do batched inference here + # logits, past_key_values = self.model(batch) + + # for each member of the batch: + # a) run inference + # b) sample and check stopping criteria + # c) create generation + update batch + for i, ( + request, + input_ids, + past_key_values, + ) in enumerate(zip( + batch.requests, + batch.input_ids_list, + batch.past_key_values_list + )): + + # run inference + logits, past_key_values = self.model(input_ids, past_key_values) + + # sample token + # simple for now --- should use NextTokenChooser + generated_token_id = self.sample_token(logits) + + # check stopping criteria + # simple for now --- should use StoppingCriteria + stop = self.should_stop( + num_tokens_processed=len(input_ids) + 1, + generated_token_id = generated_token_id + ) + + # if not stopped, convert token id to text + generated_text = None + if not stop: + all_stopped = False + generated_text = self.tokenizer.decode( + generated_token_id, + skip_special_tokens=True, + clean_up_tokenization_spaces=False + ) + generations[request.id] = generated_text + + # update values in the batch + assert len(batch.input_ids_list[i].shape) == 2 + assert batch.input_ids_list[i].shape[0] == 1 + + # bad --- this does not occur in place + # print(batch.input_ids_list[i]) + batch.input_ids_list[i] = np.append( + batch.input_ids_list[i], + np.array([[generated_token_id]]), + axis=1 + ) + batch.past_key_values_list[i] = past_key_values + + # if all elements of the batch are done, return generation + null for batch + if all_stopped: + return generations, None + + # return generation + updated batch + return generations, batch \ No newline at end of file diff --git a/server/text_generation_server/models/deepsparse_model.py b/server/text_generation_server/models/deepsparse_model.py index 03f73bad..cdd0fb52 100644 --- a/server/text_generation_server/models/deepsparse_model.py +++ b/server/text_generation_server/models/deepsparse_model.py @@ -1,3 +1,7 @@ +import os + +os.environ["WAND_OPT_FLAGS"] = "default,~pyramids" + import numpy as np from typing import Optional, List, Dict @@ -53,8 +57,10 @@ class DeepSparseDecoderEngine: val_inputs: bool = True ): # format input into lists (we pass empty past key values) - inputs = [self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME) - else engine_inputs[name] for name in self.engine.input_names] + inputs = [ + self.empty_past_key_values[name] if name.startswith(PAST_KEY_VALUES_NAME) + else engine_inputs[name] for name in self.engine.input_names + ] # validate inputs formatted correctly if val_inputs: @@ -87,29 +93,25 @@ class DeepSparseDecoderModel: sequence_length: int = 1024, multitoken_length: int = 16, engine_context: Optional[Context] = None, - singletoken_engine = None, - multitoken_engine = None, ): self.sequence_length = sequence_length self.multitoken_length = multitoken_length - if singletoken_engine is not None and multitoken_engine is not None: - self.singletoken_engine = singletoken_engine - self.multitoken_engine = multitoken_engine - else: - self.singletoken_engine = DeepSparseDecoderEngine( - onnx_file_path=onnx_file_path, - engine_context=engine_context, - sequence_length=sequence_length, - input_ids_length=1, - ) - - self.multitoken_engine = DeepSparseDecoderEngine( - onnx_file_path=onnx_file_path, - engine_context=engine_context, - sequence_length=sequence_length, - input_ids_length=self.multitoken_length, - ) + # compile decode engine + self.singletoken_engine = DeepSparseDecoderEngine( + onnx_file_path=onnx_file_path, + engine_context=engine_context, + sequence_length=sequence_length, + input_ids_length=1, + ) + + # compile prefill engine + self.multitoken_engine = DeepSparseDecoderEngine( + onnx_file_path=onnx_file_path, + engine_context=engine_context, + sequence_length=sequence_length, + input_ids_length=self.multitoken_length, + ) assert "input_ids" in self.singletoken_engine.onnx_inputs assert "attention_mask" in self.singletoken_engine.onnx_inputs @@ -118,14 +120,12 @@ class DeepSparseDecoderModel: def engine_inputs_for_prefill( self, - tokens: List[int], + input_ids: np.ndarray, ): - assert len(tokens) < self.sequence_length - # split batch into N token_batches - num_batches = len(tokens) // self.multitoken_length + num_batches = input_ids.shape[1] // self.multitoken_length token_batches = [ - tokens[i * self.multitoken_length : (i+1) * self.multitoken_length] + input_ids[:, i*self.multitoken_length : (i+1)*self.multitoken_length] for i in range(0, num_batches) ] @@ -133,9 +133,9 @@ class DeepSparseDecoderModel: for idx, token_batch in enumerate(token_batches): num_processed_tokens = self.multitoken_length * idx - engine_inputs = {} - engine_inputs["input_ids"] = np.array([token_batch]) - + engine_inputs = {} + engine_inputs["input_ids"] = token_batch + # make attention mask from the right engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) engine_inputs["attention_mask"][:, -(self.multitoken_length + num_processed_tokens):] = 1 @@ -156,30 +156,33 @@ class DeepSparseDecoderModel: def engine_inputs_for_decode( self, - tokens: List[int], + input_ids: np.ndarray, ): - assert len(tokens) < self.sequence_length - engine_inputs = {} - engine_inputs["input_ids"] = np.array([[tokens[-1]]]) + engine_inputs["input_ids"] = input_ids[:,-1:] engine_inputs["attention_mask"] = np.zeros((1, self.sequence_length), dtype=np.int64) - engine_inputs["attention_mask"][:, -len(tokens):] = 1 - + engine_inputs["attention_mask"][:, -input_ids.shape[1]:] = 1 + engine_inputs["causal_mask"] = create_causal_mask( engine_inputs["input_ids"], engine_inputs["attention_mask"] ) - engine_inputs["positions"] = np.array([[len(tokens) - 1]], dtype=np.int64) + engine_inputs["positions"] = np.array([[input_ids.shape[1] - 1]], dtype=np.int64) return engine_inputs def decode( self, - tokens: List[int], + input_ids: np.ndarray, past_key_values: DeepSparsePastKeyValues ) -> (np.ndarray, DeepSparsePastKeyValues): - engine_inputs = self.engine_inputs_for_decode(tokens) + # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len + assert len(input_ids.shape) == 2 + assert input_ids.shape[0] == 1 + assert input_ids.shape[1] < self.sequence_length + + engine_inputs = self.engine_inputs_for_decode(input_ids) logits, past_key_values = self.singletoken_engine( engine_inputs, past_key_values @@ -189,24 +192,51 @@ class DeepSparseDecoderModel: def prefill( self, - tokens: List[int], - past_key_values: DeepSparsePastKeyValues + input_ids: np.ndarray, ) -> (np.ndarray, DeepSparsePastKeyValues): - + + # assert input is of shape [1,seq_len] w/ seq_len < self.sequence_len + assert len(input_ids.shape) == 2 + assert input_ids.shape[0] == 1 + assert input_ids.shape[1] < self.sequence_length + tokens_processed = 0 + + # setup empty past key values + past_key_values = DeepSparsePastKeyValues() - for engine_inputs in self.engine_inputs_for_prefill(tokens): - _, past_key_values = self.multitoken_engine( + # loop through chunks, run inference w/ multitoken engine + for engine_inputs in self.engine_inputs_for_prefill(input_ids): + logits, past_key_values = self.multitoken_engine( engine_inputs, past_key_values ) tokens_processed += self.multitoken_length - while tokens_processed < len(tokens): + # if anything left over, run inference w/ singletoken engine + while tokens_processed < input_ids.shape[1]: logits, past_key_values = self.decode( - tokens= tokens[:tokens_processed + 1], + input_ids=input_ids[:,:tokens_processed+1], past_key_values=past_key_values ) tokens_processed += 1 + # print(logits[:,-1:,:]) - return logits, past_key_values \ No newline at end of file + return logits, past_key_values + + def forward( + self, + input_ids: np.ndarray, + past_key_values: Optional[DeepSparsePastKeyValues] = None, + ): + if past_key_values is None: + return self.prefill(input_ids) + else: + return self.decode(input_ids, past_key_values) + + def __call__( + self, + input_ids: np.ndarray, + past_key_values: Optional[DeepSparsePastKeyValues] = None, + ): + return self.forward(input_ids, past_key_values) \ No newline at end of file