fix merge issues

This commit is contained in:
fxmarty 2024-05-15 12:06:54 +00:00
parent 3b011ed3ea
commit c683597b42
4 changed files with 8 additions and 10 deletions

View File

@ -34,7 +34,7 @@ class FastLinear(torch.nn.Module):
return F.linear(input, self.weight, self.bias) return F.linear(input, self.weight, self.bias)
class FastLinearROCm(nn.Module): class FastLinearROCm(torch.nn.Module):
def __init__( def __init__(
self, self,
weight, weight,
@ -60,7 +60,7 @@ class FastLinearROCm(nn.Module):
weight = self.weight weight = self.weight
bias = self.bias bias = self.bias
if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1: if SYSTEM == "rocm" and inp.numel() // inp.size(-1) == 1:
batched = False batched = False
if inp.dim() == 3: if inp.dim() == 3:

View File

@ -228,7 +228,7 @@ class LlamaMLP(nn.Module):
) )
def forward(self, hidden_states): def forward(self, hidden_states):
if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1: if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
out = torch.empty( out = torch.empty(
hidden_states.shape[0], hidden_states.shape[0],
self.intermediate_size, self.intermediate_size,

View File

@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -41,7 +41,7 @@ from text_generation_server.layers.layernorm import (
) )
if IS_ROCM_SYSTEM: if SYSTEM == "rocm":
try: try:
from vllm import _custom_C from vllm import _custom_C
except Exception as e: except Exception as e:
@ -289,7 +289,7 @@ class MistralMLP(nn.Module):
) )
def forward(self, hidden_states): def forward(self, hidden_states):
if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1: if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
out = torch.empty( out = torch.empty(
hidden_states.shape[0], hidden_states.shape[0],
self.intermediate_size, self.intermediate_size,

View File

@ -759,10 +759,8 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: empty_cache()
torch.cuda.empty_cache()
elif IS_XPU_SYSTEM:
torch.xpu.empty_cache()
try: try:
cache_manager = set_cache_manager( cache_manager = set_cache_manager(
batch.blocks, batch.blocks,