This commit is contained in:
Mohit Sharma 2024-06-24 11:07:32 +00:00
parent fb83e3416b
commit f0d95b0f4b
3 changed files with 3 additions and 3 deletions

View File

@ -237,7 +237,7 @@ struct Args {
dtype: Option<Dtype>,
// Specify the data type for KV cache. By default, it uses the model's data type.
// CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fn)'.
// CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3)'.
// If 'fp8_e4m3' is chosen, a model checkpoint with scales for the KV cache should be provided.
// If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy."
#[clap(long, env, value_enum)]

View File

@ -294,7 +294,7 @@ def get_model(
if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto":
raise RuntimeError(
f"kv_cache_dtype is only supported for {", ".join(FP8_KVCACHE_SUPPORTED_MODELS)} models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}"
f"kv_cache_dtype is only supported for {', '.join(FP8_KVCACHE_SUPPORTED_MODELS)} models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}"
)
speculator = None

View File

@ -3,11 +3,11 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError
from server.text_generation_server.utils.import_utils import SYSTEM
import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once