revert back to normal allocator

This commit is contained in:
OlivierDehaene 2023-07-18 16:11:18 +02:00
parent 79616a8796
commit de892fb434
2 changed files with 4 additions and 16 deletions

View File

@ -370,10 +370,10 @@ fn shard_manager(
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation
envs.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));
// envs.push((
// "PYTORCH_CUDA_ALLOC_CONF".into(),
// "backend:cudaMallocAsync".into(),
// ));
// Torch Distributed Env vars
envs.push(("RANK".into(), rank.to_string().into()));

View File

@ -51,9 +51,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
@ -62,9 +59,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
max_supported_total_tokens = self.model.warmup(batch)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.WarmupResponse(
max_supported_total_tokens=max_supported_total_tokens
)
@ -78,8 +72,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if next_batch is not None:
self.cache.set(next_batch)
else:
torch.cuda.empty_cache()
return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations],
@ -102,8 +94,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
batch = batches[0]
@ -111,8 +101,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if next_batch is not None:
self.cache.set(next_batch)
else:
torch.cuda.empty_cache()
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],