mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: simplify changes and revert model changes
This commit is contained in:
parent
cf27954257
commit
afc0fb5adf
@ -484,7 +484,7 @@ def get_model(
|
||||
)
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
|
||||
if max_input_tokens <= sliding_window:
|
||||
if max_input_tokens is not None and max_input_tokens <= sliding_window:
|
||||
sliding_window = -1
|
||||
|
||||
if (
|
||||
|
@ -31,7 +31,6 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -111,7 +110,9 @@ class MistralConfig(PretrainedConfig):
|
||||
class MistralAttention(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
if hasattr(config, "head_dim"):
|
||||
@ -486,10 +487,10 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(self.max_past, device=weights.device)
|
||||
if self.max_past
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
@ -35,7 +35,6 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
@ -196,7 +195,9 @@ class MixtralAttention(torch.nn.Module):
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
@ -614,10 +615,10 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(self.max_past, device=weights.device)
|
||||
if self.max_past
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
@ -9,7 +9,6 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -57,7 +56,9 @@ class Qwen2Attention(torch.nn.Module):
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
@ -348,10 +349,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
prefix=f"{prefix}.{suffix}" if prefix else suffix,
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(self.max_past, device=weights.device)
|
||||
if self.max_past
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
@ -30,7 +30,6 @@ from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -162,7 +161,9 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
@ -508,10 +509,10 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.max_past = config.sliding_window if SUPPORTS_WINDOWING else None
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(self.max_past, device=weights.device)
|
||||
if self.max_past
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user