mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: prefer comparing model enum over str
This commit is contained in:
parent
9bfa340e34
commit
72c97676fd
@ -298,8 +298,12 @@ class ModelType(enum.Enum):
|
||||
"multimodal": True,
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
return self.value["type"]
|
||||
@classmethod
|
||||
def from_str(cls, model_type: str) -> "ModelType":
|
||||
for model in cls:
|
||||
if model.value["type"] == model_type:
|
||||
return model
|
||||
raise ValueError(f"Unknown model type {model_type}")
|
||||
|
||||
|
||||
def get_model(
|
||||
@ -488,6 +492,9 @@ def get_model(
|
||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||
)
|
||||
|
||||
# convert model_type to ModelType enum
|
||||
model_type = ModelType.from_str(model_type)
|
||||
|
||||
if model_type == ModelType.DEEPSEEK_V2:
|
||||
if FLASH_ATTENTION:
|
||||
head_size = max(
|
||||
|
Loading…
Reference in New Issue
Block a user