Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ruff linter warnings #81

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pdebench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@


"""

from __future__ import annotations

import logging

_logger = logging.getLogger(__name__)
_logger.propagate = False

__version__ = "0.0.1"
__author__ = "Makoto Takamoto, Timothy Praditia, Raphael Leiteritz, Dan MacKinlay, Francesco Alesiani, Dirk Pflüger, Mathias Niepert"
__credits__ = "NEC labs Europe, University of Stuttgart, CSIRO" "s Data61"
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
"""
<NAME OF THE PROGRAM THIS FILE BELONGS TO>

Expand Down Expand Up @@ -144,8 +143,10 @@

THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
"""

from __future__ import annotations

import logging
import time
from math import ceil

Expand All @@ -157,11 +158,13 @@
# Hydra
from omegaconf import DictConfig

logger = logging.getLogger(__name__)


# Init arguments with Hydra
@hydra.main(config_path="config")
def main(cfg: DictConfig) -> None:
print(f"advection velocity: {cfg.args.beta}")
logger.info("advection velocity: %f", cfg.args.beta)

# cell edge coordinate
xe = jnp.linspace(cfg.args.xL, cfg.args.xR, cfg.args.nx + 1)
Expand All @@ -181,14 +184,14 @@ def evolve(u):
uu = uu.at[0].set(u)

while t < cfg.args.fin_time:
print(f"save data at t = {t:.3f}")
logger.info("save data at t = %f", t)
u = set_function(xc, t, cfg.args.beta)
uu = uu.at[i_save].set(u)
t += cfg.args.dt_save
i_save += 1

tm_fin = time.time()
print(f"total elapsed time is {tm_fin - tm_ini} sec")
logger.info("total elapsed time is %f sec", tm_fin - tm_ini)
uu = uu.at[-1].set(u)
return uu, t

Expand All @@ -199,9 +202,9 @@ def set_function(x, t, beta):
u = set_function(xc, t=0, beta=cfg.args.beta)
u = device_put(u) # putting variables in GPU (not necessary??)
uu, t = evolve(u)
print(f"final time is: {t:.3f}")
logger.info("final time is: %f", t)

print("data saving...")
logger.info("data saving...")
cwd = hydra.utils.get_original_cwd() + "/"
jnp.save(cwd + cfg.args.save + "/Advection_beta" + str(cfg.args.beta), uu)
jnp.save(cwd + cfg.args.save + "/x_coordinate", xe)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
"""
<NAME OF THE PROGRAM THIS FILE BELONGS TO>

Expand Down Expand Up @@ -144,22 +143,29 @@

THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
"""

from __future__ import annotations

import random
import sys

# Hydra
from math import ceil, exp, log
from pathlib import Path

import hydra
import jax
import jax.numpy as jnp
from jax import device_put, lax

# Hydra
from omegaconf import DictConfig

sys.path.append("..")
import logging

from utils import Courant, bc, init_multi, limiting

logger = logging.getLogger(__name__)


def _pass(carry):
return carry
Expand Down Expand Up @@ -192,7 +198,7 @@ def main(cfg: DictConfig) -> None:
else:
beta = cfg.multi.beta

print("beta: ", beta)
logger.info("beta: %f", beta)

@jax.jit
def evolve(u):
Expand All @@ -204,7 +210,8 @@ def evolve(u):
uu = jnp.zeros([it_tot, u.shape[0]])
uu = uu.at[0].set(u)

cond_fun = lambda x: x[0] < fin_time
def cond_fun(x):
return x[0] < fin_time

def _body_fun(carry):
def _show(_carry):
Expand All @@ -226,9 +233,7 @@ def _show(_carry):

carry = t, tsave, steps, i_save, dt, u, uu
t, tsave, steps, i_save, dt, u, uu = lax.while_loop(cond_fun, _body_fun, carry)
uu = uu.at[-1].set(u)

return uu
return uu.at[-1].set(u)

@jax.jit
def simulation_fn(i, carry):
Expand Down Expand Up @@ -265,12 +270,11 @@ def flux(u):
fL = uL * beta
fR = uR * beta
# upwind advection scheme
f_upwd = 0.5 * (
return 0.5 * (
fR[1 : cfg.multi.nx + 2]
+ fL[2 : cfg.multi.nx + 3]
- jnp.abs(beta) * (uL[2 : cfg.multi.nx + 3] - uR[1 : cfg.multi.nx + 2])
)
return f_upwd

u = init_multi(xc, numbers=cfg.multi.numbers, k_tot=4, init_key=cfg.multi.init_key)
u = device_put(u) # putting variables in GPU (not necessary??)
Expand All @@ -285,7 +289,7 @@ def flux(u):
# reshape before saving
uu = uu.reshape((-1, *uu.shape[2:]))

print("data saving...")
logger.info("data saving...")
cwd = hydra.utils.get_original_cwd() + "/"
Path(cwd + cfg.multi.save).mkdir(parents=True, exist_ok=True)
jnp.save(cwd + cfg.multi.save + "1D_Advection_Sols_beta" + str(beta)[:5], uu)
Expand Down
14 changes: 9 additions & 5 deletions pdebench/data_gen/data_gen_NLE/BurgersEq/burgers_Hydra.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
"""
<NAME OF THE PROGRAM THIS FILE BELONGS TO>

