Upgrading from_env to get token from file when necessary + fix

pali_gemma.
This commit is contained in:
Nicolas Patry 2025-03-14 17:06:36 +01:00
parent 659ce4f3fc
commit e5dfd41ed4
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
4 changed files with 11 additions and 12 deletions

View File

@ -97,11 +97,10 @@ fn get_config(
let filename = if !path.exists() {
// Assume it's a hub id
let mut builder = if let Ok(token) = std::env::var("HF_TOKEN") {
let mut builder = ApiBuilder::from_env();
if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token))
} else {
ApiBuilder::new()
builder = builder.with_token(Some(token))
};
if let Ok(origin) = env::var("HF_HUB_USER_AGENT_ORIGIN") {
builder = builder.with_user_agent("origin", origin.as_str());

View File

@ -1522,7 +1522,7 @@ pub async fn run(
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new().with_progress(false);
let mut builder = ApiBuilder::from_env().with_progress(false);
if let Some(token) = authorization_token {
builder = builder.with_token(Some(token));
}

View File

@ -31,7 +31,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
super().__init__()
config.vision_config.quantize = config.quantize
self.vision_tower = load_vision_model(
prefix="vision_model" if not prefix else f"{prefix}.vision_model",
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)

View File

@ -74,7 +74,7 @@ class Qwen2Attention(torch.nn.Module):
weights,
):
super().__init__()
self.max_past = (
self.window_size = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
@ -172,7 +172,7 @@ class Qwen2Attention(torch.nn.Module):
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
window_size_left=self.window_size,
)
# Decode
else:
@ -185,7 +185,7 @@ class Qwen2Attention(torch.nn.Module):
seqlen,
max_s,
kv_scales=self.kv_scales,
window_size_left=self.max_past,
window_size_left=self.window_size,
)
return self.o_proj(
@ -406,8 +406,8 @@ class Qwen2ForCausalLM(torch.nn.Module):
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
self.window_size = config.sliding_window
self.window_size_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
@ -434,7 +434,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.window_size_tensor)
inputs_embeds = self.embed_tokens(input_ids)