diff --git a/CHANGELOG.md b/CHANGELOG.md index 709cced4b680..5fb4867a4a14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Fixed +- Fixed DimeNet pretrained checkpoint loading for compatibility with Tensorflow v2.X + ### Security ## [2.7.0] - 2025-10-14 diff --git a/torch_geometric/nn/models/dimenet.py b/torch_geometric/nn/models/dimenet.py index 09e74c0d5028..d31d19fc2f60 100644 --- a/torch_geometric/nn/models/dimenet.py +++ b/torch_geometric/nn/models/dimenet.py @@ -593,7 +593,7 @@ def from_qm9_pretrained( download_url(f'{url}/ckpt.index', path) path = osp.join(path, 'ckpt') - reader = tf.train.load_checkpoint(path) + reader = tf.compat.v1.train.load_checkpoint(path) model = cls( hidden_channels=128, @@ -865,7 +865,7 @@ def from_qm9_pretrained( download_url(f'{url}/ckpt.index', path) path = osp.join(path, 'ckpt') - reader = tf.train.load_checkpoint(path) + reader = tf.compat.v1.train.load_checkpoint(path) # Configuration from DimeNet++: # https://github.com/gasteigerjo/dimenet/blob/master/config_pp.yaml