direct return in clamp like rocm

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-10-10 23:02:56 -07:00
parent f213012b08
commit b392362e9e
2 changed files with 7 additions and 17 deletions

View File

@ -66,7 +66,7 @@ else:
max_k: int max_k: int
def clamp(self, max): def clamp(self, max):
if SYSTEM == "rocm": if SYSTEM == "rocm" or SYSTEM == "ipex":
return self return self
raise NotImplementedError("Not implemented seqlen for paged") raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max)) return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -46,7 +46,7 @@ from text_generation_server.models.globals import (
TGI_WIGGLE_ROOM, TGI_WIGGLE_ROOM,
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import KVCache, Seqlen, SUPPORTS_WINDOWING from text_generation_server.layers.attention import KVCache, Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.quantization import get_loader
@ -993,21 +993,6 @@ class FlashCausalLM(Model):
) )
prefix = "" prefix = ""
if getattr(config, "sliding_window", None) is not None and SUPPORTS_WINDOWING:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
text_config = getattr(config, "text_config", None)
if text_config:
if (
getattr(text_config, "sliding_window", None) is not None
and SUPPORTS_WINDOWING
):
set_sliding_window(text_config.sliding_window)
else:
text_config.sliding_window = None
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -1016,6 +1001,11 @@ class FlashCausalLM(Model):
if text_config is not None: if text_config is not None:
config = text_config config = text_config
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads // self.process_group.size() self.num_heads = config.num_attention_heads // self.process_group.size()
# Validation is done in the model itself # Validation is done in the model itself