From b2cd1b66edbb5dfccbfc0141ec92f4e93c300621 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 27 Sep 2024 15:52:43 +0000 Subject: [PATCH] fix imports after rebase --- .../models/custom_modeling/flash_cohere_modeling.py | 2 +- .../models/custom_modeling/flash_dbrx_modeling.py | 2 +- .../models/custom_modeling/flash_deepseek_v2_modeling.py | 7 +++---- .../models/custom_modeling/flash_gemma2_modeling.py | 2 +- .../models/custom_modeling/flash_gemma_modeling.py | 2 +- .../models/custom_modeling/flash_gptj_modeling.py | 2 +- .../models/custom_modeling/flash_llama_modeling.py | 5 ++--- .../models/custom_modeling/flash_mistral_modeling.py | 3 +-- .../models/custom_modeling/flash_mixtral_modeling.py | 2 +- .../models/custom_modeling/flash_neox_modeling.py | 2 +- .../models/custom_modeling/flash_phi_modeling.py | 2 +- .../models/custom_modeling/flash_qwen2_modeling.py | 2 +- .../models/custom_modeling/flash_rw_modeling.py | 2 +- .../models/custom_modeling/flash_santacoder_modeling.py | 2 +- .../models/custom_modeling/flash_starcoder2_modeling.py | 2 +- 15 files changed, 18 insertions(+), 21 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 44db0290..30656038 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -40,6 +39,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 852e52d8..1137a453 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -31,6 +30,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, Seqlen, + PREFILL_IN_KV_CACHE, ) from text_generation_server.layers import ( FastLinear, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 54a334dd..88c2cf80 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -15,9 +15,6 @@ from typing import List, Optional, Tuple, Type -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE -from text_generation_server.utils.import_utils import SYSTEM - import torch import torch.distributed from torch import nn @@ -38,9 +35,11 @@ from text_generation_server.layers.attention import ( paged_attention, reshape_and_cache, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import Weights if SYSTEM == "rocm": @@ -390,8 +389,8 @@ class DeepseekV2MLP(nn.Module): def forward(self, hidden_states: torch.Tensor, reduce: bool = True): if ( SYSTEM == "rocm" - and hidden_states.dtype == torch.float16 and self.hidden_act == "silu" + and hidden_states.dtype == torch.float16 and hidden_states.shape[0] == 1 and not self.quantize ): diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 09c058f0..7a3d60c9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -41,6 +40,7 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 3ddcba8a..4c1be6f6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -31,6 +30,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, Seqlen, + PREFILL_IN_KV_CACHE, ) from text_generation_server.layers import ( TensorParallelRowLinear, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 200735c6..aca97004 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -39,6 +38,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 75e43d88..df48c6f7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -321,12 +321,12 @@ class LlamaMLP(nn.Module): def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" - and hidden_states.dtype == torch.float16 and self.hidden_act == "silu" + and hidden_states.dtype == torch.float16 and hidden_states.shape[0] == 1 + and not self.quantize and self.hidden_size != 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed. - and not self.quantize ): out = torch.empty( hidden_states.shape[0], @@ -561,7 +561,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) - hidden_states = self.model( inputs_embeds, position_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d0503277..3e16d371 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -42,6 +41,7 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -302,7 +302,6 @@ class MistralMLP(nn.Module): def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" - and hidden_states.dtype == torch.float16 and self.hidden_act == "silu" and hidden_states.shape[0] == 1 and not self.quantize diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index c0ffe036..5836d30a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from typing import List, Optional, Tuple, Type import torch @@ -40,6 +39,7 @@ from text_generation_server.layers.attention import ( paged_attention, reshape_and_cache, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 471abca3..ad4e382f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -40,6 +39,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 4a18090a..2a0dc606 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -1,4 +1,3 @@ -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -20,6 +19,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 00e63a6c..02c788d3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -1,4 +1,3 @@ -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -18,6 +17,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 2cf243e8..6671d85e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,6 +1,5 @@ from typing import List, Optional, Tuple -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed from torch import nn @@ -13,6 +12,7 @@ from text_generation_server.layers import ( TensorParallelRowLinear, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 0c1518e7..43eb9687 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -1,4 +1,3 @@ -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -19,6 +18,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 22ac0240..4975cf22 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -40,6 +39,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, FastRMSNorm,