From 7f39d5ae8a3e265c62869dc7f3511514f2e6eb27 Mon Sep 17 00:00:00 2001 From: sandy9999 Date: Tue, 9 Jul 2024 14:21:56 +0530 Subject: [PATCH 01/13] Added code for https://eprint.iacr.org/2024/1077 --- Compiler/decision_tree_new.py | 480 +++++++++++++++++++++++++++++ Programs/Source/custom_data_dt.mpc | 29 ++ 2 files changed, 509 insertions(+) create mode 100644 Compiler/decision_tree_new.py create mode 100644 Programs/Source/custom_data_dt.mpc diff --git a/Compiler/decision_tree_new.py b/Compiler/decision_tree_new.py new file mode 100644 index 000000000..56e350d66 --- /dev/null +++ b/Compiler/decision_tree_new.py @@ -0,0 +1,480 @@ +from Compiler.types import * +from Compiler.sorting import * +from Compiler.library import * +from Compiler import util, oram + +from itertools import accumulate +import math + +debug = False +debug_split = False +max_leaves = None + +def get_type(x): + if isinstance(x, (Array, SubMultiArray)): + return x.value_type + elif isinstance(x, (tuple, list)): + x = x[0] + x[-1] + if util.is_constant(x): + return cint + else: + return type(x) + else: + return type(x) + +def GetSortPerm(keys, *to_sort, n_bits=None, time=False): + """ + Compute and return secret shared permutation that stably sorts :param keys. + """ + for k in keys: + assert len(k) == len(keys[0]) + n_bits = n_bits or [None] * len(keys) + bs = Matrix.create_from(sum([k.get_vector().bit_decompose(nb) + for k, nb in reversed(list(zip(keys, n_bits)))], [])) + get_vec = lambda x: x[:] if isinstance(x, Array) else x + res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x + for x in to_sort) + res = res.transpose() + return radix_sort_permutation_from_matrix(bs, res) + +def PrefixSum(x): + return x.get_vector().prefix_sum() + +def PrefixSumR(x): + tmp = get_type(x).Array(len(x)) + tmp.assign_vector(x) + break_point() + tmp[:] = tmp.get_reverse_vector().prefix_sum() + break_point() + return tmp.get_reverse_vector() + +def PrefixSum_inv(x): + tmp = get_type(x).Array(len(x) + 1) + tmp.assign_vector(x, base=1) + tmp[0] = 0 + return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x)) + +def PrefixSumR_inv(x): + tmp = get_type(x).Array(len(x) + 1) + tmp.assign_vector(x) + tmp[-1] = 0 + return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x)) + +def ApplyPermutation(perm, x): + res = Array.create_from(x) + reveal_sort(perm, res, False) + return res + +def ApplyInversePermutation(perm, x): + res = Array.create_from(x) + reveal_sort(perm, res, True) + return res + +class SortPerm: + def __init__(self, x): + B = sint.Matrix(len(x), 2) + B.set_column(0, 1 - x.get_vector()) + B.set_column(1, x.get_vector()) + self.perm = Array.create_from(dest_comp(B)) + def apply(self, x): + res = Array.create_from(x) + reveal_sort(self.perm, res, False) + return res + def unapply(self, x): + res = Array.create_from(x) + reveal_sort(self.perm, res, True) + return res + +def Sort(keys, *to_sort, n_bits=None, time=False): + if time: + start_timer(1) + for k in keys: + assert len(k) == len(keys[0]) + n_bits = n_bits or [None] * len(keys) + bs = Matrix.create_from( + sum([k.get_vector().bit_decompose(nb) + for k, nb in reversed(list(zip(keys, n_bits)))], [])) + get_vec = lambda x: x[:] if isinstance(x, Array) else x + res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x + for x in to_sort) + res = res.transpose() + if time: + start_timer(11) + radix_sort_from_matrix(bs, res) + if time: + stop_timer(11) + stop_timer(1) + res = res.transpose() + return [sfix._new(get_vec(x), k=get_vec(y).k, f=get_vec(y).f) + if isinstance(get_vec(y), sfix) + else x for (x, y) in zip(res, to_sort)] + +def VectMax(key, *data, debug=False): + def reducer(x, y): + b = x[0]*y[1] > y[0]*x[1] + return [b.if_else(xx, yy) for xx, yy in zip(x, y)] + res = util.tree_reduce(reducer, zip(key, *data)) + return res + +def GroupSum(g, x): + assert len(g) == len(x) + p = PrefixSumR(x) * g + pi = SortPerm(g.get_vector().bit_not()) + p1 = pi.apply(p) + s1 = PrefixSumR_inv(p1) + d1 = PrefixSum_inv(s1) + d = pi.unapply(d1) * g + return PrefixSum(d) + +def GroupPrefixSum(g, x): + assert len(g) == len(x) + s = get_type(x).Array(len(x) + 1) + s[0] = 0 + s.assign_vector(PrefixSum(x), base=1) + q = get_type(s).Array(len(x)) + q.assign_vector(s.get_vector(size=len(x)) * g) + return s.get_vector(size=len(x), base=1) - GroupSum(g, q) + +def Custom_GT_Fractions(x_num, x_den, y_num, y_den, n_threads=2): + b = (x_num*y_den) > (x_den*y_num) + b = Array.create_from(b).get_vector() + return b + +def GroupMax(g, keys, *x, debug=False): + assert len(keys) == len(g) + for xx in x: + assert len(xx) == len(g) + n = len(g) + m = int(math.ceil(math.log(n, 2))) + keys = Array.create_from(keys) + x = [Array.create_from(xx) for xx in x] + g_new = Array.create_from(g) + g_old = g_new.same_shape() + for d in range(m): + w = 2 ** d + g_old[:] = g_new[:] + break_point() + vsize = n - w + g_new.assign_vector(g_old.get_vector(size=vsize).bit_or( + g_old.get_vector(size=vsize, base=w)), base=w) + b = Custom_GT_Fractions(keys.get_vector(size=vsize), x[0].get_vector(size=vsize), keys.get_vector(size=vsize, base=w), x[0].get_vector(size=vsize, base=w)) + for xx in [keys] + x: + a = b.if_else(xx.get_vector(size=vsize), + xx.get_vector(size=vsize, base=w)) + xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else( + xx.get_vector(size=vsize, base=w), a), base=w) + break_point() + t = sint.Array(len(g)) + t[-1] = 1 + t.assign_vector(g.get_vector(size=n - 1, base=1)) + return [GroupSum(g, t[:] * xx) for xx in [keys] + x] + +def ComputeGini(g, x, y, notysum, ysum, debug=False): + assert len(g) == len(y) + y = [y.get_vector().bit_not(), y] + u = [GroupPrefixSum(g, yy) for yy in y] + s = [notysum, ysum] + w = [ss - uu for ss, uu in zip(s, u)] + us = sum(u) + ws = sum(w) + uqs = u[0] ** 2 + u[1] ** 2 + wqs = w[0] ** 2 + w[1] ** 2 + res_num = ws * uqs + us * wqs + res_den = us * ws + xx = x + t = get_type(x).Array(len(x)) + t[-1] = MIN_VALUE + t.assign_vector(xx.get_vector(size=len(x) - 1) + \ + xx.get_vector(size=len(x) - 1, base=1)) + gg = g + p = sint.Array(len(x)) + p[-1] = 1 + p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( + xx.get_vector(size=len(x) - 1) == \ + xx.get_vector(size=len(x) - 1, base=1))) + break_point() + res_num = p[:].if_else(MIN_VALUE, res_num) + res_den = p[:].if_else(1, res_den) + t = p[:].if_else(MIN_VALUE, t[:]) + return res_num, res_den, t + +MIN_VALUE = -10000 + +def FormatLayer(h, g, *a, debug=False): + return CropLayer(h, *FormatLayer_without_crop(g, *a, debug=debug)) + +def FormatLayer_without_crop(g, *a, debug=False): + for x in a: + assert len(x) == len(g) + v = [g.if_else(aa, 0) for aa in a] + p = SortPerm(g.get_vector().bit_not()) + v = [p.apply(vv) for vv in v] + return v + +def CropLayer(k, *v): + if max_leaves: + n = min(2 ** k, max_leaves) + else: + n = 2 ** k + return [vv[:min(n, len(vv))] for vv in v] + +def TrainLeafNodes(h, g, y, NID, Label, debug=False): + assert len(g) == len(y) + assert len(g) == len(NID) + return FormatLayer(h, g, NID, Label, debug=debug) + +def GroupFirstOne(g, b): + assert len(g) == len(b) + s = GroupPrefixSum(g, b) + return s * b == 1 + +class TreeTrainer: + def GetInversePermutation(self, perm, n_threads=2): + res = Array.create_from(self.identity_permutation) + reveal_sort(perm, res) + return res + + def ApplyTests(self, x, AID, Threshold): + m = len(x) + n = len(AID) + assert len(AID) == len(Threshold) + for xx in x: + assert len(xx) == len(AID) + e = sint.Matrix(m, n) + + @for_range_multithread(self.n_threads, 1, m) + def _(j): + e[j][:] = AID[:] == j + xx = sum(x[j]*e[j] for j in range(m)) + + return 2 * xx.get_vector() < Threshold.get_vector() + + def TestSelection(self, g, x, y, pis, notysum, ysum, time=False): + for xx in x: + assert(len(xx) == len(g)) + assert len(g) == len(y) + m = len(x) + n = len(y) + gg = g + u, t = [get_type(x).Matrix(m, n) for i in range(2)] + v = get_type(y).Matrix(m, n) + s_num = get_type(y).Matrix(m, n) + s_den = get_type(y).Matrix(m, n) + a = sint.Array(n) + + notysum_arr = Array.create_from(notysum) + ysum_arr = Array.create_from(ysum) + + @for_range_multithread(self.n_threads, 1, m) + def _(j): + single = not self.n_threads or self.n_threads == 1 + time = self.time and single + if self.debug_selection: + print_ln('run %s', j) + u[j].assign_vector(x[j]) + v[j].assign_vector(y) + reveal_sort(pis[j], u[j]) + reveal_sort(pis[j], v[j]) + s_num[j][:], s_den[j][:], t[j][:] = ComputeGini(g, u[j], v[j], notysum_arr, ysum_arr, debug=False) + + ss_num, ss_den, tt, aa = VectMax((s_num[j][:] for j in range(m)), (s_den[j][:] for j in range(m)), (t[j][:] for j in range(m)), range(m), debug=self.debug) + + aaa = get_type(y).Array(n) + ttt = get_type(x).Array(n) + + GroupMax_num, GroupMax_den, GroupMax_ttt, GroupMax_aaa = GroupMax(g, ss_num, ss_den, tt, aa) + + f = sint.Array(n) + f = (self.zeros.get_vector() == notysum).bit_or(self.zeros.get_vector() == ysum) + aaa_vector, ttt_vector = f.if_else(0, GroupMax_aaa), f.if_else(MIN_VALUE, GroupMax_ttt) + + ttt.assign_vector(ttt_vector) + aaa.assign_vector(aaa_vector) + + return aaa, ttt + + def SetupPerm(self, g, x, y): + m = len(x) + n = len(y) + pis = get_type(y).Matrix(m, n) + @for_range_multithread(self.n_threads, 1, m) + def _(j): + @if_e(self.attr_lengths[j]) + def _(): + pis[j][:] = self.GetInversePermutation(GetSortPerm([x[j]], x[j], y, + n_bits=[1], time=time)) + @else_ + def _(): + pis[j][:] = self.GetInversePermutation(GetSortPerm([x[j]], x[j], y, + n_bits=[None], + time=time)) + return pis + + def UpdateState(self, g, x, y, pis, NID, b, k): + m = len(x) + n = len(y) + q = SortPerm(b) + + y[:] = q.apply(y) + NID[:] = 2 ** k * b + NID + NID[:] = q.apply(NID) + g[:] = GroupFirstOne(g, b.bit_not()) + GroupFirstOne(g, b) + g[:] = q.apply(g) + + b_arith = sint.Array(n) + b_arith = Array.create_from(b) + + @for_range_multithread(self.n_threads, 1, m) + def _(j): + x[j][:] = q.apply(x[j]) + b_permuted = ApplyPermutation(pis[j], b_arith) + + pis[j] = q.apply(pis[j]) + pis[j] = ApplyInversePermutation(pis[j], SortPerm(b_permuted).perm) + + return [g, x, y, NID, pis] + + @method_block + def train_layer(self, k): + g = self.g + x = self.x + y = self.y + NID = self.NID + pis = self.pis + s0 = GroupSum(g, y.get_vector().bit_not()) + s1 = GroupSum(g, y.get_vector()) + a, t = self.TestSelection(g, x, y, pis, s0, s1) + b = self.ApplyTests(x, a, t) + p = SortPerm(g.get_vector().bit_not()) + self.nids[k], self.aids[k], self.thresholds[k]= FormatLayer_without_crop(g[:], NID, a, t, debug=self.debug) + + @if_e(k < (len(self.nids)-1)) + def _(): + self.g, self.x, self.y, self.NID, self.pis = self.UpdateState(g, x, y, pis, NID, b, k) + @else_ + def _(): + self.label = Array.create_from(s0 < s1) + + def __init__(self, x, y, h, binary=False, attr_lengths=None, + n_threads=None): + """ Securely Training Decision Trees Efficiently by `Bhardwaj et al.`_ : https://eprint.iacr.org/2024/1077.pdf + + :param x: Attribute values + :param y: Binary labels + :param h: Height of the decision tree + :param binary: Binary attributes instead of continuous + :param attr_lengths: Attribute description for mixed data + (list of 0/1 for continuous/binary) + :param n_threads: Number of threads + + """ + assert not (binary and attr_lengths) + if binary: + attr_lengths = [1] * len(x) + else: + attr_lengths = attr_lengths or ([0] * len(x)) + for l in attr_lengths: + assert l in (0, 1) + self.attr_lengths = Array.create_from(regint(attr_lengths)) + Array.check_indices = False + Matrix.disable_index_checks() + for xx in x: + assert len(xx) == len(y) + m = len(x) + n = len(y) + self.g = sint.Array(n) + self.g.assign_all(0) + self.g[0] = 1 + self.NID = sint.Array(n) + self.NID.assign_all(1) + self.y = Array.create_from(y) + self.x = Matrix.create_from(x) + self.pis = sint.Matrix(m, n) + self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)] + self.thresholds = self.x.value_type.Matrix(h, n) + self.identity_permutation = sint.Array(n) + self.label = sintbit.Array(n) + self.zeros = sint.Array(n) + self.zeros.assign_all(0) + self.n_threads = n_threads + self.debug_selection = False + self.debug_threading = True + self.debug_gini = False + self.debug_init = False + self.debug_vectmax = False + self.debug = False + self.time = False + + def train(self): + """ Train and return decision tree. """ + n = len(self.y) + + @for_range(n) + def _(i): + self.identity_permutation[i] = sint(i) + + h = len(self.nids) + + self.pis = self.SetupPerm(self.g, self.x, self.y) + + @for_range(h) + def _(k): + self.train_layer(k) + return self.get_tree(h, self.label) + + def get_tree(self, h, Label): + Layer = [None] * (h + 1) + for k in range(h): + Layer[k] = CropLayer(k, self.nids[k], self.aids[k], + self.thresholds[k]) + Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID, Label, debug=self.debug) + return Layer + +def DecisionTreeTraining(x, y, h, binary=False): + return TreeTrainer(x, y, h, binary=binary).train() + +def output_decision_tree(layers): + """ Print decision tree output by :py:class:`TreeTrainer`. """ + + print_ln('full model %s', util.reveal(layers)) + for i, layer in enumerate(layers[:-1]): + print_ln('level %s:', i) + for j, x in enumerate(('NID', 'AID', 'Thr')): + print_ln(' %s: %s', x, util.reveal(layer[j])) + print_ln('leaves:') + for j, x in enumerate(('NID', 'result')): + print_ln(' %s: %s', x, util.reveal(layers[-1][j])) + +class TreeClassifier: + """ Tree classification that uses + :py:class:`TreeTrainer` internally. + + :param max_depth: Depth of decision tree + :param n_threads: Number of threads used + + """ + def __init__(self, max_depth, n_threads=None): + self.max_depth = max_depth + self.n_threads = n_threads + + @staticmethod + def get_attr_lengths(attr_types): + if attr_types == None: + return None + else: + return [1 if x == 'b' else 0 for x in attr_types] + + def fit(self, X, y, attr_types=None): + """ Train tree. + + :param X: Attribute values + :param y: Binary labels + + """ + self.tree = TreeTrainer( + X.transpose(), y, self.max_depth, + attr_lengths=self.get_attr_lengths(attr_types), + n_threads=self.n_threads).train() + + def output(self): + output_decision_tree(self.tree) diff --git a/Programs/Source/custom_data_dt.mpc b/Programs/Source/custom_data_dt.mpc new file mode 100644 index 000000000..2676ac150 --- /dev/null +++ b/Programs/Source/custom_data_dt.mpc @@ -0,0 +1,29 @@ +from sklearn.datasets import load_breast_cancer +from sklearn.model_selection import train_test_split +import pandas as pd +import random +import numpy as np + +m = int(program.args[2]) +n = int(program.args[1]) + +data_x = np.random.uniform(0, 10, (n, m)) +data_y = np.random.randint(2, size=(1, n)) +df_x = pd.DataFrame(data_x) +df_y = pd.DataFrame(data_y) + +df_x = sfix.input_tensor_via(0, df_x) +df_y = sint.input_tensor_via(0, df_y) +df_y = Array.create_from(df_y[0]) + +program.set_bit_length(32) +sfix.set_precision(16, 31) + +from Compiler.decision_tree_new import TreeClassifier + +tree = TreeClassifier(max_depth=int(program.args[3]), n_threads=4) + +tree.fit(df_x, df_y) + +# output tree +tree.output() From 15c616cfc319b49249418a5a58037a8b71efc9d4 Mon Sep 17 00:00:00 2001 From: sandy9999 Date: Tue, 9 Jul 2024 14:34:55 +0530 Subject: [PATCH 02/13] Added description to run protocol --- Compiler/decision_tree_new.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Compiler/decision_tree_new.py b/Compiler/decision_tree_new.py index 56e350d66..bf5eeb01c 100644 --- a/Compiler/decision_tree_new.py +++ b/Compiler/decision_tree_new.py @@ -359,6 +359,10 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None, n_threads=None): """ Securely Training Decision Trees Efficiently by `Bhardwaj et al.`_ : https://eprint.iacr.org/2024/1077.pdf + This protocol has communication complexity O( mN logN + hmN + hN log N) which is an improvement of ~min(h, m, log N) over `Hamada et al.`_ : https://petsymposium.org/popets/2023/popets-2023-0021.pdf + + To run this protocol, at the root of the MP-SPDZ repo, run Scripts/compile-run.py -H HOSTS -E ring custom_data_dt $((2**13)) 11 4 -Z 3 -R 128 + :param x: Attribute values :param y: Binary labels :param h: Height of the decision tree From 92f84196e116dd50a18e7f787b62693d64d79fe1 Mon Sep 17 00:00:00 2001 From: sandy9999 Date: Thu, 11 Jul 2024 11:49:59 +0530 Subject: [PATCH 03/13] Fixed compilation errors --- Compiler/decision_tree_new.py | 20 ++++++++++---------- Compiler/sorting.py | 21 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/Compiler/decision_tree_new.py b/Compiler/decision_tree_new.py index bf5eeb01c..f0376b688 100644 --- a/Compiler/decision_tree_new.py +++ b/Compiler/decision_tree_new.py @@ -357,21 +357,21 @@ def _(): def __init__(self, x, y, h, binary=False, attr_lengths=None, n_threads=None): - """ Securely Training Decision Trees Efficiently by `Bhardwaj et al.`_ : https://eprint.iacr.org/2024/1077.pdf + """ Securely Training Decision Trees Efficiently by `Bhardwaj et al.`_ : https://eprint.iacr.org/2024/1077.pdf - This protocol has communication complexity O( mN logN + hmN + hN log N) which is an improvement of ~min(h, m, log N) over `Hamada et al.`_ : https://petsymposium.org/popets/2023/popets-2023-0021.pdf + This protocol has communication complexity O( mN logN + hmN + hN log N) which is an improvement of ~min(h, m, log N) over `Hamada et al.`_ : https://petsymposium.org/popets/2023/popets-2023-0021.pdf - To run this protocol, at the root of the MP-SPDZ repo, run Scripts/compile-run.py -H HOSTS -E ring custom_data_dt $((2**13)) 11 4 -Z 3 -R 128 + To run this protocol, at the root of the MP-SPDZ repo, run Scripts/compile-run.py -H HOSTS -E ring custom_data_dt $((2**13)) 11 4 -Z 3 -R 128 - :param x: Attribute values - :param y: Binary labels - :param h: Height of the decision tree - :param binary: Binary attributes instead of continuous - :param attr_lengths: Attribute description for mixed data + :param x: Attribute values + :param y: Binary labels + :param h: Height of the decision tree + :param binary: Binary attributes instead of continuous + :param attr_lengths: Attribute description for mixed data (list of 0/1 for continuous/binary) - :param n_threads: Number of threads + :param n_threads: Number of threads - """ + """ assert not (binary and attr_lengths) if binary: attr_lengths = [1] * len(x) diff --git a/Compiler/sorting.py b/Compiler/sorting.py index 7779c7489..5a37f80eb 100644 --- a/Compiler/sorting.py +++ b/Compiler/sorting.py @@ -73,3 +73,24 @@ def _(): @library.else_ def _(): reveal_sort(h, D, reverse=True) + +def radix_sort_permutation_from_matrix(bs, D): + n = len(D) + for b in bs: + assert(len(b) == n) + B = types.sint.Matrix(n, 2) + h = types.Array.create_from(types.sint(types.regint.inc(n))) + @library.for_range(len(bs)) + def _(i): + b = bs[i] + B.set_column(0, 1 - b.get_vector()) + B.set_column(1, b.get_vector()) + c = types.Array.create_from(dest_comp(B)) + reveal_sort(c, h, reverse=False) + @library.if_e(i < len(bs) - 1) + def _(): + reveal_sort(h, bs[i + 1], reverse=True) + @library.else_ + def _(): + reveal_sort(h, D, reverse=True) + return h From 6d0469128c57dbad6d6c1cab4b0fc1a4e28db603 Mon Sep 17 00:00:00 2001 From: Sandhya Date: Thu, 11 Jul 2024 06:35:18 +0000 Subject: [PATCH 04/13] Changed decision_tree_new to decision_tree_optimized --- ...tree_new.py => decision_tree_optimized.py} | 121 ++++++++++++++++++ Programs/Source/breast_tree.mpc | 2 +- Programs/Source/custom_data_dt.mpc | 2 +- 3 files changed, 123 insertions(+), 2 deletions(-) rename Compiler/{decision_tree_new.py => decision_tree_optimized.py} (78%) diff --git a/Compiler/decision_tree_new.py b/Compiler/decision_tree_optimized.py similarity index 78% rename from Compiler/decision_tree_new.py rename to Compiler/decision_tree_optimized.py index f0376b688..bc001d46f 100644 --- a/Compiler/decision_tree_new.py +++ b/Compiler/decision_tree_optimized.py @@ -425,6 +425,28 @@ def _(i): def _(k): self.train_layer(k) return self.get_tree(h, self.label) + + def train_with_testing(self, *test_set, output=False): + """ Train decision tree and test against test data. + + :param y: binary labels (list or sint vector) + :param x: sample data (by attribute, list or + :py:obj:`~Compiler.types.Matrix`) + :param output: output tree after every level + :returns: tree + + """ + for k in range(len(self.nids)): + self.train_layer(k) + tree = self.get_tree(k + 1, self.label) + if output: + output_decision_tree(tree) + test_decision_tree('train', tree, self.y, self.x, + n_threads=self.n_threads) + if test_set: + test_decision_tree('test', tree, *test_set, + n_threads=self.n_threads) + return tree def get_tree(self, h, Label): Layer = [None] * (h + 1) @@ -449,6 +471,69 @@ def output_decision_tree(layers): for j, x in enumerate(('NID', 'result')): print_ln(' %s: %s', x, util.reveal(layers[-1][j])) +def pick(bits, x): + if len(bits) == 1: + return bits[0] * x[0] + else: + try: + return x[0].dot_product(bits, x) + except: + return sum(aa * bb for aa, bb in zip(bits, x)) + +def run_decision_tree(layers, data): + """ Run decision tree against sample data. + + :param layers: tree output by :py:class:`TreeTrainer` + :param data: sample data (:py:class:`~Compiler.types.Array`) + :returns: binary label + + """ + h = len(layers) - 1 + index = 1 + for k, layer in enumerate(layers[:-1]): + assert len(layer) == 3 + for x in layer: + assert len(x) <= 2 ** k + bits = layer[0].equal(index, k) + threshold = pick(bits, layer[2]) + key_index = pick(bits, layer[1]) + if key_index.is_clear: + key = data[key_index] + else: + key = pick( + oram.demux(key_index.bit_decompose(util.log2(len(data)))), data) + child = 2 * key < threshold + index += child * 2 ** k + bits = layers[h][0].equal(index, h) + return pick(bits, layers[h][1]) + +def test_decision_tree(name, layers, y, x, n_threads=None, time=False): + if time: + start_timer(100) + n = len(y) + x = x.transpose().reveal() + y = y.reveal() + guess = regint.Array(n) + truth = regint.Array(n) + correct = regint.Array(2) + parts = regint.Array(2) + layers = [[Array.create_from(util.reveal(x)) for x in layer] + for layer in layers] + @for_range_multithread(n_threads, 1, n) + def _(i): + guess[i] = run_decision_tree([[part[:] for part in layer] + for layer in layers], x[i]).reveal() + truth[i] = y[i].reveal() + @for_range(n) + def _(i): + parts[truth[i]] += 1 + c = (guess[i].bit_xor(truth[i]).bit_not()) + correct[truth[i]] += c + print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1, + sum(correct), n, correct[0], parts[0], correct[1], parts[1]) + if time: + stop_timer(100) + class TreeClassifier: """ Tree classification that uses :py:class:`TreeTrainer` internally. @@ -482,3 +567,39 @@ def fit(self, X, y, attr_types=None): def output(self): output_decision_tree(self.tree) + + def fit_with_testing(self, X_train, y_train, X_test, y_test, + attr_types=None, output_trees=False, debug=False): + """ Train tree with accuracy output after every level. + + :param X_train: training data with row-wise samples (sint/sfix matrix) + :param y_train: training binary labels (sint list/array) + :param X_test: testing data with row-wise samples (sint/sfix matrix) + :param y_test: testing binary labels (sint list/array) + :param attr_types: attributes types (list of 'b'/'c' for + binary/continuous; default is all continuous) + :param output_trees: output tree after every level + :param debug: output debugging information + + """ + trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth, + attr_lengths=self.get_attr_lengths(attr_types), + n_threads=self.n_threads) + trainer.debug = debug + trainer.debug_gini = debug + trainer.debug_threading = debug > 1 + self.tree = trainer.train_with_testing(y_test, X_test.transpose(), + output=output_trees) + + def predict(self, X): + """ Use tree for prediction. + + :param X: sample data with row-wise samples (sint/sfix matrix) + :returns: sint array + + """ + res = sint.Array(len(X)) + @for_range(len(X)) + def _(i): + res[i] = run_decision_tree(self.tree, X[i]) + return res diff --git a/Programs/Source/breast_tree.mpc b/Programs/Source/breast_tree.mpc index 9214584bb..3401e6a58 100644 --- a/Programs/Source/breast_tree.mpc +++ b/Programs/Source/breast_tree.mpc @@ -16,7 +16,7 @@ y_test = sint.input_tensor_via(0, y_test) sfix.set_precision_from_args(program) -from Compiler.decision_tree import TreeClassifier +from Compiler.decision_tree_optimized import TreeClassifier tree = TreeClassifier(max_depth=5, n_threads=2) diff --git a/Programs/Source/custom_data_dt.mpc b/Programs/Source/custom_data_dt.mpc index 2676ac150..77c9bc3b6 100644 --- a/Programs/Source/custom_data_dt.mpc +++ b/Programs/Source/custom_data_dt.mpc @@ -19,7 +19,7 @@ df_y = Array.create_from(df_y[0]) program.set_bit_length(32) sfix.set_precision(16, 31) -from Compiler.decision_tree_new import TreeClassifier +from Compiler.decision_tree_optimized import TreeClassifier tree = TreeClassifier(max_depth=int(program.args[3]), n_threads=4) From cb245d6164b8a014d6304e3efda3f82c71ac33f8 Mon Sep 17 00:00:00 2001 From: Sandhya Date: Tue, 16 Jul 2024 06:11:41 +0000 Subject: [PATCH 05/13] Removed some redundant blocks --- Compiler/decision_tree_optimized.py | 87 +---------------------------- 1 file changed, 1 insertion(+), 86 deletions(-) diff --git a/Compiler/decision_tree_optimized.py b/Compiler/decision_tree_optimized.py index bc001d46f..4e6fdbf2f 100644 --- a/Compiler/decision_tree_optimized.py +++ b/Compiler/decision_tree_optimized.py @@ -1,6 +1,7 @@ from Compiler.types import * from Compiler.sorting import * from Compiler.library import * +from Compiler.decision_tree import PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne from Compiler import util, oram from itertools import accumulate @@ -37,29 +38,6 @@ def GetSortPerm(keys, *to_sort, n_bits=None, time=False): res = res.transpose() return radix_sort_permutation_from_matrix(bs, res) -def PrefixSum(x): - return x.get_vector().prefix_sum() - -def PrefixSumR(x): - tmp = get_type(x).Array(len(x)) - tmp.assign_vector(x) - break_point() - tmp[:] = tmp.get_reverse_vector().prefix_sum() - break_point() - return tmp.get_reverse_vector() - -def PrefixSum_inv(x): - tmp = get_type(x).Array(len(x) + 1) - tmp.assign_vector(x, base=1) - tmp[0] = 0 - return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x)) - -def PrefixSumR_inv(x): - tmp = get_type(x).Array(len(x) + 1) - tmp.assign_vector(x) - tmp[-1] = 0 - return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x)) - def ApplyPermutation(perm, x): res = Array.create_from(x) reveal_sort(perm, res, False) @@ -70,45 +48,6 @@ def ApplyInversePermutation(perm, x): reveal_sort(perm, res, True) return res -class SortPerm: - def __init__(self, x): - B = sint.Matrix(len(x), 2) - B.set_column(0, 1 - x.get_vector()) - B.set_column(1, x.get_vector()) - self.perm = Array.create_from(dest_comp(B)) - def apply(self, x): - res = Array.create_from(x) - reveal_sort(self.perm, res, False) - return res - def unapply(self, x): - res = Array.create_from(x) - reveal_sort(self.perm, res, True) - return res - -def Sort(keys, *to_sort, n_bits=None, time=False): - if time: - start_timer(1) - for k in keys: - assert len(k) == len(keys[0]) - n_bits = n_bits or [None] * len(keys) - bs = Matrix.create_from( - sum([k.get_vector().bit_decompose(nb) - for k, nb in reversed(list(zip(keys, n_bits)))], [])) - get_vec = lambda x: x[:] if isinstance(x, Array) else x - res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x - for x in to_sort) - res = res.transpose() - if time: - start_timer(11) - radix_sort_from_matrix(bs, res) - if time: - stop_timer(11) - stop_timer(1) - res = res.transpose() - return [sfix._new(get_vec(x), k=get_vec(y).k, f=get_vec(y).f) - if isinstance(get_vec(y), sfix) - else x for (x, y) in zip(res, to_sort)] - def VectMax(key, *data, debug=False): def reducer(x, y): b = x[0]*y[1] > y[0]*x[1] @@ -116,25 +55,6 @@ def reducer(x, y): res = util.tree_reduce(reducer, zip(key, *data)) return res -def GroupSum(g, x): - assert len(g) == len(x) - p = PrefixSumR(x) * g - pi = SortPerm(g.get_vector().bit_not()) - p1 = pi.apply(p) - s1 = PrefixSumR_inv(p1) - d1 = PrefixSum_inv(s1) - d = pi.unapply(d1) * g - return PrefixSum(d) - -def GroupPrefixSum(g, x): - assert len(g) == len(x) - s = get_type(x).Array(len(x) + 1) - s[0] = 0 - s.assign_vector(PrefixSum(x), base=1) - q = get_type(s).Array(len(x)) - q.assign_vector(s.get_vector(size=len(x)) * g) - return s.get_vector(size=len(x), base=1) - GroupSum(g, q) - def Custom_GT_Fractions(x_num, x_den, y_num, y_den, n_threads=2): b = (x_num*y_den) > (x_den*y_num) b = Array.create_from(b).get_vector() @@ -223,11 +143,6 @@ def TrainLeafNodes(h, g, y, NID, Label, debug=False): assert len(g) == len(NID) return FormatLayer(h, g, NID, Label, debug=debug) -def GroupFirstOne(g, b): - assert len(g) == len(b) - s = GroupPrefixSum(g, b) - return s * b == 1 - class TreeTrainer: def GetInversePermutation(self, perm, n_threads=2): res = Array.create_from(self.identity_permutation) From ecdab0fd71bd6a14f2be53acb6033e525dfc7f0a Mon Sep 17 00:00:00 2001 From: Sandhya Date: Tue, 16 Jul 2024 06:44:16 +0000 Subject: [PATCH 06/13] Imported existing methods from decision_tree --- Compiler/decision_tree_optimized.py | 91 +---------------------------- Compiler/sorting.py | 20 ------- 2 files changed, 2 insertions(+), 109 deletions(-) diff --git a/Compiler/decision_tree_optimized.py b/Compiler/decision_tree_optimized.py index 4e6fdbf2f..4ac9e9f39 100644 --- a/Compiler/decision_tree_optimized.py +++ b/Compiler/decision_tree_optimized.py @@ -1,7 +1,7 @@ from Compiler.types import * from Compiler.sorting import * from Compiler.library import * -from Compiler.decision_tree import PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne +from Compiler.decision_tree import get_type, PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne, output_decision_tree, pick, run_decision_tree, test_decision_tree from Compiler import util, oram from itertools import accumulate @@ -11,18 +11,6 @@ debug_split = False max_leaves = None -def get_type(x): - if isinstance(x, (Array, SubMultiArray)): - return x.value_type - elif isinstance(x, (tuple, list)): - x = x[0] + x[-1] - if util.is_constant(x): - return cint - else: - return type(x) - else: - return type(x) - def GetSortPerm(keys, *to_sort, n_bits=None, time=False): """ Compute and return secret shared permutation that stably sorts :param keys. @@ -36,7 +24,7 @@ def GetSortPerm(keys, *to_sort, n_bits=None, time=False): res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x for x in to_sort) res = res.transpose() - return radix_sort_permutation_from_matrix(bs, res) + return radix_sort_from_matrix(bs, res) def ApplyPermutation(perm, x): res = Array.create_from(x) @@ -374,81 +362,6 @@ def get_tree(self, h, Label): def DecisionTreeTraining(x, y, h, binary=False): return TreeTrainer(x, y, h, binary=binary).train() -def output_decision_tree(layers): - """ Print decision tree output by :py:class:`TreeTrainer`. """ - - print_ln('full model %s', util.reveal(layers)) - for i, layer in enumerate(layers[:-1]): - print_ln('level %s:', i) - for j, x in enumerate(('NID', 'AID', 'Thr')): - print_ln(' %s: %s', x, util.reveal(layer[j])) - print_ln('leaves:') - for j, x in enumerate(('NID', 'result')): - print_ln(' %s: %s', x, util.reveal(layers[-1][j])) - -def pick(bits, x): - if len(bits) == 1: - return bits[0] * x[0] - else: - try: - return x[0].dot_product(bits, x) - except: - return sum(aa * bb for aa, bb in zip(bits, x)) - -def run_decision_tree(layers, data): - """ Run decision tree against sample data. - - :param layers: tree output by :py:class:`TreeTrainer` - :param data: sample data (:py:class:`~Compiler.types.Array`) - :returns: binary label - - """ - h = len(layers) - 1 - index = 1 - for k, layer in enumerate(layers[:-1]): - assert len(layer) == 3 - for x in layer: - assert len(x) <= 2 ** k - bits = layer[0].equal(index, k) - threshold = pick(bits, layer[2]) - key_index = pick(bits, layer[1]) - if key_index.is_clear: - key = data[key_index] - else: - key = pick( - oram.demux(key_index.bit_decompose(util.log2(len(data)))), data) - child = 2 * key < threshold - index += child * 2 ** k - bits = layers[h][0].equal(index, h) - return pick(bits, layers[h][1]) - -def test_decision_tree(name, layers, y, x, n_threads=None, time=False): - if time: - start_timer(100) - n = len(y) - x = x.transpose().reveal() - y = y.reveal() - guess = regint.Array(n) - truth = regint.Array(n) - correct = regint.Array(2) - parts = regint.Array(2) - layers = [[Array.create_from(util.reveal(x)) for x in layer] - for layer in layers] - @for_range_multithread(n_threads, 1, n) - def _(i): - guess[i] = run_decision_tree([[part[:] for part in layer] - for layer in layers], x[i]).reveal() - truth[i] = y[i].reveal() - @for_range(n) - def _(i): - parts[truth[i]] += 1 - c = (guess[i].bit_xor(truth[i]).bit_not()) - correct[truth[i]] += c - print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1, - sum(correct), n, correct[0], parts[0], correct[1], parts[1]) - if time: - stop_timer(100) - class TreeClassifier: """ Tree classification that uses :py:class:`TreeTrainer` internally. diff --git a/Compiler/sorting.py b/Compiler/sorting.py index 5a37f80eb..cdf9cf1f6 100644 --- a/Compiler/sorting.py +++ b/Compiler/sorting.py @@ -73,24 +73,4 @@ def _(): @library.else_ def _(): reveal_sort(h, D, reverse=True) - -def radix_sort_permutation_from_matrix(bs, D): - n = len(D) - for b in bs: - assert(len(b) == n) - B = types.sint.Matrix(n, 2) - h = types.Array.create_from(types.sint(types.regint.inc(n))) - @library.for_range(len(bs)) - def _(i): - b = bs[i] - B.set_column(0, 1 - b.get_vector()) - B.set_column(1, b.get_vector()) - c = types.Array.create_from(dest_comp(B)) - reveal_sort(c, h, reverse=False) - @library.if_e(i < len(bs) - 1) - def _(): - reveal_sort(h, bs[i + 1], reverse=True) - @library.else_ - def _(): - reveal_sort(h, D, reverse=True) return h From b6499cf02eb73ca744ec941e4fd5b7f38dcb1986 Mon Sep 17 00:00:00 2001 From: Sandhya Date: Wed, 31 Jul 2024 05:36:34 +0000 Subject: [PATCH 07/13] Fixed issue with NID --- Compiler/decision_tree_optimized.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Compiler/decision_tree_optimized.py b/Compiler/decision_tree_optimized.py index 4ac9e9f39..b7566906d 100644 --- a/Compiler/decision_tree_optimized.py +++ b/Compiler/decision_tree_optimized.py @@ -250,11 +250,9 @@ def train_layer(self, k): b = self.ApplyTests(x, a, t) p = SortPerm(g.get_vector().bit_not()) self.nids[k], self.aids[k], self.thresholds[k]= FormatLayer_without_crop(g[:], NID, a, t, debug=self.debug) + self.g, self.x, self.y, self.NID, self.pis = self.UpdateState(g, x, y, pis, NID, b, k) - @if_e(k < (len(self.nids)-1)) - def _(): - self.g, self.x, self.y, self.NID, self.pis = self.UpdateState(g, x, y, pis, NID, b, k) - @else_ + @if_(k >= (len(self.nids)-1)) def _(): self.label = Array.create_from(s0 < s1) From 79407d8fc874a0d36e52fcdd15ff1613c0aff8fc Mon Sep 17 00:00:00 2001 From: Sandhya Date: Fri, 2 Aug 2024 10:35:45 +0000 Subject: [PATCH 08/13] Modifying decision_tree.py --- Compiler/decision_tree.py | 47 ++++++++++++++++++++++++++---- Programs/Source/custom_data_dt.mpc | 2 +- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/Compiler/decision_tree.py b/Compiler/decision_tree.py index 5f7ac8716..26ec2e7a2 100644 --- a/Compiler/decision_tree.py +++ b/Compiler/decision_tree.py @@ -211,7 +211,6 @@ def CropLayer(k, *v): def TrainLeafNodes(h, g, y, NID): assert len(g) == len(y) assert len(g) == len(NID) - Label = GroupSum(g, y.bit_not()) < GroupSum(g, y) return FormatLayer(h, g, NID, Label) def GroupSame(g, y): @@ -352,6 +351,23 @@ def _(): print_ln('tt=%s', util.reveal(tt)) return a[:], tt[:] + def SetupPerm(self, g, x, y): + m = len(x) + n = len(y) + pis = get_type(y).Matrix(m, n) + @for_range_multithread(self.n_threads, 1, m) + def _(j): + @if_e(self.attr_lengths[j]) + def _(): + pis[j][:] = self.GetInversePermutation(GetSortPerm([x[j]], x[j], y, + n_bits=[1], time=time)) + @else_ + def _(): + pis[j][:] = self.GetInversePermutation(GetSortPerm([x[j]], x[j], y, + n_bits=[None], + time=time)) + return pis + def TrainInternalNodes(self, k, x, y, g, NID): assert len(g) == len(y) for xx in x: @@ -377,12 +393,21 @@ def train_layer(self, k): y = self.y g = self.g NID = self.NID + pis = self.pis if self.debug > 1: print_ln('g=%s', g.reveal()) print_ln('y=%s', y.reveal()) print_ln('x=%s', x.reveal_nested()) - self.nids[k], self.aids[k], self.thresholds[k], b = \ - self.TrainInternalNodes(k, x, y, g, NID) + + s0 = GroupSum(g, y.get_vector().bit_not()) + s1 = GroupSum(g, y.get_vector()) + + a, t = self.TestSelection(g, x, y, pis, s0, s1) + b = self.ApplyTests(x, a, t) + p = SortPerm(g.get_vector().bit_not()) + + self.nids[k], self.aids[k], self.thresholds[k]= FormatLayer_without_crop(g[:], NID, a, t, debug=self.debug) + if self.debug > 1: print_ln('layer %s:', k) for name, data in zip(('NID', 'AID', 'Thr'), @@ -422,6 +447,8 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None, self.x = Matrix.create_from(x) self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)] self.thresholds = self.x.value_type.Matrix(h, n) + self.identity_permutation = sint.Array(n) + self.label = sintbit.Array(n) self.n_threads = n_threads self.debug_selection = False self.debug_threading = False @@ -431,11 +458,19 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None, def train(self): """ Train and return decision tree. """ + n = len(self.y) + + @for_range(n) + def _(i): + self.identity_permutation[i] = sint(i) + h = len(self.nids) + self.pis = self.SetupPerm(self.g, self.x, self.y) + @for_range(h) def _(k): self.train_layer(k) - return self.get_tree(h) + return self.get_tree(h, self.label) def train_with_testing(self, *test_set, output=False): """ Train decision tree and test against test data. @@ -459,12 +494,12 @@ def train_with_testing(self, *test_set, output=False): n_threads=self.n_threads) return tree - def get_tree(self, h): + def get_tree(self, h, Label): Layer = [None] * (h + 1) for k in range(h): Layer[k] = CropLayer(k, self.nids[k], self.aids[k], self.thresholds[k]) - Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID) + Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID, Label) return Layer def DecisionTreeTraining(x, y, h, binary=False): diff --git a/Programs/Source/custom_data_dt.mpc b/Programs/Source/custom_data_dt.mpc index 77c9bc3b6..ba7915c65 100644 --- a/Programs/Source/custom_data_dt.mpc +++ b/Programs/Source/custom_data_dt.mpc @@ -19,7 +19,7 @@ df_y = Array.create_from(df_y[0]) program.set_bit_length(32) sfix.set_precision(16, 31) -from Compiler.decision_tree_optimized import TreeClassifier +from Compiler.decision_tree import TreeClassifier tree = TreeClassifier(max_depth=int(program.args[3]), n_threads=4) From 0a9d5e8aaf0e792ff7caec19853cc0f1e0d56084 Mon Sep 17 00:00:00 2001 From: sandy9999 Date: Sat, 3 Aug 2024 21:52:12 +0530 Subject: [PATCH 09/13] Unification of decision_tree and decision_tree_optimized in progress --- Compiler/decision_tree.py | 199 +++++++++++++++++--------------------- 1 file changed, 87 insertions(+), 112 deletions(-) diff --git a/Compiler/decision_tree.py b/Compiler/decision_tree.py index 26ec2e7a2..c5ffb7ab8 100644 --- a/Compiler/decision_tree.py +++ b/Compiler/decision_tree.py @@ -157,32 +157,35 @@ def GroupMax(g, keys, *x): util.reveal(t), util.reveal(keys), util.reveal(x)) return [GroupSum(g, t[:] * xx) for xx in [keys] + x] -def ModifiedGini(g, y, debug=False): +def ComputeGini(g, x, y, notysum, ysum, debug=False): assert len(g) == len(y) y = [y.get_vector().bit_not(), y] u = [GroupPrefixSum(g, yy) for yy in y] - s = [GroupSum(g, yy) for yy in y] + s = [notysum, ysum] w = [ss - uu for ss, uu in zip(s, u)] us = sum(u) ws = sum(w) uqs = u[0] ** 2 + u[1] ** 2 wqs = w[0] ** 2 + w[1] ** 2 - res = sfix(uqs) / us + sfix(wqs) / ws - if debug: - print_ln('g=%s y=%s s=%s', - util.reveal(g), util.reveal(y), - util.reveal(s)) - print_ln('u0=%s', util.reveal(u[0])) - print_ln('u0=%s', util.reveal(u[1])) - print_ln('us=%s', util.reveal(us)) - print_ln('w0=%s', util.reveal(w[0])) - print_ln('w1=%s', util.reveal(w[1])) - print_ln('ws=%s', util.reveal(ws)) - print_ln('uqs=%s', util.reveal(uqs)) - print_ln('wqs=%s', util.reveal(wqs)) - if debug: - print_ln('gini %s %s', type(res), util.reveal(res)) - return res + res_num = ws * uqs + us * wqs + res_den = us * ws + xx = x + t = get_type(x).Array(len(x)) + t[-1] = MIN_VALUE + t.assign_vector(xx.get_vector(size=len(x) - 1) + \ + xx.get_vector(size=len(x) - 1, base=1)) + gg = g + p = sint.Array(len(x)) + p[-1] = 1 + p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( + xx.get_vector(size=len(x) - 1) == \ + xx.get_vector(size=len(x) - 1, base=1))) + break_point() + res_num = p[:].if_else(MIN_VALUE, res_num) + res_den = p[:].if_else(1, res_den) + t = p[:].if_else(MIN_VALUE, t[:]) + return res_num, res_den, t + MIN_VALUE = -10000 @@ -243,6 +246,11 @@ class TreeTrainer: .. _`Hamada et al.`: https://arxiv.org/abs/2112.12906 """ + def GetInversePermutation(self, perm): + res = Array.create_from(self.identity_permutation) + reveal_sort(perm, res) + return res + def ApplyTests(self, x, AID, Threshold): m = len(x) n = len(AID) @@ -260,96 +268,49 @@ def _(j): print_ln('threshold %s', util.reveal(Threshold)) return 2 * xx < Threshold - def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False): - assert len(g) == len(x) - assert len(g) == len(y) - if time: - start_timer(2) - s = ModifiedGini(g, y, debug=debug or self.debug > 2) - if time: - stop_timer(2) - if debug or self.debug > 1: - print_ln('gini %s', s.reveal()) - xx = x - t = get_type(x).Array(len(x)) - t[-1] = MIN_VALUE - t.assign_vector(xx.get_vector(size=len(x) - 1) + \ - xx.get_vector(size=len(x) - 1, base=1)) - gg = g - p = sint.Array(len(x)) - p[-1] = 1 - p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( - xx.get_vector(size=len(x) - 1) == \ - xx.get_vector(size=len(x) - 1, base=1))) - break_point() - if debug: - print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p)) - s = p[:].if_else(MIN_VALUE, s) - t = p[:].if_else(MIN_VALUE, t[:]) - if debug: - print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) - if time: - start_timer(3) - s, t = GroupMax(gg, s, t) - if time: - stop_timer(3) - if debug: - print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) - return t, s - - def GlobalTestSelection(self, x, y, g): - assert len(y) == len(g) + def TestSelection(self, g, x, y, pis, notysum, ysum, time=False): for xx in x: assert(len(xx) == len(g)) + assert len(g) == len(y) m = len(x) n = len(y) + gg = g u, t = [get_type(x).Matrix(m, n) for i in range(2)] v = get_type(y).Matrix(m, n) - s = sfix.Matrix(m, n) + s_num = get_type(y).Matrix(m, n) + s_den = get_type(y).Matrix(m, n) + a = sint.Array(n) + + notysum_arr = Array.create_from(notysum) + ysum_arr = Array.create_from(ysum) + @for_range_multithread(self.n_threads, 1, m) def _(j): single = not self.n_threads or self.n_threads == 1 time = self.time and single - if debug: + if self.debug_selection: print_ln('run %s', j) - @if_e(self.attr_lengths[j]) - def _(): - u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, - n_bits=[util.log2(n), 1], time=time) - @else_ - def _(): - u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, - n_bits=[util.log2(n), None], - time=time) - if self.debug_threading: - print_ln('global sort %s %s %s', j, util.reveal(u[j]), - util.reveal(v[j])) - t[j][:], s[j][:] = self.AttributeWiseTestSelection( - g, u[j], v[j], time=time, debug=self.debug_selection) - if self.debug_threading: - print_ln('global attribute %s %s %s', j, util.reveal(t[j]), - util.reveal(s[j])) - n = len(g) - a = sint.Array(n) - if self.debug_threading: - print_ln('global s=%s', util.reveal(s)) - if self.debug_gini: - print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)), - *(ss[0].reveal() for ss in s)) - if self.time: - start_timer(4) - if self.debug > 1: - print_ln('s=%s', s.reveal_nested()) - print_ln('t=%s', t.reveal_nested()) - a[:], tt = VectMax((s[j][:] for j in range(m)), range(m), - (t[j][:] for j in range(m)), debug=self.debug > 1) - tt = Array.create_from(tt) - if self.time: - stop_timer(4) - if self.debug > 1: - print_ln('a=%s', util.reveal(a)) - print_ln('tt=%s', util.reveal(tt)) - return a[:], tt[:] + u[j].assign_vector(x[j]) + v[j].assign_vector(y) + reveal_sort(pis[j], u[j]) + reveal_sort(pis[j], v[j]) + s_num[j][:], s_den[j][:], t[j][:] = ComputeGini(g, u[j], v[j], notysum_arr, ysum_arr, debug=False) + + ss_num, ss_den, tt, aa = VectMax((s_num[j][:] for j in range(m)), (s_den[j][:] for j in range(m)), (t[j][:] for j in range(m)), range(m), debug=self.debug) + + aaa = get_type(y).Array(n) + ttt = get_type(x).Array(n) + + GroupMax_num, GroupMax_den, GroupMax_ttt, GroupMax_aaa = GroupMax(g, ss_num, ss_den, tt, aa) + + f = sint.Array(n) + f = (self.zeros.get_vector() == notysum).bit_or(self.zeros.get_vector() == ysum) + aaa_vector, ttt_vector = f.if_else(0, GroupMax_aaa), f.if_else(MIN_VALUE, GroupMax_ttt) + + ttt.assign_vector(ttt_vector) + aaa.assign_vector(aaa_vector) + + return aaa, ttt def SetupPerm(self, g, x, y): m = len(x) @@ -368,6 +329,30 @@ def _(): time=time)) return pis + def UpdateState(self, g, x, y, pis, NID, b, k): + m = len(x) + n = len(y) + q = SortPerm(b) + + y[:] = q.apply(y) + NID[:] = 2 ** k * b + NID + NID[:] = q.apply(NID) + g[:] = GroupFirstOne(g, b.bit_not()) + GroupFirstOne(g, b) + g[:] = q.apply(g) + + b_arith = sint.Array(n) + b_arith = Array.create_from(b) + + @for_range_multithread(self.n_threads, 1, m) + def _(j): + x[j][:] = q.apply(x[j]) + b_permuted = ApplyPermutation(pis[j], b_arith) + + pis[j] = q.apply(pis[j]) + pis[j] = ApplyInversePermutation(pis[j], SortPerm(b_permuted).perm) + + return [g, x, y, NID, pis] + def TrainInternalNodes(self, k, x, y, g, NID): assert len(g) == len(y) for xx in x: @@ -407,21 +392,11 @@ def train_layer(self, k): p = SortPerm(g.get_vector().bit_not()) self.nids[k], self.aids[k], self.thresholds[k]= FormatLayer_without_crop(g[:], NID, a, t, debug=self.debug) - - if self.debug > 1: - print_ln('layer %s:', k) - for name, data in zip(('NID', 'AID', 'Thr'), - (self.nids[k], self.aids[k], - self.thresholds[k])): - print_ln(' %s: %s', name, data.reveal()) - NID[:] = 2 ** k * b + NID - b_not = b.bit_not() - if self.debug > 1: - print_ln('b_not=%s', b_not.reveal()) - g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b) - y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1]) - for i, xxx in enumerate(xx): - x[i] = xxx + self.g, self.x, self.y, self.NID, self.pis = self.UpdateState(g, x, y, pis, NID, b, k) + + @if_(k >= (len(self.nids)-1)) + def _(): + self.label = Array.create_from(s0 < s1) def __init__(self, x, y, h, binary=False, attr_lengths=None, n_threads=None): From 0a4417d632d3c4ed58ca51f159d3d8565fef589e Mon Sep 17 00:00:00 2001 From: sandy9999 Date: Mon, 19 Aug 2024 12:37:37 +0530 Subject: [PATCH 10/13] Unified decision_tree and decision_tree_optimized, testing due --- Compiler/decision_tree.py | 131 +++++++++++++++----------------------- 1 file changed, 53 insertions(+), 78 deletions(-) diff --git a/Compiler/decision_tree.py b/Compiler/decision_tree.py index c5ffb7ab8..9b1a8b938 100644 --- a/Compiler/decision_tree.py +++ b/Compiler/decision_tree.py @@ -10,6 +10,31 @@ debug_split = False max_leaves = None +def GetSortPerm(keys, *to_sort, n_bits=None, time=False): + """ + Compute and return secret shared permutation that stably sorts :param keys. + """ + for k in keys: + assert len(k) == len(keys[0]) + n_bits = n_bits or [None] * len(keys) + bs = Matrix.create_from(sum([k.get_vector().bit_decompose(nb) + for k, nb in reversed(list(zip(keys, n_bits)))], [])) + get_vec = lambda x: x[:] if isinstance(x, Array) else x + res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x + for x in to_sort) + res = res.transpose() + return radix_sort_from_matrix(bs, res) + +def ApplyPermutation(perm, x): + res = Array.create_from(x) + reveal_sort(perm, res, False) + return res + +def ApplyInversePermutation(perm, x): + res = Array.create_from(x) + reveal_sort(perm, res, True) + return res + def get_type(x): if isinstance(x, (Array, SubMultiArray)): return x.value_type @@ -22,6 +47,12 @@ def get_type(x): else: return type(x) + +def Custom_GT_Fractions(x_num, x_den, y_num, y_den, n_threads=2): + b = (x_num*y_den) > (x_den*y_num) + b = Array.create_from(b).get_vector() + return b + def PrefixSum(x): return x.get_vector().prefix_sum() @@ -86,17 +117,9 @@ def Sort(keys, *to_sort, n_bits=None, time=False): def VectMax(key, *data, debug=False): def reducer(x, y): - b = x[0] > y[0] - if debug: - print_ln('max b=%s', b.reveal()) + b = x[0]*y[1] > y[0]*x[1] return [b.if_else(xx, yy) for xx, yy in zip(x, y)] - if debug: - key = list(key) - data = [list(x) for x in data] - print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data)) res = util.tree_reduce(reducer, zip(key, *data))[1:] - if debug: - print_ln('vect max res=%s', util.reveal(res)) return res def GroupSum(g, x): @@ -119,9 +142,6 @@ def GroupPrefixSum(g, x): return s.get_vector(size=len(x), base=1) - GroupSum(g, q) def GroupMax(g, keys, *x): - if debug: - print_ln('group max input g=%s keys=%s x=%s', util.reveal(g), - util.reveal(keys), util.reveal(x)) assert len(keys) == len(g) for xx in x: assert len(xx) == len(g) @@ -138,23 +158,17 @@ def GroupMax(g, keys, *x): vsize = n - w g_new.assign_vector(g_old.get_vector(size=vsize).bit_or( g_old.get_vector(size=vsize, base=w)), base=w) - b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w) + b = Custom_GT_Fractions(keys.get_vector(size=vsize), x[0].get_vector(size=vsize), keys.get_vector(size=vsize, base=w), x[0].get_vector(size=vsize, base=w)) + #b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w) for xx in [keys] + x: a = b.if_else(xx.get_vector(size=vsize), xx.get_vector(size=vsize, base=w)) xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else( xx.get_vector(size=vsize, base=w), a), base=w) break_point() - if debug: - print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(), - util.reveal(a), util.reveal(keys), - util.reveal(x), g_new.reveal()) t = sint.Array(len(g)) t[-1] = 1 t.assign_vector(g.get_vector(size=n - 1, base=1)) - if debug: - print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g), - util.reveal(t), util.reveal(keys), util.reveal(x)) return [GroupSum(g, t[:] * xx) for xx in [keys] + x] def ComputeGini(g, x, y, notysum, ysum, debug=False): @@ -196,12 +210,8 @@ def FormatLayer_without_crop(g, *a, debug=False): for x in a: assert len(x) == len(g) v = [g.if_else(aa, 0) for aa in a] - if debug: - print_ln('format in %s', util.reveal(a)) - print_ln('format mux %s', util.reveal(v)) - v = Sort([g.bit_not()], *v, n_bits=[1]) - if debug: - print_ln('format sort %s', util.reveal(v)) + p = SortPerm(g.get_vector().bit_not()) + v = [p.apply(vv) for vv in v] return v def CropLayer(k, *v): @@ -216,36 +226,12 @@ def TrainLeafNodes(h, g, y, NID): assert len(g) == len(NID) return FormatLayer(h, g, NID, Label) -def GroupSame(g, y): - assert len(g) == len(y) - s = GroupSum(g, [sint(1)] * len(g)) - s0 = GroupSum(g, y.bit_not()) - s1 = GroupSum(g, y) - if debug_split: - print_ln('group same g=%s', util.reveal(g)) - print_ln('group same y=%s', util.reveal(y)) - return (s == s0).bit_or(s == s1) - def GroupFirstOne(g, b): assert len(g) == len(b) s = GroupPrefixSum(g, b) return s * b == 1 class TreeTrainer: - """ Decision tree training by `Hamada et al.`_ - - :param x: sample data (by attribute, list or - :py:obj:`~Compiler.types.Matrix`) - :param y: binary labels (list or sint vector) - :param h: height (int) - :param binary: binary attributes instead of continuous - :param attr_lengths: attribute description for mixed data - (list of 0/1 for continuous/binary) - :param n_threads: number of threads (default: single thread) - - .. _`Hamada et al.`: https://arxiv.org/abs/2112.12906 - - """ def GetInversePermutation(self, perm): res = Array.create_from(self.identity_permutation) reveal_sort(perm, res) @@ -258,14 +244,10 @@ def ApplyTests(self, x, AID, Threshold): for xx in x: assert len(xx) == len(AID) e = sint.Matrix(m, n) - AID = Array.create_from(AID) @for_range_multithread(self.n_threads, 1, m) def _(j): e[j][:] = AID[:] == j xx = sum(x[j] * e[j] for j in range(m)) - if self.debug > 1: - print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx)) - print_ln('threshold %s', util.reveal(Threshold)) return 2 * xx < Threshold def TestSelection(self, g, x, y, pis, notysum, ysum, time=False): @@ -353,25 +335,6 @@ def _(j): return [g, x, y, NID, pis] - def TrainInternalNodes(self, k, x, y, g, NID): - assert len(g) == len(y) - for xx in x: - assert len(xx) == len(g) - AID, Threshold = self.GlobalTestSelection(x, y, g) - s = GroupSame(g[:], y[:]) - if self.debug > 1 or debug_split: - print_ln('AID=%s', util.reveal(AID)) - print_ln('Threshold=%s', util.reveal(Threshold)) - print_ln('GroupSame=%s', util.reveal(s)) - AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold) - if self.debug > 1 or debug_split: - print_ln('AID=%s', util.reveal(AID)) - print_ln('Threshold=%s', util.reveal(Threshold)) - b = self.ApplyTests(x, AID, Threshold) - layer = FormatLayer_without_crop(g[:], NID, AID, Threshold, - debug=self.debug > 1) - return *layer, b - @method_block def train_layer(self, k): x = self.x @@ -379,10 +342,6 @@ def train_layer(self, k): g = self.g NID = self.NID pis = self.pis - if self.debug > 1: - print_ln('g=%s', g.reveal()) - print_ln('y=%s', y.reveal()) - print_ln('x=%s', x.reveal_nested()) s0 = GroupSum(g, y.get_vector().bit_not()) s1 = GroupSum(g, y.get_vector()) @@ -400,6 +359,21 @@ def _(): def __init__(self, x, y, h, binary=False, attr_lengths=None, n_threads=None): + """ Securely Training Decision Trees Efficiently by `Bhardwaj et al.`_ : https://eprint.iacr.org/2024/1077.pdf + + This protocol has communication complexity O( mN logN + hmN + hN log N) which is an improvement of ~min(h, m, log N) over `Hamada et al.`_ : https://petsymposium.org/popets/2023/popets-2023-0021.pdf + + To run this protocol, at the root of the MP-SPDZ repo, run Scripts/compile-run.py -H HOSTS -E ring custom_data_dt $((2**13)) 11 4 -Z 3 -R 128 + + :param x: Attribute values + :param y: Binary labels + :param h: Height of the decision tree + :param binary: Binary attributes instead of continuous + :param attr_lengths: Attribute description for mixed data + (list of 0/1 for continuous/binary) + :param n_threads: Number of threads + + """ assert not (binary and attr_lengths) if binary: attr_lengths = [1] * len(x) @@ -412,6 +386,7 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None, Matrix.disable_index_checks() for xx in x: assert len(xx) == len(y) + m = len(x) n = len(y) self.g = sint.Array(n) self.g.assign_all(0) @@ -459,7 +434,7 @@ def train_with_testing(self, *test_set, output=False): """ for k in range(len(self.nids)): self.train_layer(k) - tree = self.get_tree(k + 1) + tree = self.get_tree(k + 1, self.label) if output: output_decision_tree(tree) test_decision_tree('train', tree, self.y, self.x, From a49bba9ab8f36339e27f1eefd8608aa5199e1f6f Mon Sep 17 00:00:00 2001 From: sandy9999 Date: Mon, 23 Dec 2024 12:09:33 +0000 Subject: [PATCH 11/13] Fixed bugs in DT integration --- Compiler/decision_tree.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Compiler/decision_tree.py b/Compiler/decision_tree.py index 9b1a8b938..1589ca955 100644 --- a/Compiler/decision_tree.py +++ b/Compiler/decision_tree.py @@ -119,7 +119,7 @@ def VectMax(key, *data, debug=False): def reducer(x, y): b = x[0]*y[1] > y[0]*x[1] return [b.if_else(xx, yy) for xx, yy in zip(x, y)] - res = util.tree_reduce(reducer, zip(key, *data))[1:] + res = util.tree_reduce(reducer, zip(key, *data)) return res def GroupSum(g, x): @@ -221,7 +221,7 @@ def CropLayer(k, *v): n = 2 ** k return [vv[:min(n, len(vv))] for vv in v] -def TrainLeafNodes(h, g, y, NID): +def TrainLeafNodes(h, g, y, NID, Label, debug=False): assert len(g) == len(y) assert len(g) == len(NID) return FormatLayer(h, g, NID, Label) @@ -248,7 +248,7 @@ def ApplyTests(self, x, AID, Threshold): def _(j): e[j][:] = AID[:] == j xx = sum(x[j] * e[j] for j in range(m)) - return 2 * xx < Threshold + return 2 * xx.get_vector() < Threshold.get_vector() def TestSelection(self, g, x, y, pis, notysum, ysum, time=False): for xx in x: @@ -395,10 +395,13 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None, self.NID.assign_all(1) self.y = Array.create_from(y) self.x = Matrix.create_from(x) + self.pis = sint.Matrix(m, n) self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)] self.thresholds = self.x.value_type.Matrix(h, n) self.identity_permutation = sint.Array(n) self.label = sintbit.Array(n) + self.zeros = sint.Array(n) + self.zeros.assign_all(0) self.n_threads = n_threads self.debug_selection = False self.debug_threading = False From 774bad610a21e81a701446ec2eedfc35ad9f3e80 Mon Sep 17 00:00:00 2001 From: sandy9999 Date: Mon, 23 Dec 2024 12:44:50 +0000 Subject: [PATCH 12/13] Removed extraneous files --- Compiler/decision_tree_optimized.py | 431 ---------------------------- Programs/Source/custom_data_dt.mpc | 29 -- 2 files changed, 460 deletions(-) delete mode 100644 Compiler/decision_tree_optimized.py delete mode 100644 Programs/Source/custom_data_dt.mpc diff --git a/Compiler/decision_tree_optimized.py b/Compiler/decision_tree_optimized.py deleted file mode 100644 index b7566906d..000000000 --- a/Compiler/decision_tree_optimized.py +++ /dev/null @@ -1,431 +0,0 @@ -from Compiler.types import * -from Compiler.sorting import * -from Compiler.library import * -from Compiler.decision_tree import get_type, PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne, output_decision_tree, pick, run_decision_tree, test_decision_tree -from Compiler import util, oram - -from itertools import accumulate -import math - -debug = False -debug_split = False -max_leaves = None - -def GetSortPerm(keys, *to_sort, n_bits=None, time=False): - """ - Compute and return secret shared permutation that stably sorts :param keys. - """ - for k in keys: - assert len(k) == len(keys[0]) - n_bits = n_bits or [None] * len(keys) - bs = Matrix.create_from(sum([k.get_vector().bit_decompose(nb) - for k, nb in reversed(list(zip(keys, n_bits)))], [])) - get_vec = lambda x: x[:] if isinstance(x, Array) else x - res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x - for x in to_sort) - res = res.transpose() - return radix_sort_from_matrix(bs, res) - -def ApplyPermutation(perm, x): - res = Array.create_from(x) - reveal_sort(perm, res, False) - return res - -def ApplyInversePermutation(perm, x): - res = Array.create_from(x) - reveal_sort(perm, res, True) - return res - -def VectMax(key, *data, debug=False): - def reducer(x, y): - b = x[0]*y[1] > y[0]*x[1] - return [b.if_else(xx, yy) for xx, yy in zip(x, y)] - res = util.tree_reduce(reducer, zip(key, *data)) - return res - -def Custom_GT_Fractions(x_num, x_den, y_num, y_den, n_threads=2): - b = (x_num*y_den) > (x_den*y_num) - b = Array.create_from(b).get_vector() - return b - -def GroupMax(g, keys, *x, debug=False): - assert len(keys) == len(g) - for xx in x: - assert len(xx) == len(g) - n = len(g) - m = int(math.ceil(math.log(n, 2))) - keys = Array.create_from(keys) - x = [Array.create_from(xx) for xx in x] - g_new = Array.create_from(g) - g_old = g_new.same_shape() - for d in range(m): - w = 2 ** d - g_old[:] = g_new[:] - break_point() - vsize = n - w - g_new.assign_vector(g_old.get_vector(size=vsize).bit_or( - g_old.get_vector(size=vsize, base=w)), base=w) - b = Custom_GT_Fractions(keys.get_vector(size=vsize), x[0].get_vector(size=vsize), keys.get_vector(size=vsize, base=w), x[0].get_vector(size=vsize, base=w)) - for xx in [keys] + x: - a = b.if_else(xx.get_vector(size=vsize), - xx.get_vector(size=vsize, base=w)) - xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else( - xx.get_vector(size=vsize, base=w), a), base=w) - break_point() - t = sint.Array(len(g)) - t[-1] = 1 - t.assign_vector(g.get_vector(size=n - 1, base=1)) - return [GroupSum(g, t[:] * xx) for xx in [keys] + x] - -def ComputeGini(g, x, y, notysum, ysum, debug=False): - assert len(g) == len(y) - y = [y.get_vector().bit_not(), y] - u = [GroupPrefixSum(g, yy) for yy in y] - s = [notysum, ysum] - w = [ss - uu for ss, uu in zip(s, u)] - us = sum(u) - ws = sum(w) - uqs = u[0] ** 2 + u[1] ** 2 - wqs = w[0] ** 2 + w[1] ** 2 - res_num = ws * uqs + us * wqs - res_den = us * ws - xx = x - t = get_type(x).Array(len(x)) - t[-1] = MIN_VALUE - t.assign_vector(xx.get_vector(size=len(x) - 1) + \ - xx.get_vector(size=len(x) - 1, base=1)) - gg = g - p = sint.Array(len(x)) - p[-1] = 1 - p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( - xx.get_vector(size=len(x) - 1) == \ - xx.get_vector(size=len(x) - 1, base=1))) - break_point() - res_num = p[:].if_else(MIN_VALUE, res_num) - res_den = p[:].if_else(1, res_den) - t = p[:].if_else(MIN_VALUE, t[:]) - return res_num, res_den, t - -MIN_VALUE = -10000 - -def FormatLayer(h, g, *a, debug=False): - return CropLayer(h, *FormatLayer_without_crop(g, *a, debug=debug)) - -def FormatLayer_without_crop(g, *a, debug=False): - for x in a: - assert len(x) == len(g) - v = [g.if_else(aa, 0) for aa in a] - p = SortPerm(g.get_vector().bit_not()) - v = [p.apply(vv) for vv in v] - return v - -def CropLayer(k, *v): - if max_leaves: - n = min(2 ** k, max_leaves) - else: - n = 2 ** k - return [vv[:min(n, len(vv))] for vv in v] - -def TrainLeafNodes(h, g, y, NID, Label, debug=False): - assert len(g) == len(y) - assert len(g) == len(NID) - return FormatLayer(h, g, NID, Label, debug=debug) - -class TreeTrainer: - def GetInversePermutation(self, perm, n_threads=2): - res = Array.create_from(self.identity_permutation) - reveal_sort(perm, res) - return res - - def ApplyTests(self, x, AID, Threshold): - m = len(x) - n = len(AID) - assert len(AID) == len(Threshold) - for xx in x: - assert len(xx) == len(AID) - e = sint.Matrix(m, n) - - @for_range_multithread(self.n_threads, 1, m) - def _(j): - e[j][:] = AID[:] == j - xx = sum(x[j]*e[j] for j in range(m)) - - return 2 * xx.get_vector() < Threshold.get_vector() - - def TestSelection(self, g, x, y, pis, notysum, ysum, time=False): - for xx in x: - assert(len(xx) == len(g)) - assert len(g) == len(y) - m = len(x) - n = len(y) - gg = g - u, t = [get_type(x).Matrix(m, n) for i in range(2)] - v = get_type(y).Matrix(m, n) - s_num = get_type(y).Matrix(m, n) - s_den = get_type(y).Matrix(m, n) - a = sint.Array(n) - - notysum_arr = Array.create_from(notysum) - ysum_arr = Array.create_from(ysum) - - @for_range_multithread(self.n_threads, 1, m) - def _(j): - single = not self.n_threads or self.n_threads == 1 - time = self.time and single - if self.debug_selection: - print_ln('run %s', j) - u[j].assign_vector(x[j]) - v[j].assign_vector(y) - reveal_sort(pis[j], u[j]) - reveal_sort(pis[j], v[j]) - s_num[j][:], s_den[j][:], t[j][:] = ComputeGini(g, u[j], v[j], notysum_arr, ysum_arr, debug=False) - - ss_num, ss_den, tt, aa = VectMax((s_num[j][:] for j in range(m)), (s_den[j][:] for j in range(m)), (t[j][:] for j in range(m)), range(m), debug=self.debug) - - aaa = get_type(y).Array(n) - ttt = get_type(x).Array(n) - - GroupMax_num, GroupMax_den, GroupMax_ttt, GroupMax_aaa = GroupMax(g, ss_num, ss_den, tt, aa) - - f = sint.Array(n) - f = (self.zeros.get_vector() == notysum).bit_or(self.zeros.get_vector() == ysum) - aaa_vector, ttt_vector = f.if_else(0, GroupMax_aaa), f.if_else(MIN_VALUE, GroupMax_ttt) - - ttt.assign_vector(ttt_vector) - aaa.assign_vector(aaa_vector) - - return aaa, ttt - - def SetupPerm(self, g, x, y): - m = len(x) - n = len(y) - pis = get_type(y).Matrix(m, n) - @for_range_multithread(self.n_threads, 1, m) - def _(j): - @if_e(self.attr_lengths[j]) - def _(): - pis[j][:] = self.GetInversePermutation(GetSortPerm([x[j]], x[j], y, - n_bits=[1], time=time)) - @else_ - def _(): - pis[j][:] = self.GetInversePermutation(GetSortPerm([x[j]], x[j], y, - n_bits=[None], - time=time)) - return pis - - def UpdateState(self, g, x, y, pis, NID, b, k): - m = len(x) - n = len(y) - q = SortPerm(b) - - y[:] = q.apply(y) - NID[:] = 2 ** k * b + NID - NID[:] = q.apply(NID) - g[:] = GroupFirstOne(g, b.bit_not()) + GroupFirstOne(g, b) - g[:] = q.apply(g) - - b_arith = sint.Array(n) - b_arith = Array.create_from(b) - - @for_range_multithread(self.n_threads, 1, m) - def _(j): - x[j][:] = q.apply(x[j]) - b_permuted = ApplyPermutation(pis[j], b_arith) - - pis[j] = q.apply(pis[j]) - pis[j] = ApplyInversePermutation(pis[j], SortPerm(b_permuted).perm) - - return [g, x, y, NID, pis] - - @method_block - def train_layer(self, k): - g = self.g - x = self.x - y = self.y - NID = self.NID - pis = self.pis - s0 = GroupSum(g, y.get_vector().bit_not()) - s1 = GroupSum(g, y.get_vector()) - a, t = self.TestSelection(g, x, y, pis, s0, s1) - b = self.ApplyTests(x, a, t) - p = SortPerm(g.get_vector().bit_not()) - self.nids[k], self.aids[k], self.thresholds[k]= FormatLayer_without_crop(g[:], NID, a, t, debug=self.debug) - self.g, self.x, self.y, self.NID, self.pis = self.UpdateState(g, x, y, pis, NID, b, k) - - @if_(k >= (len(self.nids)-1)) - def _(): - self.label = Array.create_from(s0 < s1) - - def __init__(self, x, y, h, binary=False, attr_lengths=None, - n_threads=None): - """ Securely Training Decision Trees Efficiently by `Bhardwaj et al.`_ : https://eprint.iacr.org/2024/1077.pdf - - This protocol has communication complexity O( mN logN + hmN + hN log N) which is an improvement of ~min(h, m, log N) over `Hamada et al.`_ : https://petsymposium.org/popets/2023/popets-2023-0021.pdf - - To run this protocol, at the root of the MP-SPDZ repo, run Scripts/compile-run.py -H HOSTS -E ring custom_data_dt $((2**13)) 11 4 -Z 3 -R 128 - - :param x: Attribute values - :param y: Binary labels - :param h: Height of the decision tree - :param binary: Binary attributes instead of continuous - :param attr_lengths: Attribute description for mixed data - (list of 0/1 for continuous/binary) - :param n_threads: Number of threads - - """ - assert not (binary and attr_lengths) - if binary: - attr_lengths = [1] * len(x) - else: - attr_lengths = attr_lengths or ([0] * len(x)) - for l in attr_lengths: - assert l in (0, 1) - self.attr_lengths = Array.create_from(regint(attr_lengths)) - Array.check_indices = False - Matrix.disable_index_checks() - for xx in x: - assert len(xx) == len(y) - m = len(x) - n = len(y) - self.g = sint.Array(n) - self.g.assign_all(0) - self.g[0] = 1 - self.NID = sint.Array(n) - self.NID.assign_all(1) - self.y = Array.create_from(y) - self.x = Matrix.create_from(x) - self.pis = sint.Matrix(m, n) - self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)] - self.thresholds = self.x.value_type.Matrix(h, n) - self.identity_permutation = sint.Array(n) - self.label = sintbit.Array(n) - self.zeros = sint.Array(n) - self.zeros.assign_all(0) - self.n_threads = n_threads - self.debug_selection = False - self.debug_threading = True - self.debug_gini = False - self.debug_init = False - self.debug_vectmax = False - self.debug = False - self.time = False - - def train(self): - """ Train and return decision tree. """ - n = len(self.y) - - @for_range(n) - def _(i): - self.identity_permutation[i] = sint(i) - - h = len(self.nids) - - self.pis = self.SetupPerm(self.g, self.x, self.y) - - @for_range(h) - def _(k): - self.train_layer(k) - return self.get_tree(h, self.label) - - def train_with_testing(self, *test_set, output=False): - """ Train decision tree and test against test data. - - :param y: binary labels (list or sint vector) - :param x: sample data (by attribute, list or - :py:obj:`~Compiler.types.Matrix`) - :param output: output tree after every level - :returns: tree - - """ - for k in range(len(self.nids)): - self.train_layer(k) - tree = self.get_tree(k + 1, self.label) - if output: - output_decision_tree(tree) - test_decision_tree('train', tree, self.y, self.x, - n_threads=self.n_threads) - if test_set: - test_decision_tree('test', tree, *test_set, - n_threads=self.n_threads) - return tree - - def get_tree(self, h, Label): - Layer = [None] * (h + 1) - for k in range(h): - Layer[k] = CropLayer(k, self.nids[k], self.aids[k], - self.thresholds[k]) - Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID, Label, debug=self.debug) - return Layer - -def DecisionTreeTraining(x, y, h, binary=False): - return TreeTrainer(x, y, h, binary=binary).train() - -class TreeClassifier: - """ Tree classification that uses - :py:class:`TreeTrainer` internally. - - :param max_depth: Depth of decision tree - :param n_threads: Number of threads used - - """ - def __init__(self, max_depth, n_threads=None): - self.max_depth = max_depth - self.n_threads = n_threads - - @staticmethod - def get_attr_lengths(attr_types): - if attr_types == None: - return None - else: - return [1 if x == 'b' else 0 for x in attr_types] - - def fit(self, X, y, attr_types=None): - """ Train tree. - - :param X: Attribute values - :param y: Binary labels - - """ - self.tree = TreeTrainer( - X.transpose(), y, self.max_depth, - attr_lengths=self.get_attr_lengths(attr_types), - n_threads=self.n_threads).train() - - def output(self): - output_decision_tree(self.tree) - - def fit_with_testing(self, X_train, y_train, X_test, y_test, - attr_types=None, output_trees=False, debug=False): - """ Train tree with accuracy output after every level. - - :param X_train: training data with row-wise samples (sint/sfix matrix) - :param y_train: training binary labels (sint list/array) - :param X_test: testing data with row-wise samples (sint/sfix matrix) - :param y_test: testing binary labels (sint list/array) - :param attr_types: attributes types (list of 'b'/'c' for - binary/continuous; default is all continuous) - :param output_trees: output tree after every level - :param debug: output debugging information - - """ - trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth, - attr_lengths=self.get_attr_lengths(attr_types), - n_threads=self.n_threads) - trainer.debug = debug - trainer.debug_gini = debug - trainer.debug_threading = debug > 1 - self.tree = trainer.train_with_testing(y_test, X_test.transpose(), - output=output_trees) - - def predict(self, X): - """ Use tree for prediction. - - :param X: sample data with row-wise samples (sint/sfix matrix) - :returns: sint array - - """ - res = sint.Array(len(X)) - @for_range(len(X)) - def _(i): - res[i] = run_decision_tree(self.tree, X[i]) - return res diff --git a/Programs/Source/custom_data_dt.mpc b/Programs/Source/custom_data_dt.mpc deleted file mode 100644 index ba7915c65..000000000 --- a/Programs/Source/custom_data_dt.mpc +++ /dev/null @@ -1,29 +0,0 @@ -from sklearn.datasets import load_breast_cancer -from sklearn.model_selection import train_test_split -import pandas as pd -import random -import numpy as np - -m = int(program.args[2]) -n = int(program.args[1]) - -data_x = np.random.uniform(0, 10, (n, m)) -data_y = np.random.randint(2, size=(1, n)) -df_x = pd.DataFrame(data_x) -df_y = pd.DataFrame(data_y) - -df_x = sfix.input_tensor_via(0, df_x) -df_y = sint.input_tensor_via(0, df_y) -df_y = Array.create_from(df_y[0]) - -program.set_bit_length(32) -sfix.set_precision(16, 31) - -from Compiler.decision_tree import TreeClassifier - -tree = TreeClassifier(max_depth=int(program.args[3]), n_threads=4) - -tree.fit(df_x, df_y) - -# output tree -tree.output() From 267bc113905526ff56766c03d1291ca5b2a04422 Mon Sep 17 00:00:00 2001 From: sandy9999 Date: Mon, 23 Dec 2024 12:46:24 +0000 Subject: [PATCH 13/13] Reverted breast_tree to original --- Programs/Source/breast_tree.mpc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Programs/Source/breast_tree.mpc b/Programs/Source/breast_tree.mpc index 3401e6a58..9214584bb 100644 --- a/Programs/Source/breast_tree.mpc +++ b/Programs/Source/breast_tree.mpc @@ -16,7 +16,7 @@ y_test = sint.input_tensor_via(0, y_test) sfix.set_precision_from_args(program) -from Compiler.decision_tree_optimized import TreeClassifier +from Compiler.decision_tree import TreeClassifier tree = TreeClassifier(max_depth=5, n_threads=2)