Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Releasing 0.2.0 #10

Merged
merged 7 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/unittest_py.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9, '3.10']
python-version: [3.8, 3.9, '3.10', '3.11']

steps:
- uses: actions/checkout@v2
Expand All @@ -41,9 +41,8 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8
pip install -r requirements.txt
pip install jax[cpu]
python setup.py install
pip install -e '.[test]'
- name: Lint with flake8
working-directory: ./python
run: |
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ formats:

# Optionally set the version of Python and requirements required to build your docs
python:
version: "3.8"
version: "3.9"
install:
- requirements: docs/build_requirements.txt
12 changes: 9 additions & 3 deletions changelog.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Changelog

0.1.6 (working release)
0.2.1 (working release)
-----------------------
1. Try JET for efficient computation of recursive derivatives.
1. Try JET or forward Laplacian for efficient computation of recursive derivatives.

0.2.0 (31 October 2024)
-----------------------
1. Deprecated the old `setup.py`; now use `pyproject.toml`.
2. Added semantic typings for better documentation.
3. Added support for SDEs with time-dependent drift/dispersion, as well as for target function \phi. The time variable support is not implemented for `tme.mean_and_cov` as we don't have a consistent approximation for the covariance part for now. Note implemented for the `sympy` module either.
4. Migrated all the `unittest` to `pytest`.

0.1.5 (8 October, 2022)
-----------------------
1. Fixed a critical bug in computing the matrix-Hessian-matrix multiplication. Fortunately this bug does not affect the results when using constant dispersion coefficient.


