mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 21:42:06 +00:00
Change default values
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
3c6630c6e9
commit
c6f97fd884
@ -110,8 +110,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
token_idx: Optional[torch.Tensor] = None,
|
token_idx: Optional[torch.Tensor] = None,
|
||||||
use_flash_attention: Optional[bool] = False,
|
use_flash_attention: Optional[bool] = True,
|
||||||
flash_attention_recompute: Optional[bool] = False,
|
flash_attention_recompute: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
if token_idx is not None:
|
if token_idx is not None:
|
||||||
@ -337,8 +337,8 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
use_flash_attention = kwargs.get("use_flash_attention", False)
|
use_flash_attention = kwargs.get("use_flash_attention", True)
|
||||||
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
|
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
labels = kwargs.get("labels", None)
|
labels = kwargs.get("labels", None)
|
||||||
|
Loading…
Reference in New Issue
Block a user