mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
init
This commit is contained in:
parent
715b2d19ed
commit
e93ab925f9
@ -12,7 +12,7 @@ import os
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto, modeling_task
|
||||||
from huggingface_hub import hf_hub_download, HfApi
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -375,30 +375,26 @@ def get_model(
|
|||||||
)
|
)
|
||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
# transformers_causal_lm_class = CausalLM
|
transformers_causal_lm_class = CausalLM
|
||||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
if not USE_CUSTOM_MODELING:
|
||||||
if (
|
|
||||||
not USE_CUSTOM_MODELING
|
|
||||||
and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
|
||||||
):
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
|
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
|
||||||
)
|
)
|
||||||
transformers_model_class = getattr(
|
try:
|
||||||
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
|
||||||
)
|
except KeyError:
|
||||||
|
transformers_model_class = modeling_task.AutoForCausalLM
|
||||||
|
|
||||||
if (
|
if transformers_model_class._supports_flash_attn_2:
|
||||||
transformers_model_class._supports_flash_attn_2
|
|
||||||
and transformers_model_class._supports_cache_class
|
|
||||||
):
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersFlashCausalLM with ragged tensors (single dimension for batch and sequence length)."
|
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
|
||||||
|
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
|
||||||
)
|
)
|
||||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersCausalLM with classic tensors with padding (two dimensions for batch size and sequence length)."
|
f"Transformers' {model_type} implementation does not supports ragged tensors format. Will use classic "
|
||||||
|
"format with padding (two dimensions for batch size and sequence length). This is expected to be slow."
|
||||||
)
|
)
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
Loading…
Reference in New Issue
Block a user