Skip to content

Commit

Permalink
Merge pull request #5 from issp-center-dev/update.main
Browse files Browse the repository at this point in the history
refactoring main function
  • Loading branch information
aoymt authored Oct 31, 2024
2 parents 0a0b7a6 + 0db5975 commit 8b5611f
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 80 deletions.
1 change: 1 addition & 0 deletions src/odatse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
from ._runner import Runner
from . import algorithm
from ._main import main
from ._initialize import initialize

__version__ = "3.0-dev"
49 changes: 49 additions & 0 deletions src/odatse/_initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import odatse

def initialize():
"""
Initialize for main function by parsing commandline arguments and loading input files
Returns
-------
Tuple(Info, str)
an Info object having parameter values, and a run_mode string
"""
import argparse

parser = argparse.ArgumentParser(
description=(
"Data-analysis software of quantum beam "
"diffraction experiments for 2D material structure"
)
)
parser.add_argument("inputfile", help="input file with TOML format")
parser.add_argument("--version", action="version", version=odatse.__version__)

mode_group = parser.add_mutually_exclusive_group()
mode_group.add_argument("--init", action="store_true", help="initial start (default)")
mode_group.add_argument("--resume", action="store_true", help="resume intterupted run")
mode_group.add_argument("--cont", action="store_true", help="continue from previous run")

parser.add_argument("--reset_rand", action="store_true", default=False, help="new random number series in resume or continue mode")

args = parser.parse_args()


if args.init is True:
run_mode = "initial"
elif args.resume is True:
run_mode = "resume"
if args.reset_rand is True:
run_mode = "resume-resetrand"
elif args.cont is True:
run_mode = "continue"
if args.reset_rand is True:
run_mode = "continue-resetrand"
else:
run_mode = "initial" # default

info = odatse.Info.from_file(args.inputfile)
# info.algorithm.update({"run_mode": run_mode})

return info, run_mode
89 changes: 9 additions & 80 deletions src/odatse/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,101 +14,30 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from sys import exit

import sys
import odatse
import odatse.mpi
import odatse.util.toml


def main():
"""
Main function to run the data-analysis software for quantum beam diffraction experiments
on 2D material structures. It parses command-line arguments, loads the input file,
selects the appropriate algorithm and solver, and executes the analysis.
"""
import argparse

parser = argparse.ArgumentParser(
description=(
"Data-analysis software of quantum beam "
"diffraction experiments for 2D material structure"
)
)
parser.add_argument("inputfile", help="input file with TOML format")
parser.add_argument("--version", action="version", version=odatse.__version__)

mode_group = parser.add_mutually_exclusive_group()
mode_group.add_argument("--init", action="store_true", help="initial start (default)")
mode_group.add_argument("--resume", action="store_true", help="resume intterupted run")
mode_group.add_argument("--cont", action="store_true", help="continue from previous run")

parser.add_argument("--reset_rand", action="store_true", default=False, help="new random number series in resume or continue mode")

args = parser.parse_args()

file_name = args.inputfile
# inp = {}
# if odatse.mpi.rank() == 0:
# inp = odatse.util.toml.load(file_name)
# if odatse.mpi.size() > 1:
# inp = odatse.mpi.comm().bcast(inp, root=0)
# info = odatse.Info(inp)
info = odatse.Info.from_file(file_name)
info, run_mode = odatse.initialize()

algname = info.algorithm["name"]
if algname == "mapper":
from .algorithm.mapper_mpi import Algorithm
elif algname == "minsearch":
from .algorithm.min_search import Algorithm
elif algname == "exchange":
from .algorithm.exchange import Algorithm
elif algname == "pamc":
from .algorithm.pamc import Algorithm
elif algname == "bayes":
from .algorithm.bayes import Algorithm
else:
print(f"ERROR: Unknown algorithm ({algname})")
exit(1)
alg_module = odatse.algorithm.choose_algorithm(info.algorithm["name"])

solvername = info.solver["name"]
if solvername == "surface":
if odatse.mpi.rank() == 0:
print(
'WARNING: solver name "surface" is deprecated and will be unavailable in future.'
' Use "sim-trhepd-rheed" instead.'
)
#from .solver.sim_trhepd_rheed import Solver
from sim_trhepd_rheed import Solver
elif solvername == "sim-trhepd-rheed":
#from .solver.sim_trhepd_rheed import Solver
from sim_trhepd_rheed import Solver
elif solvername == "sxrd":
#from .solver.sxrd import Solver
from sxrd import Solver
elif solvername == "leed":
#from .solver.leed import Solver
from leed import Solver
elif solvername == "analytical":
if solvername == "analytical":
from .solver.analytical import Solver
else:
print(f"ERROR: Unknown solver ({solvername})")
exit(1)

if args.init is True:
run_mode = "initial"
elif args.resume is True:
run_mode = "resume"
if args.reset_rand is True:
run_mode = "resume-resetrand"
elif args.cont is True:
run_mode = "continue"
if args.reset_rand is True:
run_mode = "continue-resetrand"
else:
run_mode = "initial" # default
if odatse.mpi.rank() == 0:
print(f"ERROR: Unknown solver ({solvername})")
sys.exit(1)

solver = Solver(info)
runner = odatse.Runner(solver, info)
alg = Algorithm(info, runner, run_mode=run_mode)
alg = alg_module.Algorithm(info, runner, run_mode=run_mode)

result = alg.main()
31 changes: 31 additions & 0 deletions src/odatse/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,34 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from ._algorithm import AlgorithmBase

def choose_algorithm(name):
"""
Search for algorithm module by name
Parameters
----------
name : str
name of the algorithm
Returns
-------
module
algorithm module
"""

alg_table = {
"mapper": "mapper_mpi",
"minsearch": "min_search",
}

try:
import importlib
alg_name = "odatse.algorithm.{}".format(alg_table.get(name, name))
alg_module = importlib.import_module(alg_name)
except ModuleNotFoundError as e:
print("ERROR: {}".format(e))
import sys
sys.exit(1)

return alg_module

0 comments on commit 8b5611f

Please sign in to comment.