Skip to content

Commit

Permalink
Merge pull request #15 from Relifest/main
Browse files Browse the repository at this point in the history
Basic function repair and maintenance.
  • Loading branch information
Relifest authored Jul 18, 2024
2 parents 7b89afd + 62b789d commit 582c8a9
Show file tree
Hide file tree
Showing 8 changed files with 2,731 additions and 353 deletions.
32 changes: 4 additions & 28 deletions pytdml/io/tdml_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,7 @@
#
# ------------------------------------------------------------------------------
import json


# from pytdml.type.basic_types_old import TrainingDataset
# from pytdml.type.extended_types_old import EOTrainingDataset
from pytdml.type.basic_types import TrainingDataset
from pytdml.type.extended_types import EOTrainingDataset

from pytdml.type import TrainingDataset, EOTrainingDataset


def read_from_json(file_path: str):
Expand All @@ -50,27 +44,9 @@ def read_from_json(file_path: str):

def parse_json(json_dict):
# Different kinds of training datasets are supported
if json_dict["type"] == "TrainingDataset":
return TrainingDataset(**json_dict)
elif json_dict["type"] == "EOTrainingDataset":
return EOTrainingDataset(**json_dict)
if json_dict["type"] == "AI_TrainingDataset":
return TrainingDataset.from_dict(json_dict)
elif json_dict["type"] == "AI_EOTrainingDataset":
return EOTrainingDataset(**json_dict)
return EOTrainingDataset.from_dict(json_dict)
else:
raise ValueError("Unknown TDML type: {}".format(json_dict["type"]))


# def read_from_json(file_path: str):
# """
# Reads a TDML JSON file and returns a TrainingDataset object.
# """
# with open(file_path, "r", encoding='utf-8') as f:
# json_dict = json.load(f)
# # Different kinds of training datasets are supported
# if json_dict["type"] == "TrainingDataset":
# return TrainingDataset.from_dict(json_dict)
# elif json_dict["type"] == "EOTrainingDataset":
# return EOTrainingDataset.from_dict(json_dict)
# else:
# raise ValueError("Unknown TDML type: {}".format(json_dict["type"]))

14 changes: 3 additions & 11 deletions pytdml/io/tdml_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,13 @@
# ------------------------------------------------------------------------------
import json
from typing import Union
from pytdml.type.basic_types import TrainingDataset
from pytdml.utils import remove_empty
from pytdml.type import TrainingDataset, EOTrainingDataset


# def write_to_json(td: TrainingDataset, file_path: str, indent: Union[None, int, str] = 4):
# """
# Writes a TrainingDataset to a JSON file.
# """
# with open(file_path, "w", encoding='utf-8') as f:
# json.dump(remove_empty(td.to_dict()), f, indent=indent, ensure_ascii=False)

