-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Rename the folder `tests` to `examples` to avoid misleading. - Update the `readme.md`
- Loading branch information
Showing
17 changed files
with
179 additions
and
199 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2021/1/4 | ||
# @Author : Lart Pang | ||
# @GitHub : https://github.com/lartpang | ||
|
||
import numpy as np | ||
|
||
from py_sod_metrics import Emeasure, Fmeasure, MAE, Smeasure, WeightedFmeasure | ||
|
||
|
||
def ndarray_to_basetype(data): | ||
""" | ||
将单独的ndarray,或者tuple,list或者dict中的ndarray转化为基本数据类型, | ||
即列表(.tolist())和python标量 | ||
""" | ||
|
||
def _to_list_or_scalar(item): | ||
listed_item = item.tolist() | ||
if isinstance(listed_item, list) and len(listed_item) == 1: | ||
listed_item = listed_item[0] | ||
return listed_item | ||
|
||
if isinstance(data, (tuple, list)): | ||
results = [_to_list_or_scalar(item) for item in data] | ||
elif isinstance(data, dict): | ||
results = {k: _to_list_or_scalar(item) for k, item in data.items()} | ||
else: | ||
assert isinstance(data, np.ndarray) | ||
results = _to_list_or_scalar(data) | ||
return results | ||
|
||
|
||
class CalTotalMetric(object): | ||
def __init__(self): | ||
""" | ||
用于统计各种指标的类 | ||
https://github.com/lartpang/Py-SOD-VOS-EvalToolkit/blob/81ce89da6813fdd3e22e3f20e3a09fe1e4a1a87c/utils/recorders/metric_recorder.py | ||
""" | ||
self.mae = MAE() | ||
self.fm = Fmeasure() | ||
self.sm = Smeasure() | ||
self.em = Emeasure() | ||
self.wfm = WeightedFmeasure() | ||
|
||
def step(self, pre: np.ndarray, gt: np.ndarray): | ||
assert pre.shape == gt.shape | ||
assert pre.dtype == np.uint8 | ||
assert gt.dtype == np.uint8 | ||
|
||
self.mae.step(pre, gt) | ||
self.sm.step(pre, gt) | ||
self.fm.step(pre, gt) | ||
self.em.step(pre, gt) | ||
self.wfm.step(pre, gt) | ||
|
||
def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict: | ||
""" | ||
返回指标计算结果: | ||
- 曲线数据(sequential): fm/em/p/r | ||
- 数值指标(numerical): SM/MAE/maxE/avgE/adpE/maxF/avgF/adpF/wFm | ||
""" | ||
fm_info = self.fm.get_results() | ||
fm = fm_info["fm"] | ||
pr = fm_info["pr"] | ||
wfm = self.wfm.get_results()["wfm"] | ||
sm = self.sm.get_results()["sm"] | ||
em = self.em.get_results()["em"] | ||
mae = self.mae.get_results()["mae"] | ||
|
||
sequential_results = { | ||
"fm": np.flip(fm["curve"]), | ||
"em": np.flip(em["curve"]), | ||
"p": np.flip(pr["p"]), | ||
"r": np.flip(pr["r"]), | ||
} | ||
numerical_results = { | ||
"SM": sm, | ||
"MAE": mae, | ||
"maxE": em["curve"].max(), | ||
"avgE": em["curve"].mean(), | ||
"adpE": em["adp"], | ||
"maxF": fm["curve"].max(), | ||
"avgF": fm["curve"].mean(), | ||
"adpF": fm["adp"], | ||
"wFm": wfm, | ||
} | ||
if num_bits is not None and isinstance(num_bits, int): | ||
numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()} | ||
if not return_ndarray: | ||
sequential_results = ndarray_to_basetype(sequential_results) | ||
numerical_results = ndarray_to_basetype(numerical_results) | ||
return {"sequential": sequential_results, "numerical": numerical_results} | ||
|
||
|
||
if __name__ == "__main__": | ||
data_loader = ... | ||
model = ... | ||
|
||
cal_total_seg_metrics = CalTotalMetric() | ||
for batch in data_loader: | ||
seg_preds = model(batch) | ||
for seg_pred in seg_preds: | ||
mask_array = ... | ||
cal_total_seg_metrics.step(seg_pred, mask_array) | ||
fixed_seg_results = cal_total_seg_metrics.get_results() |
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2020/11/21 | ||
# @Author : Lart Pang | ||
# @GitHub : https://github.com/lartpang | ||
|
||
import os | ||
|
||
import cv2 | ||
from tqdm import tqdm | ||
|
||
# pip install pysodmetrics | ||
from py_sod_metrics import Emeasure, Fmeasure, MAE, Smeasure, WeightedFmeasure | ||
|
||
FM = Fmeasure() | ||
WFM = WeightedFmeasure() | ||
SM = Smeasure() | ||
EM = Emeasure() | ||
MAE = MAE() | ||
|
||
data_root = "./test_data" | ||
mask_root = os.path.join(data_root, "masks") | ||
pred_root = os.path.join(data_root, "preds") | ||
mask_name_list = sorted(os.listdir(mask_root)) | ||
for mask_name in tqdm(mask_name_list, total=len(mask_name_list)): | ||
mask_path = os.path.join(mask_root, mask_name) | ||
pred_path = os.path.join(pred_root, mask_name) | ||
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | ||
pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE) | ||
FM.step(pred=pred, gt=mask) | ||
WFM.step(pred=pred, gt=mask) | ||
SM.step(pred=pred, gt=mask) | ||
EM.step(pred=pred, gt=mask) | ||
MAE.step(pred=pred, gt=mask) | ||
|
||
fm = FM.get_results()["fm"] | ||
wfm = WFM.get_results()["wfm"] | ||
sm = SM.get_results()["sm"] | ||
em = EM.get_results()["em"] | ||
mae = MAE.get_results()["mae"] | ||
|
||
results = { | ||
"Smeasure": sm.round(3), | ||
"wFmeasure": wfm.round(3), | ||
"MAE": mae.round(3), | ||
"adpEm": em["adp"].round(3), | ||
"meanEm": em["curve"].mean().round(3), | ||
"maxEm": em["curve"].max().round(3), | ||
"adpFm": fm["adp"].round(3), | ||
"meanFm": fm["curve"].mean().round(3), | ||
"maxFm": fm["curve"].max().round(3), | ||
} | ||
|
||
print(results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from py_sod_metrics.sod_metrics import * | ||
|
||
__version__ = "1.2.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,17 @@ | ||
from setuptools import setup, find_packages | ||
import py_sod_metrics as my_script | ||
|
||
with open("readme.md", "r") as fh: | ||
long_description = fh.read() | ||
|
||
setup( | ||
name="pysodmetrics", | ||
packages=find_packages(), | ||
version="1.2.1", | ||
version=my_script.__version__, | ||
license="MIT", | ||
description="A simple and efficient implementation of SOD metrics.", | ||
long_description=long_description, | ||
long_description_content_type="text/markdown", | ||
author="lartpang", | ||
author_email="[email protected]", | ||
url="https://github.com/lartpang/PySODMetrics", | ||
|
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.