mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Upgrading from_env
to get token from file when necessary + fix
pali_gemma.
This commit is contained in:
parent
659ce4f3fc
commit
e5dfd41ed4
@ -97,11 +97,10 @@ fn get_config(
|
||||
let filename = if !path.exists() {
|
||||
// Assume it's a hub id
|
||||
|
||||
let mut builder = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||
let mut builder = ApiBuilder::from_env();
|
||||
if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||
// env variable has precedence over on file token.
|
||||
ApiBuilder::new().with_token(Some(token))
|
||||
} else {
|
||||
ApiBuilder::new()
|
||||
builder = builder.with_token(Some(token))
|
||||
};
|
||||
if let Ok(origin) = env::var("HF_HUB_USER_AGENT_ORIGIN") {
|
||||
builder = builder.with_user_agent("origin", origin.as_str());
|
||||
|
@ -1522,7 +1522,7 @@ pub async fn run(
|
||||
|
||||
// Shared API builder initialization
|
||||
let api_builder = || {
|
||||
let mut builder = ApiBuilder::new().with_progress(false);
|
||||
let mut builder = ApiBuilder::from_env().with_progress(false);
|
||||
if let Some(token) = authorization_token {
|
||||
builder = builder.with_token(Some(token));
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_model" if not prefix else f"{prefix}.vision_model",
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
@ -74,7 +74,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
self.window_size = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -172,7 +172,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
window_size_left=self.window_size,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
@ -185,7 +185,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
window_size_left=self.max_past,
|
||||
window_size_left=self.window_size,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
@ -406,8 +406,8 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past_tensor = (
|
||||
self.window_size = config.sliding_window
|
||||
self.window_size_tensor = (
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
else None
|
||||
@ -434,7 +434,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
seqlen = seqlen.clamp(max=self.window_size_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user