This commit is contained in:
OlivierDehaene 2023-04-06 17:27:32 +02:00
parent 7816a47697
commit 26fc232afb
5 changed files with 47 additions and 23 deletions

View File

@ -575,6 +575,14 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, process_group=None): def __init__(self, config, process_group=None):
super().__init__() super().__init__()
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
else:
self.world_size = 1
self.rank = 0
self.model = FlashLlamaModel(config, process_group) self.model = FlashLlamaModel(config, process_group)
if self.model.tp_embeddings: if self.model.tp_embeddings:

View File

@ -624,13 +624,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config): def __init__(self, config, process_group=None):
super().__init__(config) super().__init__(config)
if config.tp_parallel: self.process_group = process_group
process_group = torch.distributed.distributed_c10d._get_default_group() if self.process_group is not None:
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
else: else:
process_group = None self.world_size = 1
self.rank = 0
self.gpt_neox = FlashGPTNeoXModel(config, process_group) self.gpt_neox = FlashGPTNeoXModel(config, process_group)

View File

@ -43,7 +43,8 @@ class FlashLlama(FlashCausalLM):
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True model_id,
revision=revision,
) )
# We do not use from_pretrained as we modified the model internal module layout # We do not use from_pretrained as we modified the model internal module layout
@ -57,12 +58,7 @@ class FlashLlama(FlashCausalLM):
with init_empty_weights(): with init_empty_weights():
model = FlashLlamaForCausalLM(config) model = FlashLlamaForCausalLM(config)
self.load_weights( self.load_weights(model, filenames, device, dtype)
model,
filenames,
device,
dtype
)
self.model = model.eval() self.model = model.eval()
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
@ -163,7 +159,8 @@ class FlashLlamaSharded(FlashLlama):
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True model_id,
revision=revision,
) )
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -49,14 +49,15 @@ class FlashNeoXSharded(FlashNeoX):
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True model_id,
revision=revision,
) )
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights(): with init_empty_weights():
model = FlashGPTNeoXForCausalLM(config) model = FlashGPTNeoXForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
self.load_weights( self.load_weights(

View File

@ -39,11 +39,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ClearCacheResponse() return generate_pb2.ClearCacheResponse()
async def Prefill(self, request, context): async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb( from torch.profiler import profile, ProfilerActivity
request.batch, self.model.tokenizer, self.model.device
) with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]
) as prefill_prof:
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device
)
generations, next_batch = self.model.generate_token(batch)
prefill_prof.export_chrome_trace("prefill.json")
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
@ -62,12 +69,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch) batches.append(batch)
if len(batches) > 1: from torch.profiler import profile, ProfilerActivity
batch = self.model.batch_type.concatenate(batches)
else: with profile(
batch = batches[0] activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]
) as decode_prof:
if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches)
else:
batch = batches[0]
generations, next_batch = self.model.generate_token(batch)
decode_prof.export_chrome_trace("decode.json")
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(