Skip to content

Commit

Permalink
Simplify API version validation (#1556)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Refactor**
- Simplified and streamlined the submission processes for various job
types, improving efficiency and reducing redundancy.
- Centralized API version checking with a new utility function,
enhancing maintainability and consistency across the application.

- **Bug Fixes**
- Improved error handling for API versions below 1.0, ensuring clearer
and more consistent error messages.

- **New Features**
- Introduced a new function for API version validation, ensuring
compatibility and proper error handling.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
thangckt and pre-commit-ci[bot] authored May 29, 2024
1 parent 54e48c6 commit c5812fb
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 184 deletions.
153 changes: 69 additions & 84 deletions dpgen/data/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import dpdata
import numpy as np
from packaging.version import Version
from pymatgen.core import Structure
from pymatgen.io.vasp import Incar

Expand All @@ -28,7 +27,7 @@
make_abacus_scf_stru,
make_supercell_abacus,
)
from dpgen.generator.lib.utils import symlink_user_forward_files
from dpgen.generator.lib.utils import check_api_version, symlink_user_forward_files
from dpgen.generator.lib.vasp import incar_upper
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import load_file
Expand Down Expand Up @@ -1158,27 +1157,23 @@ def run_vasp_relax(jdata, mdata):
# relax_run_tasks.append(ii)
run_tasks = [os.path.basename(ii) for ii in relax_run_tasks]

api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit jobs
check_api_version(mdata)

submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()


def coll_abacus_md(jdata):
Expand Down Expand Up @@ -1298,27 +1293,23 @@ def run_abacus_relax(jdata, mdata):
# relax_run_tasks.append(ii)
run_tasks = [os.path.basename(ii) for ii in relax_run_tasks]

api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit jobs
check_api_version(mdata)

submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()


def run_vasp_md(jdata, mdata):
Expand Down Expand Up @@ -1359,27 +1350,24 @@ def run_vasp_md(jdata, mdata):
run_tasks = [ii.replace(work_dir + "/", "") for ii in md_run_tasks]
# dlog.info("md_work_dir", work_dir)
# dlog.info("run_tasks",run_tasks)
api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit jobs
check_api_version(mdata)

submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()


def run_abacus_md(jdata, mdata):
Expand Down Expand Up @@ -1435,27 +1423,24 @@ def run_abacus_md(jdata, mdata):
run_tasks = [ii.replace(work_dir + "/", "") for ii in md_run_tasks]
# dlog.info("md_work_dir", work_dir)
# dlog.info("run_tasks",run_tasks)
api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()
### Submit jobs
check_api_version(mdata)

submission = make_submission(
mdata["fp_machine"],
mdata["fp_resources"],
commands=[fp_command],
work_path=work_dir,
run_tasks=run_tasks,
group_size=fp_group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog="fp.log",
errlog="fp.log",
)
submission.run_submission()


def gen_init_bulk(args):
Expand Down
31 changes: 15 additions & 16 deletions dpgen/dispatcher/Dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,20 @@ def make_submission_compat(
"""
if Version(api_version) < Version("1.0"):
raise RuntimeError(
f"API version {api_version} has been removed. Please upgrade to 1.0."
"API version below 1.0 is no longer supported. Please upgrade to version 1.0 or newer."
)

elif Version(api_version) >= Version("1.0"):
submission = make_submission(
machine,
resources,
commands=commands,
work_path=work_path,
run_tasks=run_tasks,
group_size=group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog=outlog,
errlog=errlog,
)
submission.run_submission()
submission = make_submission(
machine,
resources,
commands=commands,
work_path=work_path,
run_tasks=run_tasks,
group_size=group_size,
forward_common_files=forward_common_files,
forward_files=forward_files,
backward_files=backward_files,
outlog=outlog,
errlog=errlog,
)
submission.run_submission()
11 changes: 11 additions & 0 deletions dpgen/generator/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import re
import shutil

from packaging.version import Version

iter_format = "%06d"
task_format = "%02d"
log_iter_head = "iter " + iter_format + " task " + task_format + ": "
Expand Down Expand Up @@ -110,3 +112,12 @@ def symlink_user_forward_files(mdata, task_type, work_path, task_format=None):
abs_file = os.path.abspath(file)
os.symlink(abs_file, os.path.join(task, os.path.basename(file)))
return


def check_api_version(mdata):
"""Check if the API version in mdata is at least 1.0."""
if Version(mdata.get("api_version", "1.0")) < Version("1.0"):
raise RuntimeError(
"API version below 1.0 is no longer supported. Please upgrade to version 1.0 or newer."
)
return
Loading

0 comments on commit c5812fb

Please sign in to comment.