0.1.4 (8 June, 2022)
-----------------------
1. Changed the verbose printing of TME matlab.
Expand Down
157 changes: 97 additions & 60 deletions python/examples/benes_jax.ipynb
100755 → 100644

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions python/examples/generate_index_figure.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""
Generate the animated figure in the index page.
"""
import jax
import math
from typing import Tuple

import jax.numpy as jnp
import matplotlib.pyplot as plt
import tme.base_jax as tme
from jax import vmap
from jax.config import config
from matplotlib.animation import FuncAnimation
from typing import Tuple

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

alp = 1.
Qw = 0.1
Expand Down
3 changes: 1 addition & 2 deletions python/examples/generate_lorenz_anime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import matplotlib.pyplot as plt
import tme.base_jax as tme
from jax import jit, lax
from jax.config import config
from matplotlib.animation import FuncAnimation

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

sigma = 10.
rho = 28.
Expand Down
7 changes: 4 additions & 3 deletions python/examples/nonlinear_multidim_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@
"outputs": [],
"source": [
"# Imports\n",
"import jax\n",
"import math\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"import tme.base_jax as tme\n",
"import matplotlib.pyplot as plt\n",
"from jax import jit\n",
"from jax.config import config\n",
"from typing import Tuple\n",
"\n",
"config.update(\"jax_enable_x64\", True)"
"jax.config.update(\"jax_enable_x64\", True)"
]
},
{
Expand Down Expand Up @@ -272,7 +272,8 @@
],
"metadata": {
"collapsed": false
}
},
"id": "911e15b49a00f541"
}
],
"metadata": {
Expand Down
186 changes: 113 additions & 73 deletions python/examples/tme_lorenz_jax.ipynb

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
[build-system]
requires = ['setuptools >= 75.2.0']
build-backend = 'setuptools.build_meta'

[tool.setuptools]
packages = ['tme']

[project]
name = 'tme'
version = '0.2.0'
authors = [
{ name = 'Zheng Zhao', email = '[email protected]' },
]
description = 'Taylor moment expansion in Python'
readme = 'README.md'
license = { file = 'LICENSE' }
requires-python = '>=3.8'
classifiers = [
'Programming Language :: Python :: 3',
'License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)',
'Operating System :: OS Independent',
'Intended Audience :: Science/Research',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Information Analysis',
'Topic :: Scientific/Engineering :: Mathematics'
]
keywords = ['Markov models', 'stochastic differential equations', 'statistics']
dependencies = [
'numpy>=1.19.2',
'scipy>=1.5.2',
'sympy>=1.8'
]

[project.optional-dependencies]
test = [
'pytest==8.3.3'
]

[project.urls]
homepage = 'https://github.com/zgbkdlm/tme'
documentation = 'https://tme.readthedocs.io'
repository = 'https://github.com/zgbkdlm/tme'
3 changes: 0 additions & 3 deletions python/requirements.txt

This file was deleted.

3 changes: 0 additions & 3 deletions python/setup.cfg

This file was deleted.

35 changes: 0 additions & 35 deletions python/setup.py

This file was deleted.

2 changes: 2 additions & 0 deletions python/test/test_experimental.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_test():
pass
143 changes: 143 additions & 0 deletions python/test/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import jax
import jax.numpy as jnp
import numpy.testing as npt
import tme.base_jax as tme

jax.config.update('jax_enable_x64', True)

sdes = [(lambda x, t: t * jnp.sin(x),
lambda x, t: jnp.exp(t) * jnp.array([[x[0], 0.],
[0., x[1]]])),
(lambda x, t: -x,
lambda x, t: jnp.eye(2))]


def euler_maruyama(key, x0, t0, T, nsteps, drift, dispersion):
def scan_body(carry, elem):
x = carry
t, rnd = elem

x = x + drift(x, t) * dt + jnp.sqrt(dt) * dispersion(x, t) @ rnd
return x, x

rnds = jax.random.normal(key, (nsteps, *x0.shape))
dt = (T - t0) / nsteps
ts = jnp.linspace(t0, T, nsteps + 1)
xT, xs = jax.lax.scan(scan_body, x0, (ts[:-1], rnds))
return xT, xs


def test_generator():
"""Test a single generator vs handwritten result.
"""

drift, dispersion = sdes[0]

def phi(x, t):
return t * jnp.outer(x, x)

def truth(x, t):
a = drift(x, t)
v11 = jnp.array([2 * x[0], 0])
v12 = x[::-1]
v22 = jnp.array([0, 2 * x[1]])
return (jnp.outer(x, x)
+ jnp.array([[jnp.dot(v11, a), jnp.dot(v12, a)],
[jnp.dot(v12, a), jnp.dot(v22, a)]]) * t
+ jnp.exp(2 * t) * jnp.diag(x ** 2) * t)

def actual(x, t):
return tme.generator(phi, drift, dispersion)(x, t)

x_ = jnp.array([1.2, 0.3])
t_ = 0.4
npt.assert_allclose(actual(x_, t_), truth(x_, t_))


def test_expectation_monte_carlo():
drift, dispersion = sdes[0]

def phi(x, t):
return jnp.tanh(t * x)

x0 = jnp.array([1.2, 0.3])
t0 = 0.
T = 0.2

def mc_simulator(key_):
xT = euler_maruyama(key_, x0, t0, T, 1000, drift, dispersion)[0]
return phi(xT, T)

key = jax.random.PRNGKey(666)
keys = jax.random.split(key, num=100000)

mc_result = jnp.mean(jax.vmap(mc_simulator)(keys), axis=0)
tme_result = tme.expectation(phi, x0, t0, T, drift, dispersion, order=2)
npt.assert_allclose(mc_result, tme_result, rtol=3e-2)


def test_expectation_lti():
"""Test generator powers vs Monte Carlo approximations.
"""

drift, dispersion = sdes[1]

def phi(x, t):
return x * t

x0 = jnp.array([1., 2.])
t0 = 0.
T = 1.
true_mean = jnp.exp(-(T - t0)) * x0 * T
approx_mean = tme.expectation(phi, x0, t0, T, drift, dispersion, order=5)
npt.assert_allclose(approx_mean, true_mean, rtol=2e-2)


def test_mean_and_cov_lti():
drift, dispersion = sdes[1]

x0 = jnp.array([1., 2.])
t0 = 0.
T = 1.

true_mean = jnp.exp(-(T - t0)) * x0
true_cov = 0.5 * (1 - jnp.exp(-2 * (T - t0))) * jnp.eye(2)

approx_m, approx_cov = tme.mean_and_cov(x0, T - t0,
lambda x: drift(x, t0),
lambda x: dispersion(x, t0), order=5)

npt.assert_allclose(approx_m, true_mean, rtol=2e-2)
npt.assert_allclose(approx_cov, true_cov, rtol=2e-2)


def test_mean_and_cov_vs_euler_maruyama():
"""TME with order=1 is consistent with Euler--Maruyama.
"""
sigma = 10.
rho = 28.
beta = 8 / 3

def drift(u):
return jnp.array([sigma * (u[1] - u[0]),
u[0] * (rho - u[2]) - u[1],
u[0] * u[1] - beta * u[2]])

def dispersion(u):
return jnp.diag(jnp.array([1., u[1] * u[2], u[0]]))

def tme_m_cov(u, dt):
return tme.mean_and_cov(x=u, dt=dt, drift=drift, dispersion=dispersion, order=1)

def em_m_cov(u, dt):
return u + drift(u) * dt, dispersion(u) @ dispersion(u).T * dt

key = jax.random.PRNGKey(666)
x0 = jax.random.normal(key, (3,))
T = 1.

tme_m, tme_cov = tme_m_cov(x0, T)
em_m, em_cov = em_m_cov(x0, T)

npt.assert_allclose(tme_m, em_m, atol=1e-12)
npt.assert_allclose(tme_cov, em_cov, atol=1e-12)
Loading
Loading