mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: prefer comparing model enum over str
This commit is contained in:
parent
9bfa340e34
commit
72c97676fd
@ -298,8 +298,12 @@ class ModelType(enum.Enum):
|
|||||||
"multimodal": True,
|
"multimodal": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __str__(self):
|
@classmethod
|
||||||
return self.value["type"]
|
def from_str(cls, model_type: str) -> "ModelType":
|
||||||
|
for model in cls:
|
||||||
|
if model.value["type"] == model_type:
|
||||||
|
return model
|
||||||
|
raise ValueError(f"Unknown model type {model_type}")
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
@ -488,6 +492,9 @@ def get_model(
|
|||||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# convert model_type to ModelType enum
|
||||||
|
model_type = ModelType.from_str(model_type)
|
||||||
|
|
||||||
if model_type == ModelType.DEEPSEEK_V2:
|
if model_type == ModelType.DEEPSEEK_V2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
head_size = max(
|
head_size = max(
|
||||||
|
Loading…
Reference in New Issue
Block a user