mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-21 09:42:09 +00:00
* forward and tokenize chooser use the same shape concate or filter happened to cpu tensor to avoid dynamic shape in hpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * use hpu set seed Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
128 lines
4.3 KiB
Python
128 lines
4.3 KiB
Python
from dataclasses import dataclass
|
|
import torch
|
|
from typing import Optional, List, Dict
|
|
import collections
|
|
|
|
_TYPE_CACHE = {}
|
|
|
|
|
|
@dataclass
|
|
class HPUPagedAttentionMetadata:
|
|
"""Metadata for PagedAttention."""
|
|
|
|
block_list: Optional[torch.Tensor]
|
|
block_mapping: Optional[torch.Tensor]
|
|
block_usage: Optional[torch.Tensor]
|
|
block_groups: Optional[torch.Tensor]
|
|
attn_bias: Optional[torch.Tensor]
|
|
|
|
|
|
def subtuple(
|
|
obj: object,
|
|
typename: str,
|
|
to_copy: List[str],
|
|
to_override: Optional[Dict[str, object]] = None,
|
|
):
|
|
if obj is None:
|
|
return None
|
|
if to_override is None:
|
|
to_override = {}
|
|
fields = set(to_copy) | set(to_override.keys())
|
|
if isinstance(obj, dict):
|
|
values = {key: obj[key] for key in fields if key in obj}
|
|
else:
|
|
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
|
|
if typename not in _TYPE_CACHE:
|
|
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
|
|
return _TYPE_CACHE[typename](**values)
|
|
|
|
|
|
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
|
|
# NOTE(kzawora): To anyone working on this in the future:
|
|
# Trimming metadata is required when using HPUGraphs.
|
|
# Attention metadata is going to be hashed by PT bridge, and
|
|
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
|
|
|
# Before you put more keys in here, make sure you know their
|
|
# value type and make sure you know how it's going to be hashed.
|
|
# You can find that information in input_hash function
|
|
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
|
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
|
|
|
# If you use primitive types here - they will get hashed based
|
|
# on their value. You *will* get lots of excessive graph captures
|
|
# (and an OOM eventually) if you decide to put something like
|
|
# seq_len int here.
|
|
# If you absolutely need a scalar, put it in a tensor. Tensors
|
|
# get hashed using their metadata, not their values:
|
|
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
|
# input_hash(123) != input_hash(321)
|
|
# input_hash("abc") != input_hash("cba")
|
|
attention_metadata = subtuple(
|
|
metadata,
|
|
"TrimmedAttentionMetadata",
|
|
[
|
|
"block_list",
|
|
"block_mapping",
|
|
"block_usage",
|
|
"block_groups",
|
|
"attn_bias",
|
|
],
|
|
)
|
|
return attention_metadata
|
|
|
|
|
|
@dataclass
|
|
class Seqlen:
|
|
input_lengths: torch.Tensor
|
|
|
|
def __init__(
|
|
self,
|
|
input_lengths,
|
|
):
|
|
self.input_lengths = input_lengths
|
|
|
|
def clamp(self, max):
|
|
# Flash decoding doesn't need to clamp
|
|
return self
|
|
|
|
|
|
def _async_h2d_tensor_copy(source, device="hpu"):
|
|
if source is None:
|
|
return None
|
|
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
|
|
target = torch.empty(source.shape, dtype=source.dtype, device=device)
|
|
target.copy_(source, non_blocking=True)
|
|
return target
|
|
|
|
|
|
def trim_seqlen_metadata(metadata: Seqlen) -> object:
|
|
# NOTE(kzawora): To anyone working on this in the future:
|
|
# Trimming metadata is required when using HPUGraphs.
|
|
# Attention metadata is going to be hashed by PT bridge, and
|
|
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
|
|
|
# Before you put more keys in here, make sure you know their
|
|
# value type and make sure you know how it's going to be hashed.
|
|
# You can find that information in input_hash function
|
|
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
|
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
|
|
|
# If you use primitive types here - they will get hashed based
|
|
# on their value. You *will* get lots of excessive graph captures
|
|
# (and an OOM eventually) if you decide to put something like
|
|
# seq_len int here.
|
|
# If you absolutely need a scalar, put it in a tensor. Tensors
|
|
# get hashed using their metadata, not their values:
|
|
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
|
# input_hash(123) != input_hash(321)
|
|
# input_hash("abc") != input_hash("cba")
|
|
attention_metadata = subtuple(
|
|
metadata,
|
|
"TrimmedSeqlen",
|
|
[
|
|
"input_lengths",
|
|
],
|
|
)
|
|
return attention_metadata
|