fix: simplify changes and revert model changes

This commit is contained in:
drbh 2024-08-02 19:01:58 +00:00
parent cf27954257
commit afc0fb5adf
5 changed files with 25 additions and 21 deletions

View File

@ -484,7 +484,7 @@ def get_model(
)
sliding_window = config_dict.get("sliding_window", -1)
if max_input_tokens <= sliding_window:
if max_input_tokens is not None and max_input_tokens <= sliding_window:
sliding_window = -1
if (

View File

@ -31,7 +31,6 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -111,7 +110,9 @@ class MistralConfig(PretrainedConfig):
class MistralAttention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
if hasattr(config, "head_dim"):
@ -486,10 +487,10 @@ class FlashMistralForCausalLM(torch.nn.Module):
),
weights=weights,
)
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(self.max_past, device=weights.device)
if self.max_past
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)

View File

@ -35,7 +35,6 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
from text_generation_server.layers import (
FastLinear,
@ -196,7 +195,9 @@ class MixtralAttention(torch.nn.Module):
weights,
):
super().__init__()
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
@ -614,10 +615,10 @@ class FlashMixtralForCausalLM(torch.nn.Module):
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(self.max_past, device=weights.device)
if self.max_past
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)

View File

@ -9,7 +9,6 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -57,7 +56,9 @@ class Qwen2Attention(torch.nn.Module):
weights,
):
super().__init__()
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
@ -348,10 +349,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(self.max_past, device=weights.device)
if self.max_past
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)

View File

@ -30,7 +30,6 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -162,7 +161,9 @@ class Starcoder2Attention(torch.nn.Module):
weights,
):
super().__init__()
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
@ -508,10 +509,10 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
weights=weights,
)
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(self.max_past, device=weights.device)
if self.max_past
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)