diff --git a/src/odatse/__init__.py b/src/odatse/__init__.py index 1fd227d5..d90adf9d 100644 --- a/src/odatse/__init__.py +++ b/src/odatse/__init__.py @@ -22,5 +22,6 @@ from ._runner import Runner from . import algorithm from ._main import main +from ._initialize import initialize __version__ = "3.0-dev" diff --git a/src/odatse/_initialize.py b/src/odatse/_initialize.py new file mode 100644 index 00000000..e47dfd4e --- /dev/null +++ b/src/odatse/_initialize.py @@ -0,0 +1,49 @@ +import odatse + +def initialize(): + """ + Initialize for main function by parsing commandline arguments and loading input files + + Returns + ------- + Tuple(Info, str) + an Info object having parameter values, and a run_mode string + """ + import argparse + + parser = argparse.ArgumentParser( + description=( + "Data-analysis software of quantum beam " + "diffraction experiments for 2D material structure" + ) + ) + parser.add_argument("inputfile", help="input file with TOML format") + parser.add_argument("--version", action="version", version=odatse.__version__) + + mode_group = parser.add_mutually_exclusive_group() + mode_group.add_argument("--init", action="store_true", help="initial start (default)") + mode_group.add_argument("--resume", action="store_true", help="resume intterupted run") + mode_group.add_argument("--cont", action="store_true", help="continue from previous run") + + parser.add_argument("--reset_rand", action="store_true", default=False, help="new random number series in resume or continue mode") + + args = parser.parse_args() + + + if args.init is True: + run_mode = "initial" + elif args.resume is True: + run_mode = "resume" + if args.reset_rand is True: + run_mode = "resume-resetrand" + elif args.cont is True: + run_mode = "continue" + if args.reset_rand is True: + run_mode = "continue-resetrand" + else: + run_mode = "initial" # default + + info = odatse.Info.from_file(args.inputfile) + # info.algorithm.update({"run_mode": run_mode}) + + return info, run_mode diff --git a/src/odatse/_main.py b/src/odatse/_main.py index 3efac34b..f18923a2 100644 --- a/src/odatse/_main.py +++ b/src/odatse/_main.py @@ -14,12 +14,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see http://www.gnu.org/licenses/. -from sys import exit - +import sys import odatse -import odatse.mpi -import odatse.util.toml - def main(): """ @@ -27,88 +23,21 @@ def main(): on 2D material structures. It parses command-line arguments, loads the input file, selects the appropriate algorithm and solver, and executes the analysis. """ - import argparse - - parser = argparse.ArgumentParser( - description=( - "Data-analysis software of quantum beam " - "diffraction experiments for 2D material structure" - ) - ) - parser.add_argument("inputfile", help="input file with TOML format") - parser.add_argument("--version", action="version", version=odatse.__version__) - - mode_group = parser.add_mutually_exclusive_group() - mode_group.add_argument("--init", action="store_true", help="initial start (default)") - mode_group.add_argument("--resume", action="store_true", help="resume intterupted run") - mode_group.add_argument("--cont", action="store_true", help="continue from previous run") - - parser.add_argument("--reset_rand", action="store_true", default=False, help="new random number series in resume or continue mode") - - args = parser.parse_args() - file_name = args.inputfile - # inp = {} - # if odatse.mpi.rank() == 0: - # inp = odatse.util.toml.load(file_name) - # if odatse.mpi.size() > 1: - # inp = odatse.mpi.comm().bcast(inp, root=0) - # info = odatse.Info(inp) - info = odatse.Info.from_file(file_name) + info, run_mode = odatse.initialize() - algname = info.algorithm["name"] - if algname == "mapper": - from .algorithm.mapper_mpi import Algorithm - elif algname == "minsearch": - from .algorithm.min_search import Algorithm - elif algname == "exchange": - from .algorithm.exchange import Algorithm - elif algname == "pamc": - from .algorithm.pamc import Algorithm - elif algname == "bayes": - from .algorithm.bayes import Algorithm - else: - print(f"ERROR: Unknown algorithm ({algname})") - exit(1) + alg_module = odatse.algorithm.choose_algorithm(info.algorithm["name"]) solvername = info.solver["name"] - if solvername == "surface": - if odatse.mpi.rank() == 0: - print( - 'WARNING: solver name "surface" is deprecated and will be unavailable in future.' - ' Use "sim-trhepd-rheed" instead.' - ) - #from .solver.sim_trhepd_rheed import Solver - from sim_trhepd_rheed import Solver - elif solvername == "sim-trhepd-rheed": - #from .solver.sim_trhepd_rheed import Solver - from sim_trhepd_rheed import Solver - elif solvername == "sxrd": - #from .solver.sxrd import Solver - from sxrd import Solver - elif solvername == "leed": - #from .solver.leed import Solver - from leed import Solver - elif solvername == "analytical": + if solvername == "analytical": from .solver.analytical import Solver else: - print(f"ERROR: Unknown solver ({solvername})") - exit(1) - - if args.init is True: - run_mode = "initial" - elif args.resume is True: - run_mode = "resume" - if args.reset_rand is True: - run_mode = "resume-resetrand" - elif args.cont is True: - run_mode = "continue" - if args.reset_rand is True: - run_mode = "continue-resetrand" - else: - run_mode = "initial" # default + if odatse.mpi.rank() == 0: + print(f"ERROR: Unknown solver ({solvername})") + sys.exit(1) solver = Solver(info) runner = odatse.Runner(solver, info) - alg = Algorithm(info, runner, run_mode=run_mode) + alg = alg_module.Algorithm(info, runner, run_mode=run_mode) + result = alg.main() diff --git a/src/odatse/algorithm/__init__.py b/src/odatse/algorithm/__init__.py index a4a9c3fa..8a36b10b 100644 --- a/src/odatse/algorithm/__init__.py +++ b/src/odatse/algorithm/__init__.py @@ -15,3 +15,34 @@ # along with this program. If not, see http://www.gnu.org/licenses/. from ._algorithm import AlgorithmBase + +def choose_algorithm(name): + """ + Search for algorithm module by name + + Parameters + ---------- + name : str + name of the algorithm + + Returns + ------- + module + algorithm module + """ + + alg_table = { + "mapper": "mapper_mpi", + "minsearch": "min_search", + } + + try: + import importlib + alg_name = "odatse.algorithm.{}".format(alg_table.get(name, name)) + alg_module = importlib.import_module(alg_name) + except ModuleNotFoundError as e: + print("ERROR: {}".format(e)) + import sys + sys.exit(1) + + return alg_module