diff --git a/pyscf/dft/rks.py b/pyscf/dft/rks.py index 349f813ec..17e8006d4 100644 --- a/pyscf/dft/rks.py +++ b/pyscf/dft/rks.py @@ -36,6 +36,17 @@ from pyscf import __config__ _AO_CACHE_MAX_MEMORY = getattr(__config__, 'dft_rks_ao_cache_max_memory', 512) +def _ao_cache_size_mb(mol, grids, deriv): + ngrids = grids.coords.shape[0] + comp = (deriv + 1) * (deriv + 2) * (deriv + 3) // 6 + return comp * ngrids * mol.nao * numpy.dtype('float64').itemsize / 1e6 + +def _ao_cache_key(mol, grids, nao, deriv, blksize): + non0tab_key = getattr(grids, 'non0tab', None) + return (id(mol), id(grids), grids.coords.ctypes.data, + grids.weights.ctypes.data, None if non0tab_key is None else id(non0tab_key), + grids.coords.shape, grids.weights.shape, nao, deriv, blksize) + def _cached_numint_call(ks, ni, mol, grids, max_memory, fn): coords = getattr(grids, 'coords', None) weights = getattr(grids, 'weights', None) @@ -58,18 +69,14 @@ def _cached_numint_call(ks, ni, mol, grids, max_memory, fn): if nao is None: nao = mol1.nao ngrids = grids1.coords.shape[0] - comp = (deriv + 1) * (deriv + 2) * (deriv + 3) // 6 - cache_mb = comp * ngrids * nao * numpy.dtype('float64').itemsize / 1e6 + cache_mb = _ao_cache_size_mb(mol1, grids1, deriv) max_cache_mb = min(_AO_CACHE_MAX_MEMORY, max_memory * .25) if cache_mb > max_cache_mb: yield from block_loop(mol1, grids1, nao, deriv, max_memory, non0tab, blksize, buf) return - non0tab_key = getattr(grids1, 'non0tab', None) - key = (id(mol1), id(grids1), grids1.coords.ctypes.data, - grids1.weights.ctypes.data, None if non0tab_key is None else id(non0tab_key), - grids1.coords.shape, grids1.weights.shape, nao, deriv, blksize) + key = _ao_cache_key(mol1, grids1, nao, deriv, blksize) blocks = ao_cache.get(key) if blocks is None: blocks = [] @@ -561,7 +568,7 @@ class KohnShamDFT: ground_state = getattr(dm, 'ndim', 0) == 2 if self.grids.coords is None: t0 = (logger.process_clock(), logger.perf_counter()) - self.grids.build(with_non0tab=True) + self.grids.build(with_non0tab=True, sort_grids=False) if self.small_rho_cutoff > 1e-20 and ground_state: # Filter grids the first time setup grids self.grids = prune_small_rho_grids_(self, self.mol, dm,