hotfix for quantization

This commit is contained in:
fxmarty 2024-05-17 17:18:40 +00:00
parent f82ae76dff
commit 7a5f5d9757
2 changed files with 8 additions and 0 deletions

View File

@ -230,11 +230,15 @@ class LlamaMLP(nn.Module):
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states): def forward(self, hidden_states):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and self.hidden_act == "silu" and self.hidden_act == "silu"
and hidden_states.shape[0] == 1 and hidden_states.shape[0] == 1
and not self.quantize
): ):
out = torch.empty( out = torch.empty(
hidden_states.shape[0], hidden_states.shape[0],

View File

@ -290,11 +290,15 @@ class MistralMLP(nn.Module):
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states): def forward(self, hidden_states):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and self.hidden_act == "silu" and self.hidden_act == "silu"
and hidden_states.shape[0] == 1 and hidden_states.shape[0] == 1
and not self.quantize
): ):
out = torch.empty( out = torch.empty(
hidden_states.shape[0], hidden_states.shape[0],