mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Using both value from config as they might not be correct. (#2817)
* Using both value from config as they might not be correct. * Fixing max_position_embeddings for falcon. * Simple attempt to fix the healthcheck block allocation. * Much simpler solution. * Default value for Backend start_health
This commit is contained in:
parent
a2d878fa0f
commit
82c24f7420
@ -104,6 +104,10 @@ impl Backend for BackendV2 {
|
|||||||
}
|
}
|
||||||
.is_ok()
|
.is_ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn start_health(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
|
@ -111,6 +111,10 @@ impl Backend for BackendV3 {
|
|||||||
}
|
}
|
||||||
.is_ok()
|
.is_ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn start_health(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
|
@ -217,8 +217,8 @@ impl Health for ShardedClient {
|
|||||||
input_chunks: Some(Input {
|
input_chunks: Some(Input {
|
||||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
}),
|
}),
|
||||||
truncate: 10,
|
truncate: 1,
|
||||||
add_special_tokens: true,
|
add_special_tokens: false,
|
||||||
prefill_logprobs: false,
|
prefill_logprobs: false,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
@ -241,7 +241,7 @@ impl Health for ShardedClient {
|
|||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: vec![0],
|
||||||
cache_len: 0,
|
cache_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
chunk_len: None,
|
chunk_len: None,
|
||||||
|
@ -33,6 +33,13 @@ pub trait Backend {
|
|||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
|
||||||
|
|
||||||
async fn health(&self, current_health: bool) -> bool;
|
async fn health(&self, current_health: bool) -> bool;
|
||||||
|
|
||||||
|
/// The state of the health on startup
|
||||||
|
/// Typically false, or true if the backend includes
|
||||||
|
/// a warmup phase.
|
||||||
|
fn start_health(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inference struct
|
/// Inference struct
|
||||||
@ -75,7 +82,7 @@ impl Infer {
|
|||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
// Backend health
|
// Backend health
|
||||||
let backend_health = Arc::new(AtomicBool::new(false));
|
let backend_health = Arc::new(AtomicBool::new(backend.start_health()));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
|
@ -78,6 +78,7 @@ class RWConfig(PretrainedConfig):
|
|||||||
self.alibi = False
|
self.alibi = False
|
||||||
self.rotary = True
|
self.rotary = True
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = 2048
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
# Backward compatibility with n_embed kwarg
|
# Backward compatibility with n_embed kwarg
|
||||||
|
@ -1304,6 +1304,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
self.num_heads = config.num_attention_heads // self.process_group.size()
|
self.num_heads = config.num_attention_heads // self.process_group.size()
|
||||||
|
self.config = config
|
||||||
# Validation is done in the model itself
|
# Validation is done in the model itself
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||||
@ -1594,7 +1595,10 @@ class FlashCausalLM(Model):
|
|||||||
if max_total_tokens is None:
|
if max_total_tokens is None:
|
||||||
if get_support_chunking():
|
if get_support_chunking():
|
||||||
model_max_length = self.tokenizer.model_max_length
|
model_max_length = self.tokenizer.model_max_length
|
||||||
max_total_tokens = min(num_blocks * BLOCK_SIZE, model_max_length)
|
max_position_embeddings = self.config.max_position_embeddings
|
||||||
|
max_total_tokens = min(
|
||||||
|
num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
max_total_tokens = sum(batch.cache_lengths)
|
max_total_tokens = sum(batch.cache_lengths)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user