Skip to content

Commit

Permalink
Merge pull request #202 from pni-lab/198-rework-fieldmap-correction-w…
Browse files Browse the repository at this point in the history
…orkflow-topup

TOPUP fieldmap correction workflow
  • Loading branch information
khoffschlag authored Jan 17, 2025
2 parents 0a0665d + 0350316 commit 49287c1
Show file tree
Hide file tree
Showing 4 changed files with 722 additions and 7 deletions.
222 changes: 221 additions & 1 deletion PUMI/pipelines/func/deconfound.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,87 @@
from PUMI.plot.carpet_plot import plot_carpet


@QcPipeline(inputspec_fields=['main', 'fmap', 'func_corrected'],
outputspec_fields=['out_file'])
def qc_fieldmap_correction_topup(wf, volume='first', **kwargs):
"""
Generate quality control image for the fieldmap correction consisting of a montage image
comparing a main volume, a fieldmap volume and a volume of the corrected fieldmap.
Parameters:
volume (str): The volume of the functional data to be used for comparison (e.g., 'middle').
Default is 'first'.
Inputs:
main (str): Path to the main sequence functional image (e.g., functional data).
fmap (str): Path to the fieldmap image (e.g., uncorrected fieldmap data).
func_corrected (str): Path to the fieldmap-corrected functional image.
Outputs:
out_file (str): Path to the saved QC montage image comparing the original and corrected images.
Sinking:
- Path to QC comparison image (PNG file showing the original and corrected volumes).
"""

def create_montage(vol_main, vol_fmap, vol_corrected, n_slices=3):
from matplotlib import pyplot as plt
from pathlib import Path
from nilearn import plotting
import os

def get_cut_cords(func, n_slices=3):
import nibabel as nib
import numpy as np

func_img = nib.load(func)
y_dim = func_img.shape[1] # y-dimension (coronal direction) is the second dimension in the image shape

slices = np.linspace(-y_dim / 2, y_dim / 2, n_slices)
# slices might contain floats but this is not a problem since nilearn will round floats to the
# nearest integer value!
return slices

fig, axes = plt.subplots(3, 1, facecolor='black', figsize=(12, 18))
plt.subplots_adjust(hspace=0.4)
plotting.plot_anat(vol_main, display_mode='y', cut_coords=get_cut_cords(vol_main, n_slices=n_slices),
title='Image #1', black_bg=True, axes=axes[0])
plotting.plot_anat(vol_fmap, display_mode='y', cut_coords=get_cut_cords(vol_fmap, n_slices=n_slices),
title='Image #2', black_bg=True, axes=axes[1])
plotting.plot_anat(vol_corrected, display_mode='y', cut_coords=get_cut_cords(vol_corrected, n_slices=n_slices),
title='Corrected', black_bg=True, axes=axes[2])

#path = Path.cwd() / 'fieldmap_correction_comparison.png'
path = os.path.join(os.getcwd(), 'fieldmap_correction_comparison.png')
plt.savefig(path, dpi=300)
plt.close(fig)
return path

vol_main = pick_volume('vol_main', volume=volume)
wf.connect('inputspec', 'main', vol_main, 'in_file')

vol_fmap = pick_volume('vol_fmap', volume=volume)
wf.connect('inputspec', 'fmap', vol_fmap, 'in_file')

vol_corrected = pick_volume('vol_corrected', volume=volume)
wf.connect('inputspec', 'func_corrected', vol_corrected, 'in_file')

montage = Node(Function(
input_names=['vol_main', 'vol_fmap', 'vol_corrected'],
output_names=['out_file'],
function=create_montage),
name='montage_node'
)
wf.connect(vol_main, 'out_file', montage, 'vol_main')
wf.connect(vol_fmap, 'out_file', montage, 'vol_fmap')
wf.connect(vol_corrected, 'out_file', montage, 'vol_corrected')

wf.connect(montage, 'out_file', 'outputspec', 'out_file')
wf.connect(montage, 'out_file', 'sinker', 'qc_fieldmap_correction')


