mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix tp
This commit is contained in:
parent
7816a47697
commit
26fc232afb
@ -575,6 +575,14 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
def __init__(self, config, process_group=None):
|
def __init__(self, config, process_group=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.process_group = process_group
|
||||||
|
if self.process_group is not None:
|
||||||
|
self.world_size = self.process_group.size()
|
||||||
|
self.rank = self.process_group.rank()
|
||||||
|
else:
|
||||||
|
self.world_size = 1
|
||||||
|
self.rank = 0
|
||||||
|
|
||||||
self.model = FlashLlamaModel(config, process_group)
|
self.model = FlashLlamaModel(config, process_group)
|
||||||
|
|
||||||
if self.model.tp_embeddings:
|
if self.model.tp_embeddings:
|
||||||
|
@ -624,13 +624,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config, process_group=None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if config.tp_parallel:
|
self.process_group = process_group
|
||||||
process_group = torch.distributed.distributed_c10d._get_default_group()
|
if self.process_group is not None:
|
||||||
|
self.world_size = self.process_group.size()
|
||||||
|
self.rank = self.process_group.rank()
|
||||||
else:
|
else:
|
||||||
process_group = None
|
self.world_size = 1
|
||||||
|
self.rank = 0
|
||||||
|
|
||||||
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
|
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
|
||||||
|
|
||||||
|
@ -43,7 +43,8 @@ class FlashLlama(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, tp_parallel=True
|
model_id,
|
||||||
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We do not use from_pretrained as we modified the model internal module layout
|
# We do not use from_pretrained as we modified the model internal module layout
|
||||||
@ -57,12 +58,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = FlashLlamaForCausalLM(config)
|
model = FlashLlamaForCausalLM(config)
|
||||||
|
|
||||||
self.load_weights(
|
self.load_weights(model, filenames, device, dtype)
|
||||||
model,
|
|
||||||
filenames,
|
|
||||||
device,
|
|
||||||
dtype
|
|
||||||
)
|
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
@ -163,7 +159,8 @@ class FlashLlamaSharded(FlashLlama):
|
|||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, tp_parallel=True
|
model_id,
|
||||||
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -49,14 +49,15 @@ class FlashNeoXSharded(FlashNeoX):
|
|||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, tp_parallel=True
|
model_id,
|
||||||
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = FlashGPTNeoXForCausalLM(config)
|
model = FlashGPTNeoXForCausalLM(config, self.process_group)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
|
@ -39,11 +39,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.ClearCacheResponse()
|
return generate_pb2.ClearCacheResponse()
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
from torch.profiler import profile, ProfilerActivity
|
||||||
request.batch, self.model.tokenizer, self.model.device
|
|
||||||
)
|
with profile(
|
||||||
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
||||||
|
) as prefill_prof:
|
||||||
|
batch = self.model.batch_type.from_pb(
|
||||||
|
request.batch, self.model.tokenizer, self.model.device
|
||||||
|
)
|
||||||
|
|
||||||
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
|
prefill_prof.export_chrome_trace("prefill.json")
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
@ -62,12 +69,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
|
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
|
||||||
batches.append(batch)
|
batches.append(batch)
|
||||||
|
|
||||||
if len(batches) > 1:
|
from torch.profiler import profile, ProfilerActivity
|
||||||
batch = self.model.batch_type.concatenate(batches)
|
|
||||||
else:
|
with profile(
|
||||||
batch = batches[0]
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
||||||
|
) as decode_prof:
|
||||||
|
|
||||||
|
if len(batches) > 1:
|
||||||
|
batch = self.model.batch_type.concatenate(batches)
|
||||||
|
else:
|
||||||
|
batch = batches[0]
|
||||||
|
|
||||||
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
|
decode_prof.export_chrome_trace("decode.json")
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
|
Loading…
Reference in New Issue
Block a user