fix: refactor and reduce lora math

This commit is contained in:
drbh 2024-06-04 05:01:52 +00:00
parent 0a6ea7fb57
commit a046c303f7

View File

@ -209,41 +209,20 @@ class FlashLlamaAttention(torch.nn.Module):
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
batch_size = query.size(0) batch_size = query.size(0)
if not torch.all(lora_indices, -1): query_adapted = (
lora_mask = lora_indices[lora_indices != -1] torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 0])
.squeeze(0)
.view(batch_size, self.num_heads, self.head_size)
)
q_pre_multiplied_batch = torch.ones( value_adapted = (
(batch_size, 4096, 4096), torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 1])
device=hidden_states.device, .squeeze(0)
dtype=hidden_states.dtype, .view(batch_size, self.num_key_value_heads, self.head_size)
) )
q_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[ query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
lora_mask, 0 kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask]
]
v_pre_multiplied_batch = torch.ones(
(batch_size, 4096, 4096),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
v_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[
lora_mask, 1
]
query_adapted = (
torch.bmm(hidden_states.unsqueeze(1), q_pre_multiplied_batch)
.squeeze(1)
.view(batch_size, self.num_heads, self.head_size)
)
value_adapted = (
torch.bmm(hidden_states.unsqueeze(1), v_pre_multiplied_batch)
.squeeze(1)
.view(batch_size, self.num_key_value_heads, self.head_size)
)
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask]
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)