do not set sliding_window if SUPPORTS_WINDOWING is false

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-09-23 20:48:43 -07:00
parent 9263817c71
commit a05f3849e4

View File

@ -47,6 +47,7 @@ from text_generation_server.models.globals import (
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.quantization import get_loader
@ -992,6 +993,21 @@ class FlashCausalLM(Model):
) )
prefix = "" prefix = ""
if getattr(config, "sliding_window", None) is not None and SUPPORTS_WINDOWING:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
text_config = getattr(config, "text_config", None)
if text_config:
if (
getattr(text_config, "sliding_window", None) is not None
and SUPPORTS_WINDOWING
):
set_sliding_window(text_config.sliding_window)
else:
text_config.sliding_window = None
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -1000,11 +1016,6 @@ class FlashCausalLM(Model):
if text_config is not None: if text_config is not None:
config = text_config config = text_config
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads // self.process_group.size() self.num_heads = config.num_attention_heads // self.process_group.size()
# Validation is done in the model itself # Validation is done in the model itself