@QcPipeline(inputspec_fields=['background', 'overlay'],
outputspec_fields=['out_file'])
def qc_fieldmap_correction_fugue(wf, overlay_volume='middle', **kwargs):
Expand Down Expand Up @@ -57,7 +138,146 @@ def create_fieldmap_plot(overlay, background):
# output
wf.connect(plot, 'out_file', 'outputspec', 'out_file')


@FuncPipeline(inputspec_fields=['main', 'main_json', 'fmap', 'fmap_json'],
outputspec_fields=['out_file'])
def fieldmap_correction_topup(wf, num_volumes=5, **kwargs):
"""
Perform fieldmap correction on the functional data using FSL's TOPUP.
Parameters:
num_volumes (int): Number of volumes to extract from the main functional sequence for averaging.
Default is 5.
Inputs:
main (str): Path to the main functional image (e.g., 4D functional MRI data).
main_json (str): Path to the JSON metadata for the main sequence.
fmap (str): Path to the fieldmap image (e.g., 4D fieldmap data).
fmap_json (str): Path to the JSON metadata for the fieldmap sequence.
Outputs:
out_file (str): Path to the corrected 4D functional image after fieldmap correction.
Sinking:
- Corrected functional sequence after fieldmap correction.
- QC results for fieldmap correction.
"""

# Extract how many volumes from the main sequence we are told to extract
num_volumes = int(wf.cfg_parser.get('FIELDMAP-CORRECTION', 'num_volumes', fallback=num_volumes))

# Extract the first num_volumes volumes from main sequence
extract_main_volumes = Node(fsl.ExtractROI(t_min=0, t_size=num_volumes), name='extract_main_volumes')
wf.connect('inputspec', 'main', extract_main_volumes, 'in_file')

# Compute the mean of extracted main volumes
mean_main = Node(fsl.MeanImage(), name='mean_main')
wf.connect(extract_main_volumes, 'roi_file', mean_main, 'in_file')

# Average all fieldmap volumes
mean_fmap = Node(fsl.MeanImage(), name='mean_fmap')
wf.connect('inputspec', 'fmap', mean_fmap, 'in_file')

# Retrieve encoding direction, total readout time and repetition time
def retrieve_image_params_function(main_json, fmap_json):
import json

with open(main_json, 'r') as f:
main_metadata = json.load(f)

with open(fmap_json, 'r') as f:
fmap_metadata = json.load(f)

for key in ['PhaseEncodingDirection', 'TotalReadoutTime', 'RepetitionTime']:
main_value = main_metadata.get(key, None)
fmap_value = fmap_metadata.get(key, None)

if main_value is None:
raise ValueError(f'JSON of main sequence is missing the key {key}!')

if fmap_value is None:
raise ValueError(f'JSON of fieldmap sequence is missing the key {key}!')

main_encoding_direction = main_metadata.get('PhaseEncodingDirection')
main_total_readout_time = main_metadata.get('TotalReadoutTime')
main_repetition_time = main_metadata.get('RepetitionTime')

fmap_encoding_direction = fmap_metadata.get('PhaseEncodingDirection')
fmap_total_readout_time = fmap_metadata.get('TotalReadoutTime')
fmap_repetition_time = fmap_metadata.get('RepetitionTime')

if main_encoding_direction == fmap_encoding_direction:
raise ValueError(f'Encoding direction of main sequence and fieldmap sequence are not allowed to be the same, but found {main_encoding_direction} and {fmap_encoding_direction}!')

if main_total_readout_time != fmap_total_readout_time:
raise ValueError(f'TRT of main sequence IS NOT EQUAL to fieldmap TRT ({main_total_readout_time}) != {fmap_total_readout_time})')

if main_repetition_time != fmap_repetition_time:
raise ValueError(f'TR of main sequence IS NOT EQUAL to fieldmap TR ({main_repetition_time}) != {fmap_repetition_time})')

