mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
Upgrading from_env
to get token from file when necessary + fix
pali_gemma.
This commit is contained in:
parent
659ce4f3fc
commit
e5dfd41ed4
@ -97,11 +97,10 @@ fn get_config(
|
|||||||
let filename = if !path.exists() {
|
let filename = if !path.exists() {
|
||||||
// Assume it's a hub id
|
// Assume it's a hub id
|
||||||
|
|
||||||
let mut builder = if let Ok(token) = std::env::var("HF_TOKEN") {
|
let mut builder = ApiBuilder::from_env();
|
||||||
|
if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||||
// env variable has precedence over on file token.
|
// env variable has precedence over on file token.
|
||||||
ApiBuilder::new().with_token(Some(token))
|
builder = builder.with_token(Some(token))
|
||||||
} else {
|
|
||||||
ApiBuilder::new()
|
|
||||||
};
|
};
|
||||||
if let Ok(origin) = env::var("HF_HUB_USER_AGENT_ORIGIN") {
|
if let Ok(origin) = env::var("HF_HUB_USER_AGENT_ORIGIN") {
|
||||||
builder = builder.with_user_agent("origin", origin.as_str());
|
builder = builder.with_user_agent("origin", origin.as_str());
|
||||||
|
@ -1522,7 +1522,7 @@ pub async fn run(
|
|||||||
|
|
||||||
// Shared API builder initialization
|
// Shared API builder initialization
|
||||||
let api_builder = || {
|
let api_builder = || {
|
||||||
let mut builder = ApiBuilder::new().with_progress(false);
|
let mut builder = ApiBuilder::from_env().with_progress(false);
|
||||||
if let Some(token) = authorization_token {
|
if let Some(token) = authorization_token {
|
||||||
builder = builder.with_token(Some(token));
|
builder = builder.with_token(Some(token));
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
config.vision_config.quantize = config.quantize
|
config.vision_config.quantize = config.quantize
|
||||||
self.vision_tower = load_vision_model(
|
self.vision_tower = load_vision_model(
|
||||||
prefix="vision_model" if not prefix else f"{prefix}.vision_model",
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
@ -74,7 +74,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_past = (
|
self.window_size = (
|
||||||
config.sliding_window if config.sliding_window is not None else -1
|
config.sliding_window if config.sliding_window is not None else -1
|
||||||
)
|
)
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
@ -172,7 +172,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.window_size,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
@ -185,7 +185,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.window_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(
|
return self.o_proj(
|
||||||
@ -406,8 +406,8 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_past = config.sliding_window
|
self.window_size = config.sliding_window
|
||||||
self.max_past_tensor = (
|
self.window_size_tensor = (
|
||||||
torch.tensor(config.sliding_window, device=weights.device)
|
torch.tensor(config.sliding_window, device=weights.device)
|
||||||
if self.max_past is not None
|
if self.max_past is not None
|
||||||
else None
|
else None
|
||||||
@ -434,7 +434,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
seqlen = seqlen.clamp(max=self.window_size_tensor)
|
||||||
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user