diff --git a/pyscf/dft/rks.py b/pyscf/dft/rks.py index fbb2c9357..349f813ec 100644 --- a/pyscf/dft/rks.py +++ b/pyscf/dft/rks.py @@ -34,6 +34,63 @@ from pyscf.dft import gen_grid from pyscf.dft import numint from pyscf import __config__ +_AO_CACHE_MAX_MEMORY = getattr(__config__, 'dft_rks_ao_cache_max_memory', 512) + +def _cached_numint_call(ks, ni, mol, grids, max_memory, fn): + coords = getattr(grids, 'coords', None) + weights = getattr(grids, 'weights', None) + if coords is None or weights is None or max_memory <= 0: + return fn() + + ao_cache = getattr(ks, '_numint_ao_cache', None) + if ao_cache is None: + ao_cache = ks._numint_ao_cache = {} + + block_loop = ni.block_loop + + def cached_block_loop(mol1, grids1, nao=None, deriv=0, max_memory=2000, + non0tab=None, blksize=None, buf=None): + if mol1 is not mol or grids1 is not grids or non0tab is not None or buf is not None: + yield from block_loop(mol1, grids1, nao, deriv, max_memory, + non0tab, blksize, buf) + return + + if nao is None: + nao = mol1.nao + ngrids = grids1.coords.shape[0] + comp = (deriv + 1) * (deriv + 2) * (deriv + 3) // 6 + cache_mb = comp * ngrids * nao * numpy.dtype('float64').itemsize / 1e6 + max_cache_mb = min(_AO_CACHE_MAX_MEMORY, max_memory * .25) + if cache_mb > max_cache_mb: + yield from block_loop(mol1, grids1, nao, deriv, max_memory, + non0tab, blksize, buf) + return + + non0tab_key = getattr(grids1, 'non0tab', None) + key = (id(mol1), id(grids1), grids1.coords.ctypes.data, + grids1.weights.ctypes.data, None if non0tab_key is None else id(non0tab_key), + grids1.coords.shape, grids1.weights.shape, nao, deriv, blksize) + blocks = ao_cache.get(key) + if blocks is None: + blocks = [] + for ao, mask, weight, coords in block_loop(mol1, grids1, nao, deriv, + max_memory, non0tab, + blksize, buf): + if mask is not None: + mask = numpy.array(mask, copy=True) + blocks.append((numpy.array(ao, copy=True), mask, weight, coords)) + ao_cache.clear() + ao_cache[key] = blocks + + for ao, mask, weight, coords in blocks: + yield ao, mask, weight, coords + + ni.block_loop = cached_block_loop + try: + return fn() + finally: + ni.block_loop = block_loop + def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1): '''Coulomb + XC functional @@ -77,7 +134,9 @@ def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1): n, exc, vxc = 0, 0, 0 else: max_memory = ks.max_memory - lib.current_memory()[0] - n, exc, vxc = ni.nr_rks(mol, ks.grids, ks.xc, dm, max_memory=max_memory) + n, exc, vxc = _cached_numint_call( + ks, ni, mol, ks.grids, max_memory, + lambda: ni.nr_rks(mol, ks.grids, ks.xc, dm, max_memory=max_memory)) logger.debug(ks, 'nelec by numeric integration = %s', n) if ks.do_nlc(): if ni.libxc.is_nlc(ks.xc): @@ -343,6 +402,7 @@ class KohnShamDFT: ################################################## # don't modify the following attributes, they are not input options self._numint = numint.NumInt() + self._numint_ao_cache = {} @property def omega(self): @@ -482,6 +542,7 @@ class KohnShamDFT: hf.SCF.reset(self, mol) self.grids.reset(mol) self.nlcgrids.reset(mol) + self._numint_ao_cache = {} return self def check_sanity(self): diff --git a/pyscf/dft/uks.py b/pyscf/dft/uks.py index 3c47e4447..ad4b4ed4d 100644 --- a/pyscf/dft/uks.py +++ b/pyscf/dft/uks.py @@ -49,7 +49,9 @@ def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1): n, exc, vxc = (0,0), 0, 0 else: max_memory = ks.max_memory - lib.current_memory()[0] - n, exc, vxc = ni.nr_uks(mol, ks.grids, ks.xc, dm, max_memory=max_memory) + n, exc, vxc = rks._cached_numint_call( + ks, ni, mol, ks.grids, max_memory, + lambda: ni.nr_uks(mol, ks.grids, ks.xc, dm, max_memory=max_memory)) logger.debug(ks, 'nelec by numeric integration = %s', n) if ks.do_nlc(): if ni.libxc.is_nlc(ks.xc):