Expand Down Expand Up @@ -144,8 +143,10 @@

THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
"""

from __future__ import annotations

import logging
import sys
import time
from math import ceil
Expand All @@ -161,6 +162,8 @@
sys.path.append("..")
from utils import Courant, Courant_diff, bc, init, limiting

logger = logging.getLogger(__name__)


def _pass(carry):
return carry
Expand Down Expand Up @@ -201,7 +204,8 @@ def evolve(u):

tm_ini = time.time()

cond_fun = lambda x: x[0] < fin_time
def cond_fun(x):
return x[0] < fin_time

def _body_fun(carry):
def _save(_carry):
Expand All @@ -227,7 +231,7 @@ def _save(_carry):
uu = uu.at[-1].set(u)

tm_fin = time.time()
print(f"total elapsed time is {tm_fin - tm_ini} sec")
logger.info("total elapsed time is %f sec", tm_fin - tm_ini)
return uu, t

@jax.jit
Expand Down Expand Up @@ -285,9 +289,9 @@ def flux(u):
u = init(xc=xc, mode=cfg.args.init_mode, u0=cfg.args.u0, du=cfg.args.du)
u = device_put(u) # putting variables in GPU (not necessary??)
uu, t = evolve(u)
print(f"final time is: {t:.3f}")
logger.info("final time is: %.3f", t)

print("data saving...")
logger.info("data saving...")
cwd = hydra.utils.get_original_cwd() + "/"
if cfg.args.init_mode == "sinsin":
jnp.save(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
"""
<NAME OF THE PROGRAM THIS FILE BELONGS TO>

Expand Down Expand Up @@ -144,8 +143,10 @@

THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
"""

from __future__ import annotations

import logging
import random
import sys
from math import ceil, exp, log
Expand All @@ -162,6 +163,8 @@
sys.path.append("..")
from utils import Courant, Courant_diff, bc, init_multi, limiting

logger = logging.getLogger(__name__)


def _pass(carry):
return carry
Expand Down Expand Up @@ -191,7 +194,7 @@ def main(cfg: DictConfig) -> None:
) # uniform number between 0.01 to 100
else:
epsilon = cfg.multi.epsilon
print("epsilon: ", epsilon)
logger.info("epsilon: %f", epsilon)
# t-coordinate
it_tot = ceil((fin_time - ini_time) / dt_save) + 1
tc = jnp.arange(it_tot + 1) * dt_save
Expand All @@ -206,7 +209,8 @@ def evolve(u):
uu = jnp.zeros([it_tot, u.shape[0]])
uu = uu.at[0].set(u)

cond_fun = lambda x: x[0] < fin_time
def cond_fun(x):
return x[0] < fin_time

def _body_fun(carry):
def _show(_carry):
Expand All @@ -228,9 +232,7 @@ def _show(_carry):

carry = t, tsave, steps, i_save, dt, u, uu
t, tsave, steps, i_save, dt, u, uu = lax.while_loop(cond_fun, _body_fun, carry)
uu = uu.at[-1].set(u)

return uu
return uu.at[-1].set(u)

@jax.jit
def simulation_fn(i, carry):
Expand Down Expand Up @@ -301,7 +303,7 @@ def flux(u):
# reshape before saving
uu = uu.reshape((-1, *uu.shape[2:]))

print("data saving...")
logger.info("data saving...")
cwd = hydra.utils.get_original_cwd() + "/"
Path(cwd + cfg.multi.save).mkdir(parents=True, exist_ok=True)
jnp.save(cwd + cfg.multi.save + "1D_Burgers_Sols_Nu" + str(epsilon)[:5], uu)
Expand Down
Loading
Loading