From 88153796e0afd5f0cdd94f5276f8f8420e01c773 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 15 Apr 2024 22:38:22 -0700 Subject: [PATCH] re-enable xpu Signed-off-by: Wang, Yi A --- .../models/custom_modeling/flash_dbrx_modeling.py | 4 +++- .../models/custom_modeling/flash_mixtral_modeling.py | 5 ++++- server/text_generation_server/models/globals.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index d04ce39e..d0978bef 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -21,8 +21,10 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from loguru import logger +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM -from vllm.model_executor.layers.fused_moe import fused_moe +if not IS_XPU_SYSTEM: + from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( FastLinear, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index be8cb965..3f6c8e03 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,7 +24,10 @@ import torch.distributed import numpy as np from torch import nn -from vllm.model_executor.layers.fused_moe import fused_moe +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM + +if not IS_XPU_SYSTEM: + from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 49d617b5..b92aa65b 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,7 +4,7 @@ import os MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli cuda_graphs = os.getenv("CUDA_GRAPHS") -if cuda_graphs is not None and cuda_graphs != "0": +if torch.cuda.is_available() and cuda_graphs is not None and cuda_graphs != "0": try: cuda_graphs = [int(item) for item in cuda_graphs.split(",")] except Exception as e: