diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index c1e4bcf7..9fb2b596 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -41,9 +41,9 @@ class FastLinearROCm(torch.nn.Module): bias, ) -> None: super().__init__() - self.weight = nn.Parameter(weight) + self.weight = torch.nn.Parameter(weight) if bias is not None: - self.bias = nn.Parameter(bias) + self.bias = torch.nn.Parameter(bias) else: self.bias = None diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 654e12f7..df8d1f52 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -48,6 +48,7 @@ if SYSTEM == "rocm": def load_attention(config, prefix, weights): + bias = config.attention_bias if config.num_attention_heads != config.num_key_value_heads: return TensorParallelColumnLinear.load_multi( config, diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index dfddf2fb..3b06f737 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -47,7 +47,7 @@ from text_generation_server.models.custom_modeling.idefics_vision import ( from text_generation_server.models.custom_modeling.idefics_perceiver import ( IdeficsPerceiverResampler, ) -from text_generation_server.utils.layers import ( +from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 301648e2..a7f3470d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1150,8 +1150,6 @@ class FlashCausalLM(Model): next_token_texts = [] left = 0 - logger.info(f"Accepted ids {n_accepted_ids}") - current_stopped = False for j in range(index, index + n_accepted_ids): # Generated token diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index f63a75d3..3820e861 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -5,12 +5,15 @@ from loguru import logger import math from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.flash_attn_triton import triton_attention if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") HAS_FLASH_ATTN = False HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_ROCM = False +ROCM_USE_FLASH_ATTN_V2_CK = False +ROCM_USE_FLASH_ATTN_V2_TRITON = False if SYSTEM == "xpu": import intel_extension_for_pytorch as ipex