Fix crash issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-19 01:39:48 +00:00
parent 8c182415c2
commit b32b78e74e
2 changed files with 19 additions and 8 deletions

View File

@ -46,9 +46,12 @@ class Qwen3Attention(nn.Module):
self.head_dim = getattr( self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads config, "head_dim", config.hidden_size // config.num_attention_heads
) )
self.num_key_value_groups = ( config.num_key_value_heads = getattr(
config.num_attention_heads // config.num_key_value_heads config, "num_key_value_heads", config.num_attention_heads
) )
# self.num_key_value_groups = (
# config.num_attention_heads // config.num_key_value_heads
# )
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.softmax_scale = self.head_dim**-0.5 self.softmax_scale = self.head_dim**-0.5
@ -66,9 +69,13 @@ class Qwen3Attention(nn.Module):
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
# self.num_key_value_heads = config.num_key_value_heads # self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_heads = ( if config.num_key_value_heads > weights.process_group.size():
config.num_key_value_heads // weights.process_group.size() self.num_key_value_heads = (
) config.num_key_value_heads // weights.process_group.size()
)
else:
self.num_key_value_heads = config.num_key_value_heads
self.query_key_value = TensorParallelColumnLinear.load_multi( self.query_key_value = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
@ -86,10 +93,10 @@ class Qwen3Attention(nn.Module):
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_key_value_groups)
self.max_past = ( self.max_past = (
config.sliding_window if config.sliding_window is not None else -1 config.sliding_window if config.sliding_window is not None else -1
@ -127,6 +134,10 @@ class Qwen3Attention(nn.Module):
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
print(f"qkv shape: {qkv.shape}")
print(f"self.head_dim: {self.head_dim}")
print(f"self.num_heads: {self.num_heads}")
print(f"self.num_key_value_heads: {self.num_key_value_heads}")
query_states, key_states, value_states = qkv.split( query_states, key_states, value_states = qkv.split(
[ [
self.head_dim * self.num_heads, self.head_dim * self.num_heads,

View File

@ -153,7 +153,7 @@ def prepare_for_decode(
block_groups_device, num_classes=batch_size block_groups_device, num_classes=batch_size
) )
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage.unsqueeze(-1) mask = mask >= block_usage_device.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
return trim_attn_metadata( return trim_attn_metadata(
HPUPagedAttentionMetadata( HPUPagedAttentionMetadata(