fix: simplify changes and revert model changes

This commit is contained in:
drbh 2024-08-02 19:01:58 +00:00
parent cf27954257
commit afc0fb5adf
5 changed files with 25 additions and 21 deletions

View File

@ -484,7 +484,7 @@ def get_model(
) )
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if max_input_tokens <= sliding_window: if max_input_tokens is not None and max_input_tokens <= sliding_window:
sliding_window = -1 sliding_window = -1
if ( if (

View File

@ -31,7 +31,6 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
SUPPORTS_WINDOWING,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -111,7 +110,9 @@ class MistralConfig(PretrainedConfig):
class MistralAttention(torch.nn.Module): class MistralAttention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, layer_id): def __init__(self, prefix: str, config, weights, layer_id):
super().__init__() super().__init__()
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
if hasattr(config, "head_dim"): if hasattr(config, "head_dim"):
@ -486,10 +487,10 @@ class FlashMistralForCausalLM(torch.nn.Module):
), ),
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past = config.sliding_window
self.max_past_tensor = ( self.max_past_tensor = (
torch.tensor(self.max_past, device=weights.device) torch.tensor(config.sliding_window, device=weights.device)
if self.max_past if self.max_past is not None
else None else None
) )

View File

@ -35,7 +35,6 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
SUPPORTS_WINDOWING,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
@ -196,7 +195,9 @@ class MixtralAttention(torch.nn.Module):
weights, weights,
): ):
super().__init__() super().__init__()
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
@ -614,10 +615,10 @@ class FlashMixtralForCausalLM(torch.nn.Module):
prefix="lm_head" if not prefix else f"{prefix}.lm_head", prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past = config.sliding_window
self.max_past_tensor = ( self.max_past_tensor = (
torch.tensor(self.max_past, device=weights.device) torch.tensor(config.sliding_window, device=weights.device)
if self.max_past if self.max_past is not None
else None else None
) )

View File

@ -9,7 +9,6 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
SUPPORTS_WINDOWING,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -57,7 +56,9 @@ class Qwen2Attention(torch.nn.Module):
weights, weights,
): ):
super().__init__() super().__init__()
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
@ -348,10 +349,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
prefix=f"{prefix}.{suffix}" if prefix else suffix, prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past = config.sliding_window
self.max_past_tensor = ( self.max_past_tensor = (
torch.tensor(self.max_past, device=weights.device) torch.tensor(config.sliding_window, device=weights.device)
if self.max_past if self.max_past is not None
else None else None
) )

View File

@ -30,7 +30,6 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
SUPPORTS_WINDOWING,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -162,7 +161,9 @@ class Starcoder2Attention(torch.nn.Module):
weights, weights,
): ):
super().__init__() super().__init__()
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
@ -508,10 +509,10 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None self.max_past = config.sliding_window
self.max_past_tensor = ( self.max_past_tensor = (
torch.tensor(self.max_past, device=weights.device) torch.tensor(config.sliding_window, device=weights.device)
if self.max_past if self.max_past is not None
else None else None
) )