From ac3bc0e95eac7b5c75910d22f0de5a6fac507132 Mon Sep 17 00:00:00 2001 From: jkaniecki <153085639+jkaniecki@users.noreply.github.com> Date: Fri, 19 Jan 2024 15:34:13 +0100 Subject: [PATCH] Removed kv_cache from HPU graph output (#19) --- .../models/causal_lm.py | 68 +++++++++++++++---- 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 6755184f..038f340e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -34,6 +34,8 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent from loguru import logger +from functools import wraps + tracer = trace.get_tracer(__name__) @@ -114,6 +116,27 @@ def shift_all(srcs, dim, offsets): return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)] +def remove_kv_cache_from_output(module): + orig_fwd = module.forward + + @wraps(orig_fwd) + def forward(*args, **kwargs): + if kwargs["past_key_values"] is not None: + kwargs["return_dict"] = False + output = orig_fwd(*args, **kwargs) + first_value, second_value, *_ = output + if first_value.nelement() < 2: + return second_value + else: + return first_value + else: + kwargs["return_dict"] = True + return orig_fwd(*args, **kwargs) + + module.forward = forward + return module + + def pad_tensors(tensors, paddings, dim, value): for i, (tensor, padding) in enumerate(zip(tensors, paddings)): if padding > 0: @@ -510,7 +533,8 @@ class CausalLM(Model): # Initialize the model ds_inference_kwargs = {"dtype": dtype} ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = self.enable_hpu_graph + ds_inference_kwargs["enable_cuda_graph"] = False + if load_to_meta: # model loaded to meta is managed differently @@ -519,6 +543,10 @@ class CausalLM(Model): ds_inference_kwargs["checkpoint"] = checkpoints_json.name model = deepspeed.init_inference(model, **ds_inference_kwargs) model = model.module + model = remove_kv_cache_from_output(model) + if self.enable_hpu_graph: + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + else: get_repo_root(model_id) model = AutoModelForCausalLM.from_pretrained( @@ -528,6 +556,7 @@ class CausalLM(Model): ) model = model.eval().to(device) #wrap in hpu_graph only if self.enable_hpu_graph is set + model = remove_kv_cache_from_output(model) if self.enable_hpu_graph: model = wrap_in_hpu_graph(model, disable_tensor_cache=True) @@ -624,8 +653,11 @@ class CausalLM(Model): kwargs["bypass_hpu_graphs"] = bypass_hpu_graph kwargs.update(self.kwargs) - outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values + if past_key_values is not None: + return self.model.forward(**kwargs) + else: + outputs = self.model.forward(**kwargs) + return outputs.logits, outputs.past_key_values @tracer.start_as_current_span("generate_token") def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]: @@ -659,14 +691,24 @@ class CausalLM(Model): else: input_ids = batch.input_ids - logits, past = self.forward( - input_ids, - attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ) + if prefill: + logits, past = self.forward( + input_ids, + attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None + ) + else: + logits = self.forward( + input_ids, + attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None + ) # Results generations: List[Generation] = [] @@ -829,7 +871,9 @@ class CausalLM(Model): else: batch.position_ids += 1 # Update past key values - batch.past_key_values = past + if prefill: + batch.past_key_values = past + if self.hb_profer_started == True: self.hb_profer.step() htorch.core.mark_step()