diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 47758d30..6e23aa2b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -230,11 +230,15 @@ class LlamaMLP(nn.Module): config.intermediate_size // weights.process_group.size() ) + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + def forward(self, hidden_states): if ( SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1 + and not self.quantize ): out = torch.empty( hidden_states.shape[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 21edc79e..ef3777da 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -290,11 +290,15 @@ class MistralMLP(nn.Module): config.intermediate_size // weights.process_group.size() ) + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + def forward(self, hidden_states): if ( SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1 + and not self.quantize ): out = torch.empty( hidden_states.shape[0],