diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 90978d6a..c1e4bcf7 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -34,7 +34,7 @@ class FastLinear(torch.nn.Module): return F.linear(input, self.weight, self.bias) -class FastLinearROCm(nn.Module): +class FastLinearROCm(torch.nn.Module): def __init__( self, weight, @@ -60,7 +60,7 @@ class FastLinearROCm(nn.Module): weight = self.weight bias = self.bias - if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1: + if SYSTEM == "rocm" and inp.numel() // inp.size(-1) == 1: batched = False if inp.dim() == 3: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 760bb408..654e12f7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -228,7 +228,7 @@ class LlamaMLP(nn.Module): ) def forward(self, hidden_states): - if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1: + if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1: out = torch.empty( hidden_states.shape[0], self.intermediate_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 3143c941..1532757f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -26,7 +26,7 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.layers import ( TensorParallelRowLinear, @@ -41,7 +41,7 @@ from text_generation_server.layers.layernorm import ( ) -if IS_ROCM_SYSTEM: +if SYSTEM == "rocm": try: from vllm import _custom_C except Exception as e: @@ -289,7 +289,7 @@ class MistralMLP(nn.Module): ) def forward(self, hidden_states): - if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1: + if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1: out = torch.empty( hidden_states.shape[0], self.intermediate_size, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8c56ffc3..301648e2 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -759,10 +759,8 @@ class FlashCausalLM(Model): def warmup(self, batch: FlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive - if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: - torch.cuda.empty_cache() - elif IS_XPU_SYSTEM: - torch.xpu.empty_cache() + empty_cache() + try: cache_manager = set_cache_manager( batch.blocks,