fix: prefer comparing model enum over str

This commit is contained in:
drbh 2024-07-24 21:13:07 +00:00
parent 9bfa340e34
commit 72c97676fd

View File

@ -298,8 +298,12 @@ class ModelType(enum.Enum):
"multimodal": True,
}
def __str__(self):
return self.value["type"]
@classmethod
def from_str(cls, model_type: str) -> "ModelType":
for model in cls:
if model.value["type"] == model_type:
return model
raise ValueError(f"Unknown model type {model_type}")
def get_model(
@ -488,6 +492,9 @@ def get_model(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
# convert model_type to ModelType enum
model_type = ModelType.from_str(model_type)
if model_type == ModelType.DEEPSEEK_V2:
if FLASH_ATTENTION:
head_size = max(