mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
hotfix for quantization
This commit is contained in:
parent
f82ae76dff
commit
7a5f5d9757
@ -230,11 +230,15 @@ class LlamaMLP(nn.Module):
|
|||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
and hidden_states.shape[0] == 1
|
and hidden_states.shape[0] == 1
|
||||||
|
and not self.quantize
|
||||||
):
|
):
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
hidden_states.shape[0],
|
hidden_states.shape[0],
|
||||||
|
@ -290,11 +290,15 @@ class MistralMLP(nn.Module):
|
|||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
and hidden_states.shape[0] == 1
|
and hidden_states.shape[0] == 1
|
||||||
|
and not self.quantize
|
||||||
):
|
):
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
hidden_states.shape[0],
|
hidden_states.shape[0],
|
||||||
|
Loading…
Reference in New Issue
Block a user