Commit 08cd2c9a authored by Seema Mirchandaney's avatar Seema Mirchandaney
Browse files

legion checkpoint related tasks

parent 36d833d7
import numpy as np
import PyNVTX as nvtx
import os
import pygion
from pygion import task, Partition, Region, Tunable, WD, RO
from spinifel import settings, utils, contexts, checkpoint
from . import utils as lgutils
@task(privileges=[WD("ac","support_","rho_"), WD("quaternions")])
@lgutils.gpu_task_wrapper
@nvtx.annotate("legion/checkpoint.py", is_prefix=True)
def checkpoint_load_task(phased, orientations, out_dir, load_gen, tag_gen):
print(f"Loading checkpoint: {checkpoint.generate_checkpoint_name(out_dir, load_gen, tag_gen)}", flush=True)
myRes = checkpoint.load_checkpoint(out_dir,
load_gen,
tag_gen)
# Unpack dictionary
phased.ac[:] = myRes['ac_phased']
phased.support_[:] = myRes['support_']
phased.rho_[:] = myRes['rho_']
orientations.quaternions[:] = myRes['orientations']
''' Create and Fill Regions [phased:{ac,support_,rho}],
[orientations:{quaternions}]
'''
def load_checkpoint(outdir: str, gen_num: int, tag=''):
phased = Region((settings.M,)*3, {
"ac": pygion.float32, "support_": pygion.float32,
"rho_": pygion.float32})
# setup the orientation region
N_images_per_rank = settings.N_images_per_rank
fields_dict = {"quaternions": pygion.float32}
sec_shape = (4,)
orientations, orientations_p = lgutils.create_distributed_region(
N_images_per_rank, fields_dict, sec_shape)
checkpoint_load_task(phased,orientations, outdir,gen_num,tag)
return phased, orientations, orientations_p
'''
Save pixel_position_reciprocal, pixel_distance_reciprocal
Save Regions [slices:{data}],
[solved:{ac}]
'''
@task(privileges=[RO("data"), RO("ac"), RO, RO])
@lgutils.gpu_task_wrapper
@nvtx.annotate("legion/checkpoint.py", is_prefix=True)
def save_checkpoint_solve_ac(
slices, solved,
pixel_position,
pixel_distance,
out_dir: str, gen_num: int, tag=''):
# Pack dictionary
myRes = {
'pixel_position_reciprocal': pixel_position.reciprocal,
'pixel_distance_reciprocal': pixel_distance.reciprocal,
'slices_': slices.data,
'ac': solved.ac
}
checkpoint.save_checkpoint(myRes, out_dir, gen_num, tag)
''' Save Regions [solved:{ac}:float32],
[phased:{ac,support_,rho_}]
'''
@task(privileges=[RO("ac"), RO("ac"), RO("support_"), RO("rho_")])
@lgutils.gpu_task_wrapper
@nvtx.annotate("legion/checkpoint.py", is_prefix=True)
def save_checkpoint_phase(solved, phased,
out_dir: str, gen_num: int, tag=''):
# Pack dictionary
myRes = {
'ac': solved.ac,
'ac_phased': phased.ac,
'support_': phased.support_,
'rho_': phased.rho_
}
checkpoint.save_checkpoint(myRes, out_dir, gen_num, tag)
''' Save Regions [slices:{data}:float32],
[phased:{ac,support_,rho_}]: what about support/rho?
[orientations:{quaternions}]
[pixel_position:{reciprocal}]
[pixel_distance:{reciprocal}]
'''
@task(privileges=[RO("data"), RO("ac"), RO("quaternions"), RO, RO])
@lgutils.gpu_task_wrapper
@nvtx.annotate("legion/checkpoint.py", is_prefix=True)
def save_checkpoint_match(slices, phased, orientations,
pixel_position,
pixel_distance,
out_dir: str, gen_num: int, tag=''):
# Pack dictionary
myRes = {'ac_phased': phased.ac,
'slices_': slices.data,
'pixel_position_reciprocal': pixel_position.reciprocal,
'pixel_distance_reciprocal': pixel_distance.reciprocal,
'orientations': orientations.quaternions
}
checkpoint.save_checkpoint(myRes, out_dir, gen_num, tag)
''' Save Regions [solved:{ac}:float32],
[prev_phased:{prev_rho_}]:
[phased:{ac,support_,rho_}]:
'''
@task(privileges=[RO("ac"), RO("prev_rho_"), RO("ac","support_","rho_")])
@lgutils.gpu_task_wrapper
@nvtx.annotate("legion/checkpoint.py", is_prefix=True)
def save_checkpoint_phase_prev(solved, prev_phased, phased,
prev_support,
out_dir: str, gen_num: int, tag=''):
myRes = {
'ac': solved.ac,
'prev_support_': prev_support,
'prev_rho_': prev_phased.prev_rho_,
'ac_phased': phased.ac,
'support_': phased.support_,
'rho_': phased.rho_
}
checkpoint.save_checkpoint(myRes,out_dir, gen_num, tag)
''' Save Regions [phased:{ac,support_,rho_}]
[orientations:{quaternions}]
'''
@task(privileges=[RO("ac", "support_", "rho_"), RO("quaternions")])
@lgutils.gpu_task_wrapper
@nvtx.annotate("legion/checkpoint.py", is_prefix=True)
def save_checkpoint(phased, orientations,
out_dir: str, gen_num: int, tag=''):
# Pack dictionary
myRes = {
'ac_phased': phased.ac,
'support_': phased.support_,
'rho_': phased.rho_,
'orientations': orientations.quaternions
}
checkpoint.save_checkpoint(myRes, out_dir, gen_num, tag)
......@@ -13,7 +13,7 @@ from .autocorrelation import solve_ac
from .phasing import phase, prev_phase, cov
from .orientation_matching import match
from . import mapper
from . import checkpoint
@task(replicable=True)
@nvtx.annotate("legion/main.py", is_prefix=True)
......@@ -53,26 +53,31 @@ def main():
slices, slices_p) = get_data(ds)
logger.log(f"Loaded in {timer.lap():.2f}s.")
solved = solve_ac(0, pixel_position, pixel_distance, slices, slices_p)
logger.log(f"AC recovered in {timer.lap():.2f}s.")
if settings.load_gen > 0: # Load input from previous generation
curr_gen = settings.load_gen
phased, orientations, orientations_p = checkpoint.load_checkpoint(settings.out_dir, settings.load_gen)
else:
solved = solve_ac(0, pixel_position, pixel_distance, slices, slices_p)
logger.log(f"AC recovered in {timer.lap():.2f}s.")
phased = phase(0, solved)
logger.log(f"Problem phased in {timer.lap():.2f}s.")
phased = phase(0, solved)
logger.log(f"Problem phased in {timer.lap():.2f}s.")
rho = np.fft.ifftshift(phased.rho_)
print('rho =', rho)
rho = np.fft.ifftshift(phased.rho_)
print('rho =', rho)
save_mrc(settings.out_dir / f"ac-0.mrc", phased.ac)
save_mrc(settings.out_dir / f"rho-0.mrc", rho)
save_mrc(settings.out_dir / f"ac-0.mrc", phased.ac)
save_mrc(settings.out_dir / f"rho-0.mrc", rho)
# Use improvement of cc(prev_rho, cur_rho) to dertemine if we should
# terminate the loop
prev_phased = None
cov_xy = 0
cov_delta = .05
curr_gen +=1
N_generations = settings.N_generations
for generation in range(1, N_generations):
for generation in range(curr_gen, N_generations+1):
logger.log(f"#"*27)
logger.log(f"##### Generation {generation}/{N_generations} #####")
logger.log(f"#"*27)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment