address review

This commit is contained in:
fxmarty 2024-06-07 12:36:32 +00:00
parent 4220423d76
commit b884d2b9b8
6 changed files with 27 additions and 90 deletions

View File

@ -476,7 +476,7 @@ fn shard_manager(
rope_factor: Option<f32>, rope_factor: Option<f32>,
max_total_tokens: usize, max_total_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_batch_prefill_tokens: u32, max_input_tokens: usize,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
log_level: LevelFilter, log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
@ -550,8 +550,8 @@ fn shard_manager(
} }
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-batch-prefill-tokens".to_string()); shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_batch_prefill_tokens.to_string()); shard_args.push(max_input_tokens.to_string());
// Copy current process env // Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
@ -1009,7 +1009,7 @@ fn spawn_shards(
args: &Args, args: &Args,
cuda_graphs: Vec<usize>, cuda_graphs: Vec<usize>,
max_total_tokens: usize, max_total_tokens: usize,
max_batch_prefill_tokens: u32, max_input_tokens: usize,
max_log_level: LevelFilter, max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
@ -1067,7 +1067,7 @@ fn spawn_shards(
rope_factor, rope_factor,
max_total_tokens, max_total_tokens,
max_batch_size, max_batch_size,
max_batch_prefill_tokens, max_input_tokens,
otlp_endpoint, otlp_endpoint,
max_log_level, max_log_level,
status_sender, status_sender,
@ -1542,7 +1542,7 @@ fn main() -> Result<(), LauncherError> {
&args, &args,
cuda_graphs, cuda_graphs,
max_total_tokens, max_total_tokens,
max_batch_prefill_tokens, max_input_tokens,
max_log_level, max_log_level,
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,

View File

@ -41,7 +41,7 @@ def serve(
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
otlp_endpoint: Optional[str] = None, otlp_endpoint: Optional[str] = None,
max_batch_prefill_tokens: Optional[int] = None, max_input_tokens: Optional[int] = None,
): ):
if sharded: if sharded:
assert ( assert (
@ -98,7 +98,7 @@ def serve(
dtype, dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
max_batch_prefill_tokens, max_input_tokens,
) )

View File

@ -169,10 +169,8 @@ if ENGINE == "ck":
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1 and q.shape[0] > window_size_left:
raise ValueError( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
k, k,
@ -204,10 +202,7 @@ elif ENGINE == "triton":
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
): ):
if window_size_left != -1 and q.shape[0] > window_size_left: # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
output, _ = triton_attention( output, _ = triton_attention(
q, q,
k, k,

View File

@ -14,10 +14,7 @@ def attention(
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
): ):
if window_size_left != -1 and q.shape[0] > window_size_left: # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention( return ipex.llm.functional.varlen_attention(
q, q,
k, k,

View File

@ -50,7 +50,7 @@ __all__ = [
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
SLIDING_WINDOW_MESSAGE = "The backend {} does not support sliding window attention. TGI webserver was started max_batch_prefill_tokens={} larger than sliding_window={}. To use this model with the {} backend, please launch TGI with the argument `--max-batch-prefill-tokens` smaller than {}." SLIDING_WINDOW_MESSAGE = "The backend {} does not support sliding window attention. TGI webserver was started max_input_tokens={} larger than sliding_window={}. To use this model with the {} backend, please launch TGI with the argument `--max-batch-prefill-tokens` smaller than {}."
FLASH_ATTENTION = True FLASH_ATTENTION = True
@ -261,7 +261,7 @@ def get_model(
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_batch_prefill_tokens: int, max_input_tokens: int,
) -> Model: ) -> Model:
global FLASH_ATTENTION global FLASH_ATTENTION
if dtype is None: if dtype is None:
@ -416,9 +416,13 @@ def get_model(
) )
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if sliding_window != -1 and not SUPPORTS_WINDOWING: if (
logger.warning( (sliding_window is not None and sliding_window != -1)
f"Flash attention is available, but on the backend {SYSTEM} doesn't support sliding window attention, which is required by model {model_id} for long contexts." and not SUPPORTS_WINDOWING
and max_input_tokens > sliding_window
):
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention. TGI webserver was started max_input_tokens={max_input_tokens} larger than sliding_window={sliding_window}. To use this model with the {SYSTEM} backend, please launch TGI with the argument `--max-batch-prefill-tokens` smaller than {sliding_window}."
) )
if model_type == MAMBA: if model_type == MAMBA:
@ -706,22 +710,7 @@ def get_model(
) )
if model_type == MISTRAL: if model_type == MISTRAL:
sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION:
if (
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_batch_prefill_tokens > sliding_window
):
raise ValueError(
SLIDING_WINDOW_MESSAGE.format(
SYSTEM,
max_batch_prefill_tokens,
sliding_window,
SYSTEM,
sliding_window,
)
)
elif FLASH_ATTENTION:
return FlashMistral( return FlashMistral(
model_id, model_id,
revision, revision,
@ -743,22 +732,7 @@ def get_model(
) )
if model_type == MIXTRAL: if model_type == MIXTRAL:
sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION:
if (
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_batch_prefill_tokens > sliding_window
):
raise ValueError(
SLIDING_WINDOW_MESSAGE.format(
SYSTEM,
max_batch_prefill_tokens,
sliding_window,
SYSTEM,
sliding_window,
)
)
elif FLASH_ATTENTION:
return FlashMixtral( return FlashMixtral(
model_id, model_id,
revision, revision,
@ -780,21 +754,7 @@ def get_model(
) )
if model_type == STARCODER2: if model_type == STARCODER2:
if ( if FLASH_ATTENTION:
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_batch_prefill_tokens > sliding_window
):
raise ValueError(
SLIDING_WINDOW_MESSAGE.format(
SYSTEM,
max_batch_prefill_tokens,
sliding_window,
SYSTEM,
sliding_window,
)
)
elif FLASH_ATTENTION:
return FlashStarcoder2( return FlashStarcoder2(
model_id, model_id,
revision, revision,
@ -817,22 +777,7 @@ def get_model(
) )
if model_type == QWEN2: if model_type == QWEN2:
sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION:
if (
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_batch_prefill_tokens > sliding_window
):
raise ValueError(
SLIDING_WINDOW_MESSAGE.format(
SYSTEM,
max_batch_prefill_tokens,
sliding_window,
SYSTEM,
sliding_window,
)
)
elif sliding_window is None or sliding_window != -1:
return FlashQwen2( return FlashQwen2(
model_id, model_id,
revision, revision,

View File

@ -199,7 +199,7 @@ def serve(
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
max_batch_prefill_tokens: int, max_input_tokens: int,
): ):
async def serve_inner( async def serve_inner(
model_id: str, model_id: str,
@ -230,7 +230,7 @@ def serve(
speculate, speculate,
dtype, dtype,
trust_remote_code, trust_remote_code,
max_batch_prefill_tokens, max_input_tokens,
) )
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")