fix merge issues

This commit is contained in:
fxmarty 2024-05-15 12:06:54 +00:00
parent 3b011ed3ea
commit c683597b42
4 changed files with 8 additions and 10 deletions

View File

@ -34,7 +34,7 @@ class FastLinear(torch.nn.Module):
return F.linear(input, self.weight, self.bias)
class FastLinearROCm(nn.Module):
class FastLinearROCm(torch.nn.Module):
def __init__(
self,
weight,
@ -60,7 +60,7 @@ class FastLinearROCm(nn.Module):
weight = self.weight
bias = self.bias
if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1:
if SYSTEM == "rocm" and inp.numel() // inp.size(-1) == 1:
batched = False
if inp.dim() == 3:

View File

@ -228,7 +228,7 @@ class LlamaMLP(nn.Module):
)
def forward(self, hidden_states):
if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,

View File

@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -41,7 +41,7 @@ from text_generation_server.layers.layernorm import (
)
if IS_ROCM_SYSTEM:
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
@ -289,7 +289,7 @@ class MistralMLP(nn.Module):
)
def forward(self, hidden_states):
if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,

View File

@ -759,10 +759,8 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.empty_cache()
elif IS_XPU_SYSTEM:
torch.xpu.empty_cache()
empty_cache()
try:
cache_manager = set_cache_manager(
batch.blocks,