Skip to content

Commit

Permalink
Dynamic cast constraint (#5082)
Browse files Browse the repository at this point in the history
* dynamic cast constraint
* cleanup HistogramWordStringKernel
  • Loading branch information
gf712 authored Jun 29, 2020
1 parent a075792 commit f143eaf
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 117 deletions.
7 changes: 4 additions & 3 deletions src/shogun/base/AnyParameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <list>
#include <memory>
#include <string_view>
#include <optional>

namespace shogun
{
Expand Down Expand Up @@ -168,7 +169,7 @@ namespace shogun
}
AnyParameter(
Any&& value, const AnyParameterProperties& properties,
std::function<std::string(Any)> constrain_function)
std::function<std::optional<std::string>(Any)> constrain_function)
: m_value(std::move(value)), m_properties(properties),
m_constrain_function(std::move(constrain_function))
{
Expand Down Expand Up @@ -218,7 +219,7 @@ namespace shogun
return m_init_function;
}

const std::function<std::string(Any)>& get_constrain_function() const
const std::function<std::optional<std::string>(Any)>& get_constrain_function() const
noexcept
{
return m_constrain_function;
Expand Down Expand Up @@ -272,7 +273,7 @@ namespace shogun
Any m_value;
AnyParameterProperties m_properties;
std::shared_ptr<params::AutoInit> m_init_function;
std::function<std::string(Any)> m_constrain_function;
std::function<std::optional<std::string>(Any)> m_constrain_function;
std::vector<std::function<void()>> m_callback_functions;
};
} // namespace shogun
Expand Down
10 changes: 4 additions & 6 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -921,10 +921,8 @@ class SGObject: public std::enable_shared_from_this<SGObject>
BaseTag(name), AnyParameter(
make_any_ref(value), properties,
[constrain_function](const auto& val) {
std::string result;
auto casted_val = any_cast<T1>(val);
constrain_function.run(casted_val, result);
return result;
return constrain_function.check(casted_val);
}));
register_parameter_visitor<T1>();
}
Expand Down Expand Up @@ -1131,12 +1129,12 @@ class SGObject: public std::enable_shared_from_this<SGObject>

if (pprop.has_property(ParameterProperties::CONSTRAIN))
{
auto msg = param.get_constrain_function()(make_any(value));
if (!msg.empty())
const auto& val = param.get_constrain_function()(make_any(value));
if (val)
{
require(!do_checks,
"{}::{} cannot be updated because it must be: {}!",
get_name(), _tag.name().c_str(), msg.c_str());
get_name(), _tag.name().c_str(), *val);
}
}
if constexpr (std::is_same_v<T, Any>)
Expand Down
110 changes: 86 additions & 24 deletions src/shogun/base/constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@
#ifndef __CONSTRAINT_H__
#define __CONSTRAINT_H__

#include <shogun/io/SGIO.h>
#include <shogun/util/traits.h>

#include <string>
#include <tuple>

namespace shogun
{
class SGObject;

namespace constraint_detail
{
template <typename T, typename... Args, std::size_t... Idx>
Expand Down Expand Up @@ -86,18 +91,80 @@ namespace shogun
template <typename T>
struct generic_checker
{
public:
generic_checker(T val) : m_val(val){};
bool operator()(T val) const
generic_checker() = default;
bool operator()(const T& val) const
{
return check(val);
};

virtual std::string error_msg() const = 0;

protected:
virtual bool check(const T& val) const = 0;
};

template <typename T>
struct custom_constraint: generic_checker<T>
{
template <typename Functor>
custom_constraint(Functor&& func): m_func(func)
{
}

std::string error_msg() const override
{
return msg;
}

protected:
bool check(const T& val) const override
{
try
{
m_func(val);
}
catch (const std::exception& e)
{
msg = std::string(e.what());
return false;
}

return true;
}

private:
std::string msg;
std::function<void(const T&)> m_func;
};


template <typename DerivedType>
struct castable: generic_checker<std::shared_ptr<SGObject>>
{
castable(): generic_checker<std::shared_ptr<SGObject>>()
{
}

std::string error_msg() const override
{
return "of type " + demangled_type<DerivedType>();
}

protected:
bool check(const std::shared_ptr<SGObject>& ptr) const override
{
return static_cast<bool>(std::dynamic_pointer_cast<DerivedType>(ptr));
}
};

template <typename T>
struct comparisson_checker: generic_checker<T>
{
comparisson_checker(T val): generic_checker<T>() {
m_val = val;
}
protected:
T m_val;
virtual bool check(T val) const = 0;
};

/**
Expand All @@ -106,18 +173,17 @@ namespace shogun
* @tparam T the type of val
*/
template <typename T>
struct less_than : generic_checker<T>
struct less_than : comparisson_checker<T>
{
public:
less_than(T val) : generic_checker<T>(val){};
less_than(T val) : comparisson_checker<T>(val){};

std::string error_msg() const override
{
return "less than " + std::to_string(this->m_val);
}

protected:
bool check(T val) const override
bool check(const T& val) const override
{
return val < this->m_val;
}
Expand All @@ -129,18 +195,17 @@ namespace shogun
* @tparam T the type of val
*/
template <typename T>
struct less_than_or_equal : generic_checker<T>
struct less_than_or_equal : comparisson_checker<T>
{
public:
less_than_or_equal(T val) : generic_checker<T>(val){};
less_than_or_equal(T val) : comparisson_checker<T>(val){};

std::string error_msg() const override
{
return "less than " + std::to_string(this->m_val);
}

protected:
bool check(T val) const override
bool check(const T& val) const override
{
return val <= this->m_val;
}
Expand All @@ -152,17 +217,16 @@ namespace shogun
* @tparam T the type of val
*/
template <typename T>
struct greater_than : generic_checker<T>
struct greater_than : comparisson_checker<T>
{
public:
greater_than(T val) : generic_checker<T>(val){};
greater_than(T val) : comparisson_checker<T>(val){};
std::string error_msg() const override
{
return "greater than " + std::to_string(this->m_val);
}

protected:
bool check(T val) const override
bool check(const T& val) const override
{
return val > this->m_val;
}
Expand All @@ -174,18 +238,17 @@ namespace shogun
* @tparam T the type of val
*/
template <typename T>
struct greater_than_or_equal : generic_checker<T>
struct greater_than_or_equal : comparisson_checker<T>
{
public:
greater_than_or_equal(T val) : generic_checker<T>(val){};
greater_than_or_equal(T val) : comparisson_checker<T>(val){};

std::string error_msg() const override
{
return "less than " + std::to_string(this->m_val);
}

protected:
bool check(T val) const override
bool check(const T& val) const override
{
return val >= this->m_val;
}
Expand Down Expand Up @@ -237,14 +300,13 @@ namespace shogun
}

template <typename T>
bool run(T val, std::string& buffer) const
std::optional<std::string> check(const T& val) const
{
if (!constraint_detail::apply(val, m_funcs))
{
buffer = constraint_detail::get_error(m_funcs);
return false;
return constraint_detail::get_error(m_funcs);
}
return true;
return std::nullopt;
}

private:
Expand Down
Loading

0 comments on commit f143eaf

Please sign in to comment.