mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
direct return in clamp like rocm
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
f213012b08
commit
b392362e9e
@ -66,7 +66,7 @@ else:
|
|||||||
max_k: int
|
max_k: int
|
||||||
|
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm" or SYSTEM == "ipex":
|
||||||
return self
|
return self
|
||||||
raise NotImplementedError("Not implemented seqlen for paged")
|
raise NotImplementedError("Not implemented seqlen for paged")
|
||||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||||
|
@ -46,7 +46,7 @@ from text_generation_server.models.globals import (
|
|||||||
TGI_WIGGLE_ROOM,
|
TGI_WIGGLE_ROOM,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import KVCache, Seqlen, SUPPORTS_WINDOWING
|
from text_generation_server.layers.attention import KVCache, Seqlen
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
from text_generation_server.utils.quantization import get_loader
|
from text_generation_server.utils.quantization import get_loader
|
||||||
@ -993,21 +993,6 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
prefix = ""
|
prefix = ""
|
||||||
if getattr(config, "sliding_window", None) is not None and SUPPORTS_WINDOWING:
|
|
||||||
set_sliding_window(config.sliding_window)
|
|
||||||
else:
|
|
||||||
config.sliding_window = None
|
|
||||||
|
|
||||||
text_config = getattr(config, "text_config", None)
|
|
||||||
if text_config:
|
|
||||||
if (
|
|
||||||
getattr(text_config, "sliding_window", None) is not None
|
|
||||||
and SUPPORTS_WINDOWING
|
|
||||||
):
|
|
||||||
set_sliding_window(text_config.sliding_window)
|
|
||||||
else:
|
|
||||||
text_config.sliding_window = None
|
|
||||||
|
|
||||||
model = model_class(prefix, config, weights)
|
model = model_class(prefix, config, weights)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
@ -1016,6 +1001,11 @@ class FlashCausalLM(Model):
|
|||||||
if text_config is not None:
|
if text_config is not None:
|
||||||
config = text_config
|
config = text_config
|
||||||
|
|
||||||
|
if getattr(config, "sliding_window", None) is not None:
|
||||||
|
set_sliding_window(config.sliding_window)
|
||||||
|
else:
|
||||||
|
config.sliding_window = None
|
||||||
|
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
self.num_heads = config.num_attention_heads // self.process_group.size()
|
self.num_heads = config.num_attention_heads // self.process_group.size()
|
||||||
# Validation is done in the model itself
|
# Validation is done in the model itself
|
||||||
|
Loading…
Reference in New Issue
Block a user