fix various merge errors

This commit is contained in:
fxmarty 2024-05-15 12:20:48 +00:00
parent c683597b42
commit b7e98ba635
5 changed files with 7 additions and 5 deletions

View File

@ -41,9 +41,9 @@ class FastLinearROCm(torch.nn.Module):
bias,
) -> None:
super().__init__()
self.weight = nn.Parameter(weight)
self.weight = torch.nn.Parameter(weight)
if bias is not None:
self.bias = nn.Parameter(bias)
self.bias = torch.nn.Parameter(bias)
else:
self.bias = None

View File

@ -48,6 +48,7 @@ if SYSTEM == "rocm":
def load_attention(config, prefix, weights):
bias = config.attention_bias
if config.num_attention_heads != config.num_key_value_heads:
return TensorParallelColumnLinear.load_multi(
config,

View File

@ -47,7 +47,7 @@ from text_generation_server.models.custom_modeling.idefics_vision import (
from text_generation_server.models.custom_modeling.idefics_perceiver import (
IdeficsPerceiverResampler,
)
from text_generation_server.utils.layers import (
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,

View File

@ -1150,8 +1150,6 @@ class FlashCausalLM(Model):
next_token_texts = []
left = 0
logger.info(f"Accepted ids {n_accepted_ids}")
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token

View File

@ -5,12 +5,15 @@ from loguru import logger
import math
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.flash_attn_triton import triton_attention
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False
if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex