Disable tensor caching in HPU Graph execution (#4)

This commit is contained in:
jkaniecki 2023-12-22 13:51:16 +01:00 committed by GitHub
parent b1897acfd6
commit e3dcd7f2c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -632,7 +632,7 @@ class CausalLM(Model):
model = model.eval().to(device)
#wrap in hpu_graph only if self.enable_hpu_graph is set
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model)
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
self.is_optimized_for_gaudi = True