Using both value from config as they might not be correct. (#2817)

* Using both value from config as they might not be correct.

* Fixing max_position_embeddings for falcon.

* Simple attempt to fix the healthcheck block allocation.

* Much simpler solution.

* Default value for Backend start_health
This commit is contained in:
Nicolas Patry 2024-12-11 00:07:09 +05:30 committed by GitHub
parent a2d878fa0f
commit 82c24f7420
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 25 additions and 5 deletions

View File

@ -104,6 +104,10 @@ impl Backend for BackendV2 {
} }
.is_ok() .is_ok()
} }
fn start_health(&self) -> bool {
true
}
} }
/// Batching logic /// Batching logic

View File

@ -111,6 +111,10 @@ impl Backend for BackendV3 {
} }
.is_ok() .is_ok()
} }
fn start_health(&self) -> bool {
true
}
} }
/// Batching logic /// Batching logic

View File

@ -217,8 +217,8 @@ impl Health for ShardedClient {
input_chunks: Some(Input { input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()], chunks: vec![Chunk::Text("liveness".into()).into()],
}), }),
truncate: 10, truncate: 1,
add_special_tokens: true, add_special_tokens: false,
prefill_logprobs: false, prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,
@ -241,7 +241,7 @@ impl Health for ShardedClient {
top_n_tokens: 0, top_n_tokens: 0,
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: vec![0],
cache_len: 0, cache_len: 0,
adapter_id: None, adapter_id: None,
chunk_len: None, chunk_len: None,

View File

@ -33,6 +33,13 @@ pub trait Backend {
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>; ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
async fn health(&self, current_health: bool) -> bool; async fn health(&self, current_health: bool) -> bool;
/// The state of the health on startup
/// Typically false, or true if the backend includes
/// a warmup phase.
fn start_health(&self) -> bool {
false
}
} }
/// Inference struct /// Inference struct
@ -75,7 +82,7 @@ impl Infer {
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
// Backend health // Backend health
let backend_health = Arc::new(AtomicBool::new(false)); let backend_health = Arc::new(AtomicBool::new(backend.start_health()));
Self { Self {
validation, validation,

View File

@ -78,6 +78,7 @@ class RWConfig(PretrainedConfig):
self.alibi = False self.alibi = False
self.rotary = True self.rotary = True
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = 2048
self.vocab_size = vocab_size self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg # Backward compatibility with n_embed kwarg

View File

@ -1304,6 +1304,7 @@ class FlashCausalLM(Model):
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads // self.process_group.size() self.num_heads = config.num_attention_heads // self.process_group.size()
self.config = config
# Validation is done in the model itself # Validation is done in the model itself
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None) num_kv_heads = getattr(config, "num_key_value_heads", None)
@ -1594,7 +1595,10 @@ class FlashCausalLM(Model):
if max_total_tokens is None: if max_total_tokens is None:
if get_support_chunking(): if get_support_chunking():
model_max_length = self.tokenizer.model_max_length model_max_length = self.tokenizer.model_max_length
max_total_tokens = min(num_blocks * BLOCK_SIZE, model_max_length) max_position_embeddings = self.config.max_position_embeddings
max_total_tokens = min(
num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings
)
else: else:
max_total_tokens = sum(batch.cache_lengths) max_total_tokens = sum(batch.cache_lengths)