Skip to content

Commit

Permalink
clearer trexio eri interface
Browse files Browse the repository at this point in the history
  • Loading branch information
kgasperich committed Jan 20, 2025
1 parent 0c25868 commit 9403924
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions pyscf/tools/trexio.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyscf import scf
from pyscf import pbc
from pyscf import fci
from pyscf import ao2mo

import trexio

Expand Down Expand Up @@ -307,7 +308,8 @@ def scf_from_trexio(filename):
mf.mo_occ = mo_occ
return mf

def write_eri(eri, filename, backend='h5'):
def write_eri(eri, filename, backend='h5', basis='mo'):
assert basis.upper() in ['MO','AO']
num_integrals = eri.size
if eri.ndim == 4:
n = eri.shape[0]
Expand All @@ -330,15 +332,31 @@ def write_eri(eri, filename, backend='h5'):
idx = idx[np.tril_indices(npair)]

with trexio.File(filename, 'w', back_end=_mode(backend)) as tf:
trexio.write_mo_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel())
if basis.upper() == 'MO':
trexio.write_mo_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel())
else:
trexio.write_ao_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel())

def write_scf_eri(mf, filename, backend='h5', basis='mo'):
assert basis.upper() in ['MO','AO']
if basis.upper() == 'MO':
write_eri(ao2mo.kernel(mf._eri, mf.mo_coeff), filename, backend, basis)
else:
write_eri(mf._eri, filename, backend, basis)

def read_eri(filename):

def read_eri(filename, basis='mo'):
'''Read ERIs in AO basis, 8-fold symmetry is assumed'''
assert basis.upper() in ['MO','AO']
basis_is_mo = (basis.upper() == 'MO')
with trexio.File(filename, 'r', back_end=trexio.TREXIO_AUTO) as tf:
nmo = trexio.read_mo_num(tf)
nao_pair = nmo * (nmo+1) // 2
eri_size = nao_pair * (nao_pair+1) // 2
idx, data, n_read, eof_flag = trexio.read_mo_2e_int_eri(tf, 0, eri_size)
norb = trexio.read_mo_num(tf) if basis_is_mo else trexio.read_ao_num(tf)
norb_pair = norb * (norb+1) // 2
eri_size = norb_pair * (norb_pair+1) // 2
if basis_is_mo:
idx, data, n_read, eof_flag = trexio.read_mo_2e_int_eri(tf, 0, eri_size)
else:
idx, data, n_read, eof_flag = trexio.read_ao_2e_int_eri(tf, 0, eri_size)
eri = np.zeros(eri_size)
x = idx[:,0]*(idx[:,0]+1)//2 + idx[:,1]
y = idx[:,2]*(idx[:,2]+1)//2 + idx[:,3]
Expand Down

0 comments on commit 9403924

Please sign in to comment.