mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 21:42:06 +00:00
Removed kv_cache from HPU graph output (#19)
This commit is contained in:
parent
da0f874d49
commit
ac3bc0e95e
@ -34,6 +34,8 @@ from text_generation_server.models.types import (
|
|||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent
|
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -114,6 +116,27 @@ def shift_all(srcs, dim, offsets):
|
|||||||
return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)]
|
return [shift(src, dim, offset) for src, offset in zip(srcs, offsets)]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_kv_cache_from_output(module):
|
||||||
|
orig_fwd = module.forward
|
||||||
|
|
||||||
|
@wraps(orig_fwd)
|
||||||
|
def forward(*args, **kwargs):
|
||||||
|
if kwargs["past_key_values"] is not None:
|
||||||
|
kwargs["return_dict"] = False
|
||||||
|
output = orig_fwd(*args, **kwargs)
|
||||||
|
first_value, second_value, *_ = output
|
||||||
|
if first_value.nelement() < 2:
|
||||||
|
return second_value
|
||||||
|
else:
|
||||||
|
return first_value
|
||||||
|
else:
|
||||||
|
kwargs["return_dict"] = True
|
||||||
|
return orig_fwd(*args, **kwargs)
|
||||||
|
|
||||||
|
module.forward = forward
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
def pad_tensors(tensors, paddings, dim, value):
|
def pad_tensors(tensors, paddings, dim, value):
|
||||||
for i, (tensor, padding) in enumerate(zip(tensors, paddings)):
|
for i, (tensor, padding) in enumerate(zip(tensors, paddings)):
|
||||||
if padding > 0:
|
if padding > 0:
|
||||||
@ -510,7 +533,8 @@ class CausalLM(Model):
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
ds_inference_kwargs = {"dtype": dtype}
|
ds_inference_kwargs = {"dtype": dtype}
|
||||||
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
|
||||||
ds_inference_kwargs["enable_cuda_graph"] = self.enable_hpu_graph
|
ds_inference_kwargs["enable_cuda_graph"] = False
|
||||||
|
|
||||||
|
|
||||||
if load_to_meta:
|
if load_to_meta:
|
||||||
# model loaded to meta is managed differently
|
# model loaded to meta is managed differently
|
||||||
@ -519,6 +543,10 @@ class CausalLM(Model):
|
|||||||
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
||||||
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
||||||
model = model.module
|
model = model.module
|
||||||
|
model = remove_kv_cache_from_output(model)
|
||||||
|
if self.enable_hpu_graph:
|
||||||
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id)
|
get_repo_root(model_id)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@ -528,6 +556,7 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
model = model.eval().to(device)
|
model = model.eval().to(device)
|
||||||
#wrap in hpu_graph only if self.enable_hpu_graph is set
|
#wrap in hpu_graph only if self.enable_hpu_graph is set
|
||||||
|
model = remove_kv_cache_from_output(model)
|
||||||
if self.enable_hpu_graph:
|
if self.enable_hpu_graph:
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
|
|
||||||
@ -624,6 +653,9 @@ class CausalLM(Model):
|
|||||||
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
||||||
|
|
||||||
kwargs.update(self.kwargs)
|
kwargs.update(self.kwargs)
|
||||||
|
if past_key_values is not None:
|
||||||
|
return self.model.forward(**kwargs)
|
||||||
|
else:
|
||||||
outputs = self.model.forward(**kwargs)
|
outputs = self.model.forward(**kwargs)
|
||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, outputs.past_key_values
|
||||||
|
|
||||||
@ -659,6 +691,7 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
|
|
||||||
|
if prefill:
|
||||||
logits, past = self.forward(
|
logits, past = self.forward(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@ -667,6 +700,15 @@ class CausalLM(Model):
|
|||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logits = self.forward(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
batch.position_ids,
|
||||||
|
token_idx,
|
||||||
|
batch.past_key_values,
|
||||||
|
bypass_hpu_graph = prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||||
|
)
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -829,7 +871,9 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
batch.position_ids += 1
|
batch.position_ids += 1
|
||||||
# Update past key values
|
# Update past key values
|
||||||
|
if prefill:
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
|
|
||||||
if self.hb_profer_started == True:
|
if self.hb_profer_started == True:
|
||||||
self.hb_profer.step()
|
self.hb_profer.step()
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
Loading…
Reference in New Issue
Block a user