This commit is contained in:
Nicolas Patry 2024-01-29 12:11:57 +01:00
parent e19d6f3589
commit a7a98c0253
2 changed files with 25 additions and 13 deletions

View File

@ -982,14 +982,16 @@ fn spawn_shards(
Ok(())
}
fn compute_type(num_shard: usize) -> Option<String>{
let output = Command::new("nvidia-smi").args(["--query-gpu=gpu_name", "--format=csv"]).output().ok()?;
fn compute_type(num_shard: usize) -> Option<String> {
let output = Command::new("nvidia-smi")
.args(["--query-gpu=gpu_name", "--format=csv"])
.output()
.ok()?;
let output = String::from_utf8(output.stdout).ok()?;
let fullname = output.split("\n").nth(1)?;
let cardname = fullname.replace(" ", "-").to_lowercase();
let compute_type = format!("{num_shard}-{cardname}");
Some(compute_type)
}
fn spawn_webserver(
@ -1086,7 +1088,7 @@ fn spawn_webserver(
// Parse Compute type
if let Ok(compute_type) = env::var("COMPUTE_TYPE") {
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
}else if let Some(compute_type) = compute_type(num_shard){
} else if let Some(compute_type) = compute_type(num_shard) {
envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
}
@ -1283,8 +1285,8 @@ fn main() -> Result<(), LauncherError> {
return Ok(());
}
let mut webserver =
spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver)
.map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?;

View File

@ -67,7 +67,7 @@ async fn compat_generate(
// switch on stream
if req.stream {
Ok(generate_stream(infer,compute_type, Json(req.into()))
Ok(generate_stream(infer, compute_type, Json(req.into()))
.await
.into_response())
} else {
@ -372,7 +372,7 @@ async fn generate_stream_internal(
let compute_characters = req.inputs.chars().count();
let mut headers = HeaderMap::new();
headers.insert("x-compute-type",compute_type.parse().unwrap());
headers.insert("x-compute-type", compute_type.parse().unwrap());
headers.insert(
"x-compute-characters",
compute_characters.to_string().parse().unwrap(),
@ -649,13 +649,22 @@ async fn chat_completions(
)
};
let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), on_message_callback).await;
let (headers, response_stream) = generate_stream_internal(
infer,
compute_type,
Json(generate_request),
on_message_callback,
)
.await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
let (headers, Json(generation)) =
generate(Extension(infer), Extension(compute_type), Json(generate_request)).await?;
let (headers, Json(generation)) = generate(
Extension(infer),
Extension(compute_type),
Json(generate_request),
)
.await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
@ -943,7 +952,8 @@ pub async fn run(
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
let compute_type =
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
// Combine routes and layers
let app = Router::new()