Iteration 0019 — 8d8a60d41b1b (accepted)

GitHub commit: 8d8a60d41b1b Published branch: fermilink-optimize/pyscf-diis_scf

Change summary

Reduce cached-AO DFT NumInt overhead with per-call molecule shell-metadata reuse, RKS cached-density plain-ndarray routing, and conservative fused UKS GGA tagged-MO alpha/beta rho2 contractions.

Acceptance rationale

Correctness passed and the 2.06% improvement over incumbent clears the 2% threshold without persistent caching or answer replay.

Guardrails & metrics

field

value

decision

ACCEPTED

correctness

ok

correctness mode

field_tolerances

hard reject

no

guardrail errors

0

incumbent commit

14eae176b837

candidate commit

8d8a60d41b1b

incumbent metric

0.82932

candidate metric

0.812246

baseline metric

1.22637

Δ vs incumbent

+2.059% (lower-is-better sign)

changed files

pyscf/dft/rks.py, pyscf/dft/uks.py

Diffstat

pyscf/dft/rks.py |  45 +++++++++++++++++-
pyscf/dft/uks.py | 137 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-
2 files changed, 178 insertions(+), 4 deletions(-)

Diff

download full diff

diff --git a/pyscf/dft/rks.py b/pyscf/dft/rks.py
index 17e8006d4..29f0acaf9 100644
--- a/pyscf/dft/rks.py
+++ b/pyscf/dft/rks.py
@@ -47,6 +47,11 @@ def _ao_cache_key(mol, grids, nao, deriv, blksize):
             grids.weights.ctypes.data, None if non0tab_key is None else id(non0tab_key),
             grids.coords.shape, grids.weights.shape, nao, deriv, blksize)

+def _numint_dm_for_cached_ao(ks, dm):
+    if getattr(dm, 'mo_coeff', None) is not None and getattr(ks, '_numint_ao_cache', None):
+        return numpy.asarray(dm)
+    return dm
+
 def _cached_numint_call(ks, ni, mol, grids, max_memory, fn):
     coords = getattr(grids, 'coords', None)
     weights = getattr(grids, 'weights', None)
@@ -58,6 +63,32 @@ def _cached_numint_call(ks, ni, mol, grids, max_memory, fn):
         ao_cache = ks._numint_ao_cache = {}

     block_loop = ni.block_loop
+    get_overlap_cond = mol.get_overlap_cond
+    ao_loc_nr = mol.ao_loc_nr
+    overlap_cond = [None]
+    ao_locs = {}
+
+    def cached_get_overlap_cond(shls_slice=None):
+        if shls_slice is None:
+            if overlap_cond[0] is None:
+                overlap_cond[0] = get_overlap_cond()
+            return overlap_cond[0]
+        return get_overlap_cond(shls_slice)
+
+    def cached_ao_loc_nr(cart=None):
+        if cart not in ao_locs:
+            ao_locs[cart] = ao_loc_nr(cart)
+        return ao_locs[cart]
+
+    sentinel = object()
+    mol_dict = getattr(mol, '__dict__', None)
+    old_get_overlap_cond = old_ao_loc_nr = sentinel
+    patched_mol = mol_dict is not None
+    if patched_mol:
+        old_get_overlap_cond = mol_dict.get('get_overlap_cond', sentinel)
+        old_ao_loc_nr = mol_dict.get('ao_loc_nr', sentinel)
+        mol.get_overlap_cond = cached_get_overlap_cond
+        mol.ao_loc_nr = cached_ao_loc_nr

     def cached_block_loop(mol1, grids1, nao=None, deriv=0, max_memory=2000,
                           non0tab=None, blksize=None, buf=None):
@@ -97,6 +128,15 @@ def _cached_numint_call(ks, ni, mol, grids, max_memory, fn):
         return fn()
     finally:
         ni.block_loop = block_loop
