mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix merge issues
This commit is contained in:
parent
3b011ed3ea
commit
c683597b42
@ -34,7 +34,7 @@ class FastLinear(torch.nn.Module):
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
class FastLinearROCm(nn.Module):
|
||||
class FastLinearROCm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
@ -60,7 +60,7 @@ class FastLinearROCm(nn.Module):
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
|
||||
if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1:
|
||||
if SYSTEM == "rocm" and inp.numel() // inp.size(-1) == 1:
|
||||
batched = False
|
||||
|
||||
if inp.dim() == 3:
|
||||
|
@ -228,7 +228,7 @@ class LlamaMLP(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
||||
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.intermediate_size,
|
||||
|
@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -41,7 +41,7 @@ from text_generation_server.layers.layernorm import (
|
||||
)
|
||||
|
||||
|
||||
if IS_ROCM_SYSTEM:
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
@ -289,7 +289,7 @@ class MistralMLP(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
||||
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.intermediate_size,
|
||||
|
@ -759,10 +759,8 @@ class FlashCausalLM(Model):
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
torch.cuda.empty_cache()
|
||||
elif IS_XPU_SYSTEM:
|
||||
torch.xpu.empty_cache()
|
||||
empty_cache()
|
||||
|
||||
try:
|
||||
cache_manager = set_cache_manager(
|
||||
batch.blocks,
|
||||
|
Loading…
Reference in New Issue
Block a user