Improve logging from python shards

I found these changes necessary to be able to do reasonable problem diagnosis.

- Ensure output from the shards makes its way to the launcher's stdout
- Log full exceptions raised by server rpc functions
- Fix stderr logging in integration_tests.rs
This commit is contained in:
Nick Hill 2022-12-27 16:19:10 -08:00
parent 3efa5bbbfd
commit 337e1f8795
3 changed files with 26 additions and 6 deletions

View File

@ -299,6 +299,10 @@ fn shard_manager(
"SAFETENSORS_FAST_GPU".parse().unwrap(),
"1".to_string().parse().unwrap(),
),
(
"PYTHONUNBUFFERED".parse().unwrap(),
"1".to_string().parse().unwrap(),
),
];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
@ -347,6 +351,12 @@ fn shard_manager(
}
};
// Redirect STDOUT to the console
let shard_stdout = p.stdout.take().unwrap();
thread::spawn(move || BufReader::new(shard_stdout).lines().for_each(|line|
println!("Shard {}: {}", rank, line.unwrap())
));
let mut ready = false;
let start_time = Instant::now();
let mut wait_time = Instant::now();

View File

@ -41,25 +41,21 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
&argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
stderr: Redirection::Merge,
..Default::default()
},
)
.expect("Could not start launcher");
// Redirect STDOUT and STDERR to the console
// (STDERR is merged into STDOUT)
let launcher_stdout = launcher.stdout.take().unwrap();
let launcher_stderr = launcher.stderr.take().unwrap();
thread::spawn(move || {
let stdout = BufReader::new(launcher_stdout);
let stderr = BufReader::new(launcher_stderr);
for line in stdout.lines() {
println!("{}", line.unwrap());
}
for line in stderr.lines() {
println!("{}", line.unwrap());
}
});
for _ in 0..60 {

View File

@ -1,4 +1,5 @@
import asyncio
import logging
import os
from grpc import aio
@ -12,6 +13,16 @@ from text_generation.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2
def log_errs(func):
async def func_with_log(*args, **kwargs):
try:
return await func(*args,**kwargs)
except Exception as e:
logging.exception(f"{func.__name__} failed")
raise e
return func_with_log
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache
@ -21,10 +32,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
@log_errs
async def ClearCache(self, request, context):
self.cache.clear()
return generate_pb2.ClearCacheResponse()
@log_errs
async def Generate(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device
@ -40,6 +53,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch=next_batch.to_pb() if next_batch else None,
)
@log_errs
async def GenerateWithCache(self, request, context):
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")