Upgrading from_env to get token from file when necessary + fix

pali_gemma.
This commit is contained in:
Nicolas Patry 2025-03-14 17:06:36 +01:00
parent 659ce4f3fc
commit e5dfd41ed4
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
4 changed files with 11 additions and 12 deletions

View File

@ -97,11 +97,10 @@ fn get_config(
let filename = if !path.exists() { let filename = if !path.exists() {
// Assume it's a hub id // Assume it's a hub id
let mut builder = if let Ok(token) = std::env::var("HF_TOKEN") { let mut builder = ApiBuilder::from_env();
if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token. // env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)) builder = builder.with_token(Some(token))
} else {
ApiBuilder::new()
}; };
if let Ok(origin) = env::var("HF_HUB_USER_AGENT_ORIGIN") { if let Ok(origin) = env::var("HF_HUB_USER_AGENT_ORIGIN") {
builder = builder.with_user_agent("origin", origin.as_str()); builder = builder.with_user_agent("origin", origin.as_str());

View File

@ -1522,7 +1522,7 @@ pub async fn run(
// Shared API builder initialization // Shared API builder initialization
let api_builder = || { let api_builder = || {
let mut builder = ApiBuilder::new().with_progress(false); let mut builder = ApiBuilder::from_env().with_progress(false);
if let Some(token) = authorization_token { if let Some(token) = authorization_token {
builder = builder.with_token(Some(token)); builder = builder.with_token(Some(token));
} }

View File

@ -31,7 +31,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = config.quantize
self.vision_tower = load_vision_model( self.vision_tower = load_vision_model(
prefix="vision_model" if not prefix else f"{prefix}.vision_model", prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config, config=config.vision_config,
weights=weights, weights=weights,
) )

View File

@ -74,7 +74,7 @@ class Qwen2Attention(torch.nn.Module):
weights, weights,
): ):
super().__init__() super().__init__()
self.max_past = ( self.window_size = (
config.sliding_window if config.sliding_window is not None else -1 config.sliding_window if config.sliding_window is not None else -1
) )
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
@ -172,7 +172,7 @@ class Qwen2Attention(torch.nn.Module):
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.window_size,
) )
# Decode # Decode
else: else:
@ -185,7 +185,7 @@ class Qwen2Attention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
window_size_left=self.max_past, window_size_left=self.window_size,
) )
return self.o_proj( return self.o_proj(
@ -406,8 +406,8 @@ class Qwen2ForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.window_size = config.sliding_window
self.max_past_tensor = ( self.window_size_tensor = (
torch.tensor(config.sliding_window, device=weights.device) torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None if self.max_past is not None
else None else None
@ -434,7 +434,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor) seqlen = seqlen.clamp(max=self.window_size_tensor)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)