From e5dfd41ed459d5ef91fc23910daa03d0f34bdef0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 14 Mar 2025 17:06:36 +0100 Subject: [PATCH] Upgrading `from_env` to get token from file when necessary + fix pali_gemma. --- launcher/src/main.rs | 7 +++---- router/src/server.rs | 2 +- .../custom_modeling/flash_pali_gemma_modeling.py | 2 +- .../models/custom_modeling/flash_qwen2_modeling.py | 12 ++++++------ 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d1041e26e..212f8a3b4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -97,11 +97,10 @@ fn get_config( let filename = if !path.exists() { // Assume it's a hub id - let mut builder = if let Ok(token) = std::env::var("HF_TOKEN") { + let mut builder = ApiBuilder::from_env(); + if let Ok(token) = std::env::var("HF_TOKEN") { // env variable has precedence over on file token. - ApiBuilder::new().with_token(Some(token)) - } else { - ApiBuilder::new() + builder = builder.with_token(Some(token)) }; if let Ok(origin) = env::var("HF_HUB_USER_AGENT_ORIGIN") { builder = builder.with_user_agent("origin", origin.as_str()); diff --git a/router/src/server.rs b/router/src/server.rs index 0346b1f19..45d2b9f3c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1522,7 +1522,7 @@ pub async fn run( // Shared API builder initialization let api_builder = || { - let mut builder = ApiBuilder::new().with_progress(false); + let mut builder = ApiBuilder::from_env().with_progress(false); if let Some(token) = authorization_token { builder = builder.with_token(Some(token)); } diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 4ea604510..b1f89eff4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -31,7 +31,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): super().__init__() config.vision_config.quantize = config.quantize self.vision_tower = load_vision_model( - prefix="vision_model" if not prefix else f"{prefix}.vision_model", + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, weights=weights, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 75d519e45..c06e5dcce 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -74,7 +74,7 @@ class Qwen2Attention(torch.nn.Module): weights, ): super().__init__() - self.max_past = ( + self.window_size = ( config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads @@ -172,7 +172,7 @@ class Qwen2Attention(torch.nn.Module): seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, - window_size_left=self.max_past, + window_size_left=self.window_size, ) # Decode else: @@ -185,7 +185,7 @@ class Qwen2Attention(torch.nn.Module): seqlen, max_s, kv_scales=self.kv_scales, - window_size_left=self.max_past, + window_size_left=self.window_size, ) return self.o_proj( @@ -406,8 +406,8 @@ class Qwen2ForCausalLM(torch.nn.Module): weights=weights, ) - self.max_past = config.sliding_window - self.max_past_tensor = ( + self.window_size = config.sliding_window + self.window_size_tensor = ( torch.tensor(config.sliding_window, device=weights.device) if self.max_past is not None else None @@ -434,7 +434,7 @@ class Qwen2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.window_size_tensor) inputs_embeds = self.embed_tokens(input_ids)