Removed kv_cache from HPU graph output (#19)

This commit is contained in:
jkaniecki 2024-01-19 15:34:13 +01:00 committed by GitHub
parent da0f874d49
commit ac3bc0e95e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -34,6 +34,8 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent
from loguru import logger from loguru import logger
from functools import wraps
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -114,6 +116,27 @@ def shift_all(srcs, dim, offsets):
return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)] return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)]
def remove_kv_cache_from_output(module):
orig_fwd = module.forward
@wraps(orig_fwd)
def forward(*args, **kwargs):
if kwargs["past_key_values"] is not None:
kwargs["return_dict"] = False
output = orig_fwd(*args, **kwargs)
first_value, second_value, *_ = output
if first_value.nelement() < 2:
return second_value
else:
return first_value
else:
kwargs["return_dict"] = True
return orig_fwd(*args, **kwargs)
module.forward = forward
return module
def pad_tensors(tensors, paddings, dim, value): def pad_tensors(tensors, paddings, dim, value):
for i, (tensor, padding) in enumerate(zip(tensors, paddings)): for i, (tensor, padding) in enumerate(zip(tensors, paddings)):
if padding > 0: if padding > 0:
@ -510,7 +533,8 @@ class CausalLM(Model):
# Initialize the model # Initialize the model
ds_inference_kwargs = {"dtype": dtype} ds_inference_kwargs = {"dtype": dtype}
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
ds_inference_kwargs["enable_cuda_graph"] = self.enable_hpu_graph ds_inference_kwargs["enable_cuda_graph"] = False
if load_to_meta: if load_to_meta:
# model loaded to meta is managed differently # model loaded to meta is managed differently
@ -519,6 +543,10 @@ class CausalLM(Model):
ds_inference_kwargs["checkpoint"] = checkpoints_json.name ds_inference_kwargs["checkpoint"] = checkpoints_json.name
model = deepspeed.init_inference(model, **ds_inference_kwargs) model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module model = model.module
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
else: else:
get_repo_root(model_id) get_repo_root(model_id)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
@ -528,6 +556,7 @@ class CausalLM(Model):
) )
model = model.eval().to(device) model = model.eval().to(device)
#wrap in hpu_graph only if self.enable_hpu_graph is set #wrap in hpu_graph only if self.enable_hpu_graph is set
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph: if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True) model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
@ -624,6 +653,9 @@ class CausalLM(Model):
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
kwargs.update(self.kwargs) kwargs.update(self.kwargs)
if past_key_values is not None:
return self.model.forward(**kwargs)
else:
outputs = self.model.forward(**kwargs) outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values return outputs.logits, outputs.past_key_values
@ -659,6 +691,7 @@ class CausalLM(Model):
else: else:
input_ids = batch.input_ids input_ids = batch.input_ids
if prefill:
logits, past = self.forward( logits, past = self.forward(
input_ids, input_ids,
attention_mask, attention_mask,
@ -667,6 +700,15 @@ class CausalLM(Model):
batch.past_key_values, batch.past_key_values,
bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
) )
else:
logits = self.forward(
input_ids,
attention_mask,
batch.position_ids,
token_idx,
batch.past_key_values,
bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
)
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
@ -829,7 +871,9 @@ class CausalLM(Model):
else: else:
batch.position_ids += 1 batch.position_ids += 1
# Update past key values # Update past key values
if prefill:
batch.past_key_values = past batch.past_key_values = past
if self.hb_profer_started == True: if self.hb_profer_started == True:
self.hb_profer.step() self.hb_profer.step()
htorch.core.mark_step() htorch.core.mark_step()