def write_to_json(td: TrainingDataset, file_path: str, indent: Union[None, int, str] = 4):
def write_to_json(td: TrainingDataset or EOTrainingDataset, file_path: str, indent: Union[int, str] = 4):
"""
Writes a TrainingDataset to a JSON file.
"""
with open(file_path, "w", encoding='utf-8') as f:
json.dump(td.dict(by_alias=True,exclude_none=True), f, indent=indent, ensure_ascii=False)
json.dump(td.to_dict(), f, indent=indent, ensure_ascii=False)
# json.dump(remove_empty(td.dict()), f, indent=indent, ensure_ascii=False)
109 changes: 109 additions & 0 deletions pytdml/type/UiT_HCD_California_2017.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
{
"type": "AI_EOTrainingDataset",
"id": "uit_hcd_california_2017",
"name": "UiT HCD California 2017",
"description": "This dataset is composed of two images and a label image.",
"license": "CC BY-SA 4.0",
"version": "1.0",
"amountOfTrainingData": 1,
"createdTime": "2017-01-01",
"providers": [
"LP DAAC",
"ESA"
],
"classes": [
{
"key": "change",
"value": 1
},
{
"key": "unchanged",
"value": 0
}
],
"numberOfClasses": 2,
"bands": [
{
"name": [
{
"code": "red"
}
]
},
{
"name": [
{
"code": "green"
}
]
},
{
"name": [
{
"code": "blue"
}
]
},
{
"name": [
{
"code": "VH"
}
]
},
{
"name": [
{
"code": "VV"
}
]
},
{
"name": [
{
"code": "VV/VH"
}
]
}
],
"imageSize": "2000x3500",
"tasks": [
{
"type": "AI_EOTask",
"id": "uit_hcd_california_2017-task",
"description": "Multi-source images change detection",
"taskType": "http://demo#change_detection"
}
],
"data": [
{
"type": "AI_EOTrainingData",
"id": "0",
"dataTime": [
"2017-01-05",
"2017-02-18"
],
"dataURL": [
"t1_L8.png",
"t2_SAR.png"
],
"dataSources": [
{
"title": "Landsat-8"
}
],
"numberOfLabels": 1,
"labels": [
{
"type": "AI_PixelLabel",
"imageURL": [
"change_label.png"
],
"imageFormat": [
"image/png"
]
}
]
}
]
}
55 changes: 55 additions & 0 deletions pytdml/type/UiT_HCD_California_2017.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
---
type: AI_EOTrainingDataset
id: uit_hcd_california_2017
name: UiT HCD California 2017
description: This dataset is composed of two images and a label image.
license: CC BY-SA 4.0
version: '1.0'
amountOfTrainingData: 1
createdTime: '2017-01-01'
providers:
- LP DAAC
- ESA
classes:
- key: change
value: 1
- key: unchanged
value: 0
numberOfClasses: 2
bands:
- name:
- code: red
- name:
- code: green
- name:
- code: blue
- name:
- code: VH
- name:
- code: VV
- name:
- code: VV/VH
imageSize: 2000x3500
tasks:
- type: AI_EOTask
id: uit_hcd_california_2017-task
description: Multi-source images change detection
taskType: http://demo#change_detection
data:
- type: AI_EOTrainingData
id: '0'
dataTime:
- '2017-01-05'
- '2017-02-18'
dataURL:
- t1_L8.png
- t2_SAR.png
dataSources:
- title: Landsat-8
numberOfLabels: 1
labels:
- type: AI_PixelLabel
imageURL:
- change_label.png
imageFormat:
- image/png
43 changes: 13 additions & 30 deletions pytdml/type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,16 @@
# SOFTWARE.
#
# ------------------------------------------------------------------------------
from .basic_types import BaseCamelModel
from .basic_types import KeyValuePair
from .basic_types import MD_ScopeDescription
from .basic_types import MD_Band
from .basic_types import MD_Scope
from .basic_types import CI_Date
from .basic_types import MD_BrowseGraphic
from .basic_types import CI_Citation
from .basic_types import MD_Identifier
from .basic_types import MetricsPair
from .basic_types import MetricsInLiterature
from .basic_types import Task
from .basic_types import Labeler
from .basic_types import LabelingProcedure
from .basic_types import Labeling
from .basic_types import QualityElement
from .basic_types import DataQuality
from .basic_types import Label
from .basic_types import TrainingData
from .basic_types import Changeset
from .basic_types import StatisticsInfo
from .basic_types import TrainingDataset


from .extended_types import EOTrainingDataset
from .extended_types import EOTrainingData
from .extended_types import SceneLabel
from .extended_types import ObjectLabel
from .extended_types import PixelLabel
from .extended_types import EOTask
from .all_types import BaseCamelModel
from .all_types import KeyValuePair
from .all_types import MD_ScopeDescription
from .all_types import MD_Band
from .all_types import MD_Scope
from .all_types import CI_Date
from .all_types import MD_BrowseGraphic
from .all_types import CI_Citation
from .all_types import MD_Identifier
from .all_types import QualityElement
from .all_types import DataQuality
from .all_types import TrainingDataset
from .all_types import EOTrainingDataset
40 changes: 33 additions & 7 deletions pytdml/type/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@
import re





class InvalidDatetimeError(ValueError):
def __init__(self, message):
self.message = message
Expand Down Expand Up @@ -169,17 +166,34 @@ def _validate_image_format(image_format: str):
"application/x-netcdf",
"application/geopackage+sqlite3"
}
return image_format in image_format_list
if image_format in image_format_list:
return image_format
else:
pass


def _valid_methods(labeling_methods: str):
labeling_methods_list = ["manual", "automatic", "semi-automatic", "unknown"]
return labeling_methods in labeling_methods_list
if labeling_methods in labeling_methods_list:
return labeling_methods
else:
pass


def _validate_training_type(training_type: str):
training_type_list = ["training", "validation", "test", "retraining"]
return training_type in training_type_list
if training_type in training_type_list:
return training_type
else:
pass


def _validate_evaluation_method_type(evaluation_method_type: str):
evaluation_method_type_list = ["directInternal", "directExternal", "indirect"]
if evaluation_method_type in evaluation_method_type_list:
return evaluation_method_type
else:
pass


def to_camel(string: str) -> str:
Expand All @@ -191,4 +205,16 @@ def to_camel(string: str) -> str:
Returns:
str: camelCase string
"""
return re.sub(r"_(\w)", lambda match: match.group(1).upper(), string)
return re.sub(r"_(\w)", lambda match: match.group(1).upper(), string)


def to_interior_class(data_dict, name, class_name):
new_dic = data_dict[name]
new_dic = class_name.from_dict(new_dic)
data_dict[name] = new_dic


def list_to_interior_class(data_dict, name, class_name):
new_dic = data_dict[name]
new_dic = [class_name.from_dict(i) for i in new_dic]
data_dict[name] = new_dic
Loading

0 comments on commit 582c8a9

Please sign in to comment.