diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 6593229e..e6afe9e6 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -298,8 +298,12 @@ class ModelType(enum.Enum): "multimodal": True, } - def __str__(self): - return self.value["type"] + @classmethod + def from_str(cls, model_type: str) -> "ModelType": + for model in cls: + if model.value["type"] == model_type: + return model + raise ValueError(f"Unknown model type {model_type}") def get_model( @@ -488,6 +492,9 @@ def get_model( f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) + # convert model_type to ModelType enum + model_type = ModelType.from_str(model_type) + if model_type == ModelType.DEEPSEEK_V2: if FLASH_ATTENTION: head_size = max(