mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Trying back to but EXl2 + TP>1,
the issue might have been cleaned memory by torch allocator
This commit is contained in:
parent
16958fe312
commit
97d9ff3a71
@ -185,6 +185,7 @@ class QuantLinear(nn.Module):
|
||||
"g_idx": self.g_idx,
|
||||
}
|
||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||
self.temp_dq = temp_dq
|
||||
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
|
||||
|
||||
def forward(self, x, force_cuda=False):
|
||||
|
@ -35,12 +35,12 @@ except Exception:
|
||||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||
V2 = False
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Disabling exllama v2 and using v1 instead because there are issues when sharding",
|
||||
)
|
||||
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||
# V2 = False
|
||||
# log_once(
|
||||
# logger.warning,
|
||||
# "Disabling exllama v2 and using v1 instead because there are issues when sharding",
|
||||
# )
|
||||
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
|
Loading…
Reference in New Issue
Block a user