mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix merge issues
This commit is contained in:
parent
3b011ed3ea
commit
c683597b42
@ -34,7 +34,7 @@ class FastLinear(torch.nn.Module):
|
|||||||
return F.linear(input, self.weight, self.bias)
|
return F.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
class FastLinearROCm(nn.Module):
|
class FastLinearROCm(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
@ -60,7 +60,7 @@ class FastLinearROCm(nn.Module):
|
|||||||
weight = self.weight
|
weight = self.weight
|
||||||
bias = self.bias
|
bias = self.bias
|
||||||
|
|
||||||
if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1:
|
if SYSTEM == "rocm" and inp.numel() // inp.size(-1) == 1:
|
||||||
batched = False
|
batched = False
|
||||||
|
|
||||||
if inp.dim() == 3:
|
if inp.dim() == 3:
|
||||||
|
@ -228,7 +228,7 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
hidden_states.shape[0],
|
hidden_states.shape[0],
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
|
@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -41,7 +41,7 @@ from text_generation_server.layers.layernorm import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if IS_ROCM_SYSTEM:
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_C
|
from vllm import _custom_C
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -289,7 +289,7 @@ class MistralMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
hidden_states.shape[0],
|
hidden_states.shape[0],
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
|
@ -759,10 +759,8 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
# The warmup batch is the biggest batch we could ever receive
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
empty_cache()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
elif IS_XPU_SYSTEM:
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
try:
|
try:
|
||||||
cache_manager = set_cache_manager(
|
cache_manager = set_cache_manager(
|
||||||
batch.blocks,
|
batch.blocks,
|
||||||
|
Loading…
Reference in New Issue
Block a user