mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
address review
This commit is contained in:
parent
4220423d76
commit
b884d2b9b8
@ -476,7 +476,7 @@ fn shard_manager(
|
|||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
max_batch_prefill_tokens: u32,
|
max_input_tokens: usize,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
log_level: LevelFilter,
|
log_level: LevelFilter,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
@ -550,8 +550,8 @@ fn shard_manager(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
||||||
shard_args.push("--max-batch-prefill-tokens".to_string());
|
shard_args.push("--max-input-tokens".to_string());
|
||||||
shard_args.push(max_batch_prefill_tokens.to_string());
|
shard_args.push(max_input_tokens.to_string());
|
||||||
|
|
||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
@ -1009,7 +1009,7 @@ fn spawn_shards(
|
|||||||
args: &Args,
|
args: &Args,
|
||||||
cuda_graphs: Vec<usize>,
|
cuda_graphs: Vec<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
max_batch_prefill_tokens: u32,
|
max_input_tokens: usize,
|
||||||
max_log_level: LevelFilter,
|
max_log_level: LevelFilter,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
@ -1067,7 +1067,7 @@ fn spawn_shards(
|
|||||||
rope_factor,
|
rope_factor,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_batch_prefill_tokens,
|
max_input_tokens,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
max_log_level,
|
max_log_level,
|
||||||
status_sender,
|
status_sender,
|
||||||
@ -1542,7 +1542,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
&args,
|
&args,
|
||||||
cuda_graphs,
|
cuda_graphs,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
max_batch_prefill_tokens,
|
max_input_tokens,
|
||||||
max_log_level,
|
max_log_level,
|
||||||
shutdown.clone(),
|
shutdown.clone(),
|
||||||
&shutdown_receiver,
|
&shutdown_receiver,
|
||||||
|
@ -41,7 +41,7 @@ def serve(
|
|||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
otlp_endpoint: Optional[str] = None,
|
otlp_endpoint: Optional[str] = None,
|
||||||
max_batch_prefill_tokens: Optional[int] = None,
|
max_input_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if sharded:
|
if sharded:
|
||||||
assert (
|
assert (
|
||||||
@ -98,7 +98,7 @@ def serve(
|
|||||||
dtype,
|
dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
max_batch_prefill_tokens,
|
max_input_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -169,10 +169,8 @@ if ENGINE == "ck":
|
|||||||
):
|
):
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
if window_size_left != -1 and q.shape[0] > window_size_left:
|
|
||||||
raise ValueError(
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -204,10 +202,7 @@ elif ENGINE == "triton":
|
|||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
):
|
):
|
||||||
if window_size_left != -1 and q.shape[0] > window_size_left:
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
raise ValueError(
|
|
||||||
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
@ -14,10 +14,7 @@ def attention(
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
):
|
):
|
||||||
if window_size_left != -1 and q.shape[0] > window_size_left:
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
raise ValueError(
|
|
||||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
return ipex.llm.functional.varlen_attention(
|
return ipex.llm.functional.varlen_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
@ -50,7 +50,7 @@ __all__ = [
|
|||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
|
||||||
SLIDING_WINDOW_MESSAGE = "The backend {} does not support sliding window attention. TGI webserver was started max_batch_prefill_tokens={} larger than sliding_window={}. To use this model with the {} backend, please launch TGI with the argument `--max-batch-prefill-tokens` smaller than {}."
|
SLIDING_WINDOW_MESSAGE = "The backend {} does not support sliding window attention. TGI webserver was started max_input_tokens={} larger than sliding_window={}. To use this model with the {} backend, please launch TGI with the argument `--max-batch-prefill-tokens` smaller than {}."
|
||||||
|
|
||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = True
|
||||||
|
|
||||||
@ -261,7 +261,7 @@ def get_model(
|
|||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
max_batch_prefill_tokens: int,
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
global FLASH_ATTENTION
|
global FLASH_ATTENTION
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
@ -416,9 +416,13 @@ def get_model(
|
|||||||
)
|
)
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
|
|
||||||
if sliding_window != -1 and not SUPPORTS_WINDOWING:
|
if (
|
||||||
logger.warning(
|
(sliding_window is not None and sliding_window != -1)
|
||||||
f"Flash attention is available, but on the backend {SYSTEM} doesn't support sliding window attention, which is required by model {model_id} for long contexts."
|
and not SUPPORTS_WINDOWING
|
||||||
|
and max_input_tokens > sliding_window
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"The backend {SYSTEM} does not support sliding window attention. TGI webserver was started max_input_tokens={max_input_tokens} larger than sliding_window={sliding_window}. To use this model with the {SYSTEM} backend, please launch TGI with the argument `--max-batch-prefill-tokens` smaller than {sliding_window}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MAMBA:
|
if model_type == MAMBA:
|
||||||
@ -706,22 +710,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MISTRAL:
|
if model_type == MISTRAL:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
if FLASH_ATTENTION:
|
||||||
if (
|
|
||||||
(sliding_window is not None and sliding_window != -1)
|
|
||||||
and not SUPPORTS_WINDOWING
|
|
||||||
and max_batch_prefill_tokens > sliding_window
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
SLIDING_WINDOW_MESSAGE.format(
|
|
||||||
SYSTEM,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
sliding_window,
|
|
||||||
SYSTEM,
|
|
||||||
sliding_window,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif FLASH_ATTENTION:
|
|
||||||
return FlashMistral(
|
return FlashMistral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -743,22 +732,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MIXTRAL:
|
if model_type == MIXTRAL:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
if FLASH_ATTENTION:
|
||||||
if (
|
|
||||||
(sliding_window is not None and sliding_window != -1)
|
|
||||||
and not SUPPORTS_WINDOWING
|
|
||||||
and max_batch_prefill_tokens > sliding_window
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
SLIDING_WINDOW_MESSAGE.format(
|
|
||||||
SYSTEM,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
sliding_window,
|
|
||||||
SYSTEM,
|
|
||||||
sliding_window,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif FLASH_ATTENTION:
|
|
||||||
return FlashMixtral(
|
return FlashMixtral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -780,21 +754,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == STARCODER2:
|
if model_type == STARCODER2:
|
||||||
if (
|
if FLASH_ATTENTION:
|
||||||
(sliding_window is not None and sliding_window != -1)
|
|
||||||
and not SUPPORTS_WINDOWING
|
|
||||||
and max_batch_prefill_tokens > sliding_window
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
SLIDING_WINDOW_MESSAGE.format(
|
|
||||||
SYSTEM,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
sliding_window,
|
|
||||||
SYSTEM,
|
|
||||||
sliding_window,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif FLASH_ATTENTION:
|
|
||||||
return FlashStarcoder2(
|
return FlashStarcoder2(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -817,22 +777,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == QWEN2:
|
if model_type == QWEN2:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
if FLASH_ATTENTION:
|
||||||
if (
|
|
||||||
(sliding_window is not None and sliding_window != -1)
|
|
||||||
and not SUPPORTS_WINDOWING
|
|
||||||
and max_batch_prefill_tokens > sliding_window
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
SLIDING_WINDOW_MESSAGE.format(
|
|
||||||
SYSTEM,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
sliding_window,
|
|
||||||
SYSTEM,
|
|
||||||
sliding_window,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif sliding_window is None or sliding_window != -1:
|
|
||||||
return FlashQwen2(
|
return FlashQwen2(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -199,7 +199,7 @@ def serve(
|
|||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
max_batch_prefill_tokens: int,
|
max_input_tokens: int,
|
||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -230,7 +230,7 @@ def serve(
|
|||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
max_batch_prefill_tokens,
|
max_input_tokens,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
|
Loading…
Reference in New Issue
Block a user