remove profiling

This commit is contained in:
OlivierDehaene 2023-04-06 17:58:54 +02:00
parent 26fc232afb
commit c3779fa859

View File

@ -39,18 +39,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ClearCacheResponse() return generate_pb2.ClearCacheResponse()
async def Prefill(self, request, context): async def Prefill(self, request, context):
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]
) as prefill_prof:
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device request.batch, self.model.tokenizer, self.model.device
) )
generations, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
prefill_prof.export_chrome_trace("prefill.json")
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
@ -69,20 +62,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch) batches.append(batch)
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]
) as decode_prof:
if len(batches) > 1: if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches) batch = self.model.batch_type.concatenate(batches)
else: else:
batch = batches[0] batch = batches[0]
generations, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
decode_prof.export_chrome_trace("decode.json")
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(