mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
fix: refactor and reduce lora math
This commit is contained in:
parent
0a6ea7fb57
commit
a046c303f7
@ -209,39 +209,18 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
batch_size = query.size(0)
|
batch_size = query.size(0)
|
||||||
if not torch.all(lora_indices, -1):
|
|
||||||
lora_mask = lora_indices[lora_indices != -1]
|
|
||||||
|
|
||||||
q_pre_multiplied_batch = torch.ones(
|
|
||||||
(batch_size, 4096, 4096),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
q_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[
|
|
||||||
lora_mask, 0
|
|
||||||
]
|
|
||||||
|
|
||||||
v_pre_multiplied_batch = torch.ones(
|
|
||||||
(batch_size, 4096, 4096),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
v_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[
|
|
||||||
lora_mask, 1
|
|
||||||
]
|
|
||||||
|
|
||||||
query_adapted = (
|
query_adapted = (
|
||||||
torch.bmm(hidden_states.unsqueeze(1), q_pre_multiplied_batch)
|
torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 0])
|
||||||
.squeeze(1)
|
.squeeze(0)
|
||||||
.view(batch_size, self.num_heads, self.head_size)
|
.view(batch_size, self.num_heads, self.head_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
value_adapted = (
|
value_adapted = (
|
||||||
torch.bmm(hidden_states.unsqueeze(1), v_pre_multiplied_batch)
|
torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 1])
|
||||||
.squeeze(1)
|
.squeeze(0)
|
||||||
.view(batch_size, self.num_key_value_heads, self.head_size)
|
.view(batch_size, self.num_key_value_heads, self.head_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
|
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
|
||||||
kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask]
|
kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user