# In case we have Siemens j-notation instead of y, replace j by y
main_encoding_direction = main_encoding_direction.replace('j', 'y')
fmap_encoding_direction = fmap_encoding_direction.replace('j', 'y')

encoding_direction = [main_encoding_direction, fmap_encoding_direction]
total_readout_time = main_total_readout_time
repetition_time = main_repetition_time

return encoding_direction, total_readout_time, repetition_time

retrieve_image_params = Node(
utility.Function(
input_names=['main_json', 'fmap_json'],
output_names=['encoding_direction', 'total_readout_time', 'repetition_time'],
function=retrieve_image_params_function
),
name='retrieve_image_params'
)
wf.connect('inputspec', 'main_json', retrieve_image_params, 'main_json')
wf.connect('inputspec', 'fmap_json', retrieve_image_params, 'fmap_json')

def combine_items_to_list(item_1, item_2):
return [item_1, item_2]

avg_volumes_to_list = Node(Function(
input_names=['item_1', 'item_2'],
output_names=['output'],
function=combine_items_to_list),
name='avg_volumes_to_list'
)
wf.connect(mean_main, 'out_file', avg_volumes_to_list, 'item_1')
wf.connect(mean_fmap, 'out_file', avg_volumes_to_list, 'item_2')

# Combine averaged main and averaged fieldmap into a 4D image
merge_avg_images = Node(fsl.Merge(dimension='t'), name='merge_avg_images')
wf.connect(avg_volumes_to_list, 'output', merge_avg_images, 'in_files')
wf.connect(retrieve_image_params, 'repetition_time', merge_avg_images, 'tr')

# Estimate susceptibility induced distortions
topup = Node(fsl.TOPUP(), name='topup')
wf.connect(merge_avg_images, 'merged_file', topup, 'in_file')
wf.connect(retrieve_image_params, 'total_readout_time', topup, 'readout_times')
wf.connect(retrieve_image_params, 'encoding_direction', topup, 'encoding_direction')

# Apply result of fsl.TOPUP to our original data
# Result will be one 4D distortion corrected image
apply_topup = Node(fsl.ApplyTOPUP(method='jac'), name='apply_topup')
wf.connect('inputspec', 'main', apply_topup, 'in_files')
wf.connect(topup, 'out_fieldcoef', apply_topup, 'in_topup_fieldcoef')
wf.connect(topup, 'out_movpar', apply_topup, 'in_topup_movpar')
wf.connect(topup, 'out_enc_file', apply_topup, 'encoding_file')

qc_fieldmap_correction = qc_fieldmap_correction_topup('qc_fieldmap_correction')
wf.connect('inputspec', 'main', qc_fieldmap_correction, 'main')
wf.connect('inputspec', 'fmap', qc_fieldmap_correction, 'fmap')
wf.connect(topup, 'out_corrected', qc_fieldmap_correction, 'func_corrected')

wf.connect(apply_topup, 'out_corrected', 'outputspec', 'out_file')
wf.connect(apply_topup, 'out_corrected', 'sinker', 'out_file')


@FuncPipeline(inputspec_fields=['main_img', 'main_json', 'anat_img', 'phasediff_img', 'phasediff_json',
'magnitude_img'],
outputspec_fields=['out_file'])
Expand Down Expand Up @@ -142,7 +362,7 @@ def get_fieldmap_parameters(main_json, phasediff_json):
wf.connect('inputspec', 'anat_img', qc, 'background')

wf.connect(fugue, 'unwarped_file', 'outputspec', 'out_file')


@FuncPipeline(inputspec_fields=['in_file'],
outputspec_fields=['out_file'])
Expand Down
3 changes: 3 additions & 0 deletions PUMI/settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ save_mask = 1
# set overwrite_existing to '1' to overwrite existing predictions otherwise set to '0'
overwrite_existing = 1

[FIELDMAP-CORRECTION]
num_volumes = 5

[TEMPLATES]
head = data/standard/MNI152_T1_2mm.nii.gz
#also okay: head = tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz; source=templateflow
Expand Down
Loading

0 comments on commit 49287c1

Please sign in to comment.