+        if patched_mol:
+            if old_get_overlap_cond is sentinel:
+                del mol.get_overlap_cond
+            else:
+                mol.get_overlap_cond = old_get_overlap_cond
+            if old_ao_loc_nr is sentinel:
+                del mol.ao_loc_nr
+            else:
+                mol.ao_loc_nr = old_ao_loc_nr

 def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
     '''Coulomb + XC functional
@@ -141,9 +181,10 @@ def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
         n, exc, vxc = 0, 0, 0
     else:
         max_memory = ks.max_memory - lib.current_memory()[0]
+        dm_numint = _numint_dm_for_cached_ao(ks, dm)
         n, exc, vxc = _cached_numint_call(
             ks, ni, mol, ks.grids, max_memory,
-            lambda: ni.nr_rks(mol, ks.grids, ks.xc, dm, max_memory=max_memory))
+            lambda: ni.nr_rks(mol, ks.grids, ks.xc, dm_numint, max_memory=max_memory))
         logger.debug(ks, 'nelec by numeric integration = %s', n)
         if ks.do_nlc():
             if ni.libxc.is_nlc(ks.xc):
@@ -151,7 +192,7 @@ def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
             else:
                 assert ni.libxc.is_nlc(ks.nlc)
                 xc = ks.nlc
-            n, enlc, vnlc = ni.nr_nlc_vxc(mol, ks.nlcgrids, xc, dm,
+            n, enlc, vnlc = ni.nr_nlc_vxc(mol, ks.nlcgrids, xc, dm_numint,
                                           max_memory=max_memory)
             exc += enlc
             vxc += vnlc
diff --git a/pyscf/dft/uks.py b/pyscf/dft/uks.py
index ad4b4ed4d..d04aeb883 100644
--- a/pyscf/dft/uks.py
+++ b/pyscf/dft/uks.py
@@ -25,8 +25,132 @@ import numpy
 from pyscf import lib
 from pyscf.lib import logger
 from pyscf.scf import hf, uhf
+from pyscf.dft import numint
 from pyscf.dft import rks

+def _stock_prefers_rho2(ni, mol, dm, mo_occ):
+    ovlp_cond = mol.get_overlap_cond()
+    dm_cond = mol.condense_to_shell(dm, 'absmax')
+    pair_mask = numpy.exp(-ovlp_cond) * dm_cond > ni.cutoff
+    mo_ao_sparsity = max(0.5 * numpy.sum(mo_occ) / dm.shape[-1], 1e-8)
+    ao_loc = mol.ao_loc_nr()
+    wts = (ao_loc[1:] - ao_loc[:-1]) / ao_loc[-1]
+    return numpy.dot(wts, pair_mask).dot(wts) / mo_ao_sparsity >= 4
+
+def _uks_gga_fused_mo(ni, mol, grids, xc_code, dm, hermi):
+    if hermi != 1 or ni._xc_type(xc_code) != 'GGA':
+        return None
+    if not (isinstance(dm, numpy.ndarray) and dm.ndim == 3 and dm.shape[0] == 2):
+        return None
+    if dm.dtype != numpy.double:
+        return None
+    mo_coeff = getattr(dm, 'mo_coeff', None)
+    mo_occ = getattr(dm, 'mo_occ', None)
+    if mo_coeff is None or mo_occ is None:
+        return None
+    if len(mo_coeff) != 2 or len(mo_occ) != 2:
+        return None
+    if mo_coeff[0].ndim != 2 or mo_coeff[1].ndim != 2:
+        return None
+    if mo_coeff[0].shape[0] != dm.shape[-1] or mo_coeff[1].shape[0] != dm.shape[-1]:
+        return None
+    if numpy.any(mo_occ[0] < -numint.OCCDROP) or numpy.any(mo_occ[1] < -numint.OCCDROP):
+        return None
+    if not (_stock_prefers_rho2(ni, mol, dm[0], mo_occ[0]) and
+            _stock_prefers_rho2(ni, mol, dm[1], mo_occ[1])):
+        return None
+    return mo_coeff, mo_occ
+
+def _occupied_coeff(mo_coeff, mo_occ):
+    pos = mo_occ > numint.OCCDROP
+    if numpy.any(pos):
+        return numpy.einsum('ij,j->ij', mo_coeff[:,pos], numpy.sqrt(mo_occ[pos]))
+    return None
+
+def _contract_spin_rho(c0, c1, ngrids, scale=1):
+    if c0.shape[1] == 0:
+        return numpy.zeros(ngrids)
+    rho = numint._contract_rho(c0, c1)
+    if scale != 1:
+        rho *= scale
+    return rho
+
+def _eval_rho2_gga_pair(mol, ao, mo_coeff, mo_occ, non0tab=None):
+    ngrids = ao.shape[-2]
+    cposa = _occupied_coeff(mo_coeff[0], mo_occ[0])
+    cposb = _occupied_coeff(mo_coeff[1], mo_occ[1])
+    if cposa is None and cposb is None:
+        return numpy.zeros((4, ngrids)), numpy.zeros((4, ngrids))
+
+    if cposa is None:
+        nmoa = 0
+        cpos = cposb
+    elif cposb is None:
+        nmoa = cposa.shape[1]
+        cpos = cposa
+    else:
+        nmoa = cposa.shape[1]
+        cpos = numpy.hstack((cposa, cposb))
+
+    shls_slice = (0, mol.nbas)
+    ao_loc = mol.ao_loc_nr()
+    rhoa = numpy.empty((4, ngrids))
+    rhob = numpy.empty((4, ngrids))
+
+    c0 = numint._dot_ao_dm(mol, ao[0], cpos, non0tab, shls_slice, ao_loc)
+    c0a = c0[:, :nmoa]
+    c0b = c0[:, nmoa:]
+    rhoa[0] = _contract_spin_rho(c0a, c0a, ngrids)
+    rhob[0] = _contract_spin_rho(c0b, c0b, ngrids)
+    for i in range(1, 4):
+        c1 = numint._dot_ao_dm(mol, ao[i], cpos, non0tab, shls_slice, ao_loc)
+        rhoa[i] = _contract_spin_rho(c0a, c1[:, :nmoa], ngrids, 2)
+        rhob[i] = _contract_spin_rho(c0b, c1[:, nmoa:], ngrids, 2)
+    return rhoa, rhob
+
+def _nr_uks_gga_fused(ni, mol, grids, xc_code, dm, hermi=1, max_memory=2000):
+    fused_mo = _uks_gga_fused_mo(ni, mol, grids, xc_code, dm, hermi)
+    if fused_mo is None:
+        return None
+    mo_coeff, mo_occ = fused_mo
+
+    nao = dm.shape[-1]
+    ao_loc = mol.ao_loc_nr()
+    cutoff = grids.cutoff * 1e2
+    nbins = numint.NBINS * 2 - int(numint.NBINS * numpy.log(cutoff) /
+                                   numpy.log(grids.cutoff))
+    pair_mask = mol.get_overlap_cond() < -numpy.log(ni.cutoff)
+
+    nelec = numpy.zeros(2)
+    excsum = 0
+    vmat = numpy.zeros((2, nao, nao))
+    aow = None
+    for ao, mask, weight, coords in ni.block_loop(mol, grids, nao, 1,
+                                                  max_memory=max_memory):
+        rho_a, rho_b = _eval_rho2_gga_pair(mol, ao, mo_coeff, mo_occ, mask)
+        exc, vxc = ni.eval_xc_eff(xc_code, (rho_a, rho_b), deriv=1,
+                                  xctype='GGA', spin=1)[:2]
+        den_a = rho_a[0] * weight
+        den_b = rho_b[0] * weight
+        nelec[0] += den_a.sum()
+        nelec[1] += den_b.sum()
+        excsum += numpy.dot(den_a, exc)
+        excsum += numpy.dot(den_b, exc)
+
+        wv = weight * vxc
+        wv[:,0] *= .5
+        wva, wvb = wv
+        aow = numint._scale_ao_sparse(ao, wva, mask, ao_loc, out=aow)
+        numint._dot_ao_ao_sparse(ao[0], aow, None, nbins, mask, pair_mask,
+                                 ao_loc, hermi=0, out=vmat[0])
+        aow = numint._scale_ao_sparse(ao, wvb, mask, ao_loc, out=aow)
+        numint._dot_ao_ao_sparse(ao[0], aow, None, nbins, mask, pair_mask,
+                                 ao_loc, hermi=0, out=vmat[1])
+    vmat = lib.hermi_sum(vmat, axes=(0,2,1))
+    if vmat.dtype != dm.dtype:
+        vmat = numpy.asarray(vmat, dtype=dm.dtype)
+    return nelec, excsum, vmat
+
 def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
     '''Coulomb + XC functional for UKS.  See pyscf/dft/rks.py
     :func:`get_veff` fore more details.
@@ -49,9 +173,17 @@ def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
         n, exc, vxc = (0,0), 0, 0
     else:
         max_memory = ks.max_memory - lib.current_memory()[0]
+        dm_numint = rks._numint_dm_for_cached_ao(ks, dm)
+        def nr_uks_call():
+            out = _nr_uks_gga_fused(ni, mol, ks.grids, ks.xc, dm, hermi,
+                                    max_memory=max_memory)
+            if out is not None:
+                return out
+            return ni.nr_uks(mol, ks.grids, ks.xc, dm_numint,
+                             max_memory=max_memory)
         n, exc, vxc = rks._cached_numint_call(
             ks, ni, mol, ks.grids, max_memory,
-            lambda: ni.nr_uks(mol, ks.grids, ks.xc, dm, max_memory=max_memory))
+            nr_uks_call)
         logger.debug(ks, 'nelec by numeric integration = %s', n)
         if ks.do_nlc():
             if ni.libxc.is_nlc(ks.xc):
@@ -59,7 +191,8 @@ def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
             else:
                 assert ni.libxc.is_nlc(ks.nlc)
                 xc = ks.nlc
-            n, enlc, vnlc = ni.nr_nlc_vxc(mol, ks.nlcgrids, xc, dm[0]+dm[1],
+            n, enlc, vnlc = ni.nr_nlc_vxc(mol, ks.nlcgrids, xc,
+                                          dm_numint[0]+dm_numint[1],
                                           max_memory=max_memory)
             exc += enlc
             vxc += vnlc