mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-21 16:40:20 +00:00
Fix crash issue
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
8c182415c2
commit
b32b78e74e
@ -46,9 +46,12 @@ class Qwen3Attention(nn.Module):
|
|||||||
self.head_dim = getattr(
|
self.head_dim = getattr(
|
||||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||||
)
|
)
|
||||||
self.num_key_value_groups = (
|
config.num_key_value_heads = getattr(
|
||||||
config.num_attention_heads // config.num_key_value_heads
|
config, "num_key_value_heads", config.num_attention_heads
|
||||||
)
|
)
|
||||||
|
# self.num_key_value_groups = (
|
||||||
|
# config.num_attention_heads // config.num_key_value_heads
|
||||||
|
# )
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.attention_dropout = config.attention_dropout
|
self.attention_dropout = config.attention_dropout
|
||||||
self.softmax_scale = self.head_dim**-0.5
|
self.softmax_scale = self.head_dim**-0.5
|
||||||
@ -66,9 +69,13 @@ class Qwen3Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
# self.num_key_value_heads = config.num_key_value_heads
|
# self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
if config.num_key_value_heads > weights.process_group.size():
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
|
||||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
@ -86,10 +93,10 @@ class Qwen3Attention(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_key_value_groups)
|
||||||
|
|
||||||
self.max_past = (
|
self.max_past = (
|
||||||
config.sliding_window if config.sliding_window is not None else -1
|
config.sliding_window if config.sliding_window is not None else -1
|
||||||
@ -127,6 +134,10 @@ class Qwen3Attention(nn.Module):
|
|||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
print(f"qkv shape: {qkv.shape}")
|
||||||
|
print(f"self.head_dim: {self.head_dim}")
|
||||||
|
print(f"self.num_heads: {self.num_heads}")
|
||||||
|
print(f"self.num_key_value_heads: {self.num_key_value_heads}")
|
||||||
query_states, key_states, value_states = qkv.split(
|
query_states, key_states, value_states = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_dim * self.num_heads,
|
self.head_dim * self.num_heads,
|
||||||
|
@ -153,7 +153,7 @@ def prepare_for_decode(
|
|||||||
block_groups_device, num_classes=batch_size
|
block_groups_device, num_classes=batch_size
|
||||||
)
|
)
|
||||||
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
|
||||||
mask = mask >= block_usage.unsqueeze(-1)
|
mask = mask >= block_usage_device.unsqueeze(-1)
|
||||||
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
|
||||||
return trim_attn_metadata(
|
return trim_attn_metadata(
|
||||||
HPUPagedAttentionMetadata(
|
HPUPagedAttentionMetadata(
|
||||||
|
Loading…
Reference in New Issue
Block a user