From 72c97676fd0c89336cbf49ad840ccea555f6a325 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 24 Jul 2024 21:13:07 +0000 Subject: [PATCH] fix: prefer comparing model enum over str --- server/text_generation_server/models/__init__.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 6593229e..e6afe9e6 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -298,8 +298,12 @@ class ModelType(enum.Enum): "multimodal": True, } - def __str__(self): - return self.value["type"] + @classmethod + def from_str(cls, model_type: str) -> "ModelType": + for model in cls: + if model.value["type"] == model_type: + return model + raise ValueError(f"Unknown model type {model_type}") def get_model( @@ -488,6 +492,9 @@ def get_model( f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) + # convert model_type to ModelType enum + model_type = ModelType.from_str(model_type) + if model_type == ModelType.DEEPSEEK_V2: if FLASH_ATTENTION: head_size = max(