This commit is contained in:
Mohit Sharma 2024-06-24 11:07:32 +00:00
parent fb83e3416b
commit f0d95b0f4b
3 changed files with 3 additions and 3 deletions

View File

@ -237,7 +237,7 @@ struct Args {
dtype: Option<Dtype>, dtype: Option<Dtype>,
// Specify the data type for KV cache. By default, it uses the model's data type. // Specify the data type for KV cache. By default, it uses the model's data type.
// CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fn)'. // CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3)'.
// If 'fp8_e4m3' is chosen, a model checkpoint with scales for the KV cache should be provided. // If 'fp8_e4m3' is chosen, a model checkpoint with scales for the KV cache should be provided.
// If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy." // If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy."
#[clap(long, env, value_enum)] #[clap(long, env, value_enum)]

View File

@ -294,7 +294,7 @@ def get_model(
if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto": if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto":
raise RuntimeError( raise RuntimeError(
f"kv_cache_dtype is only supported for {", ".join(FP8_KVCACHE_SUPPORTED_MODELS)} models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}" f"kv_cache_dtype is only supported for {', '.join(FP8_KVCACHE_SUPPORTED_MODELS)} models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}"
) )
speculator = None speculator = None

View File

@ -3,11 +3,11 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError from safetensors import safe_open, SafetensorError
from server.text_generation_server.utils.import_utils import SYSTEM
import torch import torch
from loguru import logger from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once