Skip to content

Commit

Permalink
- Modify some examples.
Browse files Browse the repository at this point in the history
- Rename the folder `tests` to `examples` to avoid misleading.
- Update the `readme.md`
  • Loading branch information
lartpang committed Mar 13, 2021
1 parent c6d273a commit 58a2c2f
Show file tree
Hide file tree
Showing 17 changed files with 179 additions and 199 deletions.
106 changes: 106 additions & 0 deletions examples/metric_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
# @Time : 2021/1/4
# @Author : Lart Pang
# @GitHub : https://github.com/lartpang

import numpy as np

from py_sod_metrics import Emeasure, Fmeasure, MAE, Smeasure, WeightedFmeasure


def ndarray_to_basetype(data):
"""
将单独的ndarray,或者tuple,list或者dict中的ndarray转化为基本数据类型,
即列表(.tolist())和python标量
"""

def _to_list_or_scalar(item):
listed_item = item.tolist()
if isinstance(listed_item, list) and len(listed_item) == 1:
listed_item = listed_item[0]
return listed_item

if isinstance(data, (tuple, list)):
results = [_to_list_or_scalar(item) for item in data]
elif isinstance(data, dict):
results = {k: _to_list_or_scalar(item) for k, item in data.items()}
else:
assert isinstance(data, np.ndarray)
results = _to_list_or_scalar(data)
return results


class CalTotalMetric(object):
def __init__(self):
"""
用于统计各种指标的类
https://github.com/lartpang/Py-SOD-VOS-EvalToolkit/blob/81ce89da6813fdd3e22e3f20e3a09fe1e4a1a87c/utils/recorders/metric_recorder.py
"""
self.mae = MAE()
self.fm = Fmeasure()
self.sm = Smeasure()
self.em = Emeasure()
self.wfm = WeightedFmeasure()

def step(self, pre: np.ndarray, gt: np.ndarray):
assert pre.shape == gt.shape
assert pre.dtype == np.uint8
assert gt.dtype == np.uint8

self.mae.step(pre, gt)
self.sm.step(pre, gt)
self.fm.step(pre, gt)
self.em.step(pre, gt)
self.wfm.step(pre, gt)

def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
"""
返回指标计算结果:
- 曲线数据(sequential): fm/em/p/r
- 数值指标(numerical): SM/MAE/maxE/avgE/adpE/maxF/avgF/adpF/wFm
"""
fm_info = self.fm.get_results()
fm = fm_info["fm"]
pr = fm_info["pr"]
wfm = self.wfm.get_results()["wfm"]
sm = self.sm.get_results()["sm"]
em = self.em.get_results()["em"]
mae = self.mae.get_results()["mae"]

sequential_results = {
"fm": np.flip(fm["curve"]),
"em": np.flip(em["curve"]),
"p": np.flip(pr["p"]),
"r": np.flip(pr["r"]),
}
numerical_results = {
"SM": sm,
"MAE": mae,
"maxE": em["curve"].max(),
"avgE": em["curve"].mean(),
"adpE": em["adp"],
"maxF": fm["curve"].max(),
"avgF": fm["curve"].mean(),
"adpF": fm["adp"],
"wFm": wfm,
}
if num_bits is not None and isinstance(num_bits, int):
numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
if not return_ndarray:
sequential_results = ndarray_to_basetype(sequential_results)
numerical_results = ndarray_to_basetype(numerical_results)
return {"sequential": sequential_results, "numerical": numerical_results}


if __name__ == "__main__":
data_loader = ...
model = ...

cal_total_seg_metrics = CalTotalMetric()
for batch in data_loader:
seg_preds = model(batch)
for seg_pred in seg_preds:
mask_array = ...
cal_total_seg_metrics.step(seg_pred, mask_array)
fixed_seg_results = cal_total_seg_metrics.get_results()
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
53 changes: 53 additions & 0 deletions examples/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# @Time : 2020/11/21
# @Author : Lart Pang
# @GitHub : https://github.com/lartpang

import os

import cv2
from tqdm import tqdm

# pip install pysodmetrics
from py_sod_metrics import Emeasure, Fmeasure, MAE, Smeasure, WeightedFmeasure

FM = Fmeasure()
WFM = WeightedFmeasure()
SM = Smeasure()
EM = Emeasure()
MAE = MAE()

