Fixing baichuan override.

This commit is contained in:
Nicolas Patry 2024-07-01 21:05:19 +00:00
parent d0225b1015
commit 83f61d6d7d

View File

@ -117,6 +117,11 @@ class FlashLlamaAttention(torch.nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
# Setting defaults for baichuan custom config which doesn't apply them.
config.rope_theta = getattr(config, "rope_theta", 10000)
config.num_key_value_heads = getattr(
config, "num_key_value_heads", config.num_attention_heads
)
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, config=config,
dim=self.head_size, dim=self.head_size,