diff --git a/pyscf/mcscf/newton_casscf.py b/pyscf/mcscf/newton_casscf.py index da2c677e4..b695365d3 100644 --- a/pyscf/mcscf/newton_casscf.py +++ b/pyscf/mcscf/newton_casscf.py @@ -29,6 +29,7 @@ from pyscf.lib import logger from pyscf.mcscf import casci, mc1step, addons from pyscf.mcscf.casci import get_fock, cas_natorb, canonicalize from pyscf import scf +from pyscf import df from pyscf.soscf import ciah def _pack_ci_get_H (mc, mo, ci0): @@ -851,6 +852,29 @@ class CASSCF(mc1step.CASSCF): e_tot, e_tot-elast, ss[0]) return e_tot, e_cas, fcivec + def update_jk_in_ah(self, mo, r, casdm1, eris): + ncore = self.ncore + ncas = self.ncas + nocc = ncore + ncas + + dm3 = reduce(numpy.dot, (mo[:,:ncore], r[:ncore,ncore:], mo[:,ncore:].T)) + dm3 = dm3 + dm3.T + dm4 = reduce(numpy.dot, (mo[:,ncore:nocc], casdm1, r[ncore:nocc], mo.T)) + dm4 = dm4 + dm4.T + + with_df = getattr(self, '_ah_jk_df', None) + if with_df is None or with_df.mol is not self.mol: + with_df = df.DF(self.mol) + with_df.max_memory = self.max_memory + with_df.stdout = self.stdout + with_df.verbose = self.verbose + self._ah_jk_df = with_df + + vj, vk = with_df.get_jk((dm3, dm3*2+dm4), hermi=1) + va = reduce(numpy.dot, (casdm1, mo[:,ncore:nocc].T, vj[0]*2-vk[0], mo)) + vc = reduce(numpy.dot, (mo[:,:ncore].T, vj[1]*2-vk[1], mo[:,ncore:])) + return va, vc + def update_ao2mo(self, mo): raise DeprecationWarning('update_ao2mo was obsoleted since pyscf v1.0. ' 'Use .ao2mo method instead')