data_root = "./test_data"
mask_root = os.path.join(data_root, "masks")
pred_root = os.path.join(data_root, "preds")
mask_name_list = sorted(os.listdir(mask_root))
for mask_name in tqdm(mask_name_list, total=len(mask_name_list)):
mask_path = os.path.join(mask_root, mask_name)
pred_path = os.path.join(pred_root, mask_name)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
FM.step(pred=pred, gt=mask)
WFM.step(pred=pred, gt=mask)
SM.step(pred=pred, gt=mask)
EM.step(pred=pred, gt=mask)
MAE.step(pred=pred, gt=mask)

fm = FM.get_results()["fm"]
wfm = WFM.get_results()["wfm"]
sm = SM.get_results()["sm"]
em = EM.get_results()["em"]
mae = MAE.get_results()["mae"]

results = {
"Smeasure": sm.round(3),
"wFmeasure": wfm.round(3),
"MAE": mae.round(3),
"adpEm": em["adp"].round(3),
"meanEm": em["curve"].mean().round(3),
"maxEm": em["curve"].max().round(3),
"adpFm": fm["adp"].round(3),
"meanFm": fm["curve"].mean().round(3),
"maxFm": fm["curve"].max().round(3),
}

print(results)
3 changes: 3 additions & 0 deletions py_sod_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from py_sod_metrics.sod_metrics import *

__version__ = "1.2.2"
18 changes: 8 additions & 10 deletions py_sod_metrics/sod_metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
from scipy.ndimage import convolve, distance_transform_edt as bwdist

__version__ = "1.2.1"

_EPS = 1e-16
_TYPE = np.float64

Expand Down Expand Up @@ -282,9 +280,9 @@ def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: flo
results_parts = []
for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)):
align_matrix_value = (
2
* (combination[0] * combination[1])
/ (combination[0] ** 2 + combination[1] ** 2 + _EPS)
2
* (combination[0] * combination[1])
/ (combination[0] ** 2 + combination[1] ** 2 + _EPS)
)
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts.append(enhanced_matrix_value * part_numel)
Expand Down Expand Up @@ -324,9 +322,9 @@ def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.nd
results_parts = np.empty(shape=(4, 256), dtype=np.float64)
for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
align_matrix_value = (
2
* (combination[0] * combination[1])
/ (combination[0] ** 2 + combination[1] ** 2 + _EPS)
2
* (combination[0] * combination[1])
/ (combination[0] ** 2 + combination[1] ** 2 + _EPS)
)
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts[i] = enhanced_matrix_value * part_numel
Expand All @@ -336,7 +334,7 @@ def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.nd
return em

def generate_parts_numel_combinations(
self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel
self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel
):
bg_fg_numel = self.gt_fg_numel - fg_fg_numel
bg_bg_numel = pred_bg_numel - bg_fg_numel
Expand Down Expand Up @@ -428,7 +426,7 @@ def matlab_style_gauss2D(self, shape: tuple = (7, 7), sigma: int = 5) -> np.ndar
fspecial('gaussian',[shape],[sigma])
"""
m, n = [(ss - 1) / 2 for ss in shape]
y, x = np.ogrid[-m : m + 1, -n : n + 1]
y, x = np.ogrid[-m: m + 1, -n: n + 1]
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
sumh = h.sum()
Expand Down
4 changes: 2 additions & 2 deletions readme_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ pip install pysodmetrics

### 示例

* <./tests/test_metrics.py>
* <./tests/metric_recorder.py>
* <examples/metric_recorder.py>
* <examples/test_metrics.py>

## 感谢

Expand Down
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from setuptools import setup, find_packages
import py_sod_metrics as my_script

with open("readme.md", "r") as fh:
long_description = fh.read()

setup(
name="pysodmetrics",
packages=find_packages(),
version="1.2.1",
version=my_script.__version__,
license="MIT",
description="A simple and efficient implementation of SOD metrics.",
long_description=long_description,
long_description_content_type="text/markdown",
author="lartpang",
author_email="[email protected]",
url="https://github.com/lartpang/PySODMetrics",
Expand Down
58 changes: 0 additions & 58 deletions tests/metric_recorder.py

This file was deleted.

54 changes: 0 additions & 54 deletions tests/test_metrics.py

This file was deleted.

26 changes: 0 additions & 26 deletions tests/test_speed_for_count_nonzero.py

This file was deleted.

Loading

0 comments on commit 58a2c2f

Please sign in to comment.