diff --git a/Compiler/decision_tree.py b/Compiler/decision_tree.py index 5f7ac8716..1589ca955 100644 --- a/Compiler/decision_tree.py +++ b/Compiler/decision_tree.py @@ -10,6 +10,31 @@ debug_split = False max_leaves = None +def GetSortPerm(keys, *to_sort, n_bits=None, time=False): + """ + Compute and return secret shared permutation that stably sorts :param keys. + """ + for k in keys: + assert len(k) == len(keys[0]) + n_bits = n_bits or [None] * len(keys) + bs = Matrix.create_from(sum([k.get_vector().bit_decompose(nb) + for k, nb in reversed(list(zip(keys, n_bits)))], [])) + get_vec = lambda x: x[:] if isinstance(x, Array) else x + res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x + for x in to_sort) + res = res.transpose() + return radix_sort_from_matrix(bs, res) + +def ApplyPermutation(perm, x): + res = Array.create_from(x) + reveal_sort(perm, res, False) + return res + +def ApplyInversePermutation(perm, x): + res = Array.create_from(x) + reveal_sort(perm, res, True) + return res + def get_type(x): if isinstance(x, (Array, SubMultiArray)): return x.value_type @@ -22,6 +47,12 @@ def get_type(x): else: return type(x) + +def Custom_GT_Fractions(x_num, x_den, y_num, y_den, n_threads=2): + b = (x_num*y_den) > (x_den*y_num) + b = Array.create_from(b).get_vector() + return b + def PrefixSum(x): return x.get_vector().prefix_sum() @@ -86,17 +117,9 @@ def Sort(keys, *to_sort, n_bits=None, time=False): def VectMax(key, *data, debug=False): def reducer(x, y): - b = x[0] > y[0] - if debug: - print_ln('max b=%s', b.reveal()) + b = x[0]*y[1] > y[0]*x[1] return [b.if_else(xx, yy) for xx, yy in zip(x, y)] - if debug: - key = list(key) - data = [list(x) for x in data] - print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data)) - res = util.tree_reduce(reducer, zip(key, *data))[1:] - if debug: - print_ln('vect max res=%s', util.reveal(res)) + res = util.tree_reduce(reducer, zip(key, *data)) return res def GroupSum(g, x): @@ -119,9 +142,6 @@ def GroupPrefixSum(g, x): return s.get_vector(size=len(x), base=1) - GroupSum(g, q) def GroupMax(g, keys, *x): - if debug: - print_ln('group max input g=%s keys=%s x=%s', util.reveal(g), - util.reveal(keys), util.reveal(x)) assert len(keys) == len(g) for xx in x: assert len(xx) == len(g) @@ -138,51 +158,48 @@ def GroupMax(g, keys, *x): vsize = n - w g_new.assign_vector(g_old.get_vector(size=vsize).bit_or( g_old.get_vector(size=vsize, base=w)), base=w) - b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w) + b = Custom_GT_Fractions(keys.get_vector(size=vsize), x[0].get_vector(size=vsize), keys.get_vector(size=vsize, base=w), x[0].get_vector(size=vsize, base=w)) + #b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w) for xx in [keys] + x: a = b.if_else(xx.get_vector(size=vsize), xx.get_vector(size=vsize, base=w)) xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else( xx.get_vector(size=vsize, base=w), a), base=w) break_point() - if debug: - print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(), - util.reveal(a), util.reveal(keys), - util.reveal(x), g_new.reveal()) t = sint.Array(len(g)) t[-1] = 1 t.assign_vector(g.get_vector(size=n - 1, base=1)) - if debug: - print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g), - util.reveal(t), util.reveal(keys), util.reveal(x)) return [GroupSum(g, t[:] * xx) for xx in [keys] + x] -def ModifiedGini(g, y, debug=False): +def ComputeGini(g, x, y, notysum, ysum, debug=False): assert len(g) == len(y) y = [y.get_vector().bit_not(), y] u = [GroupPrefixSum(g, yy) for yy in y] - s = [GroupSum(g, yy) for yy in y] + s = [notysum, ysum] w = [ss - uu for ss, uu in zip(s, u)] us = sum(u) ws = sum(w) uqs = u[0] ** 2 + u[1] ** 2 wqs = w[0] ** 2 + w[1] ** 2 - res = sfix(uqs) / us + sfix(wqs) / ws - if debug: - print_ln('g=%s y=%s s=%s', - util.reveal(g), util.reveal(y), - util.reveal(s)) - print_ln('u0=%s', util.reveal(u[0])) - print_ln('u0=%s', util.reveal(u[1])) - print_ln('us=%s', util.reveal(us)) - print_ln('w0=%s', util.reveal(w[0])) - print_ln('w1=%s', util.reveal(w[1])) - print_ln('ws=%s', util.reveal(ws)) - print_ln('uqs=%s', util.reveal(uqs)) - print_ln('wqs=%s', util.reveal(wqs)) - if debug: - print_ln('gini %s %s', type(res), util.reveal(res)) - return res + res_num = ws * uqs + us * wqs + res_den = us * ws + xx = x + t = get_type(x).Array(len(x)) + t[-1] = MIN_VALUE + t.assign_vector(xx.get_vector(size=len(x) - 1) + \ + xx.get_vector(size=len(x) - 1, base=1)) + gg = g + p = sint.Array(len(x)) + p[-1] = 1 + p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( + xx.get_vector(size=len(x) - 1) == \ + xx.get_vector(size=len(x) - 1, base=1))) + break_point() + res_num = p[:].if_else(MIN_VALUE, res_num) + res_den = p[:].if_else(1, res_den) + t = p[:].if_else(MIN_VALUE, t[:]) + return res_num, res_den, t + MIN_VALUE = -10000 @@ -193,12 +210,8 @@ def FormatLayer_without_crop(g, *a, debug=False): for x in a: assert len(x) == len(g) v = [g.if_else(aa, 0) for aa in a] - if debug: - print_ln('format in %s', util.reveal(a)) - print_ln('format mux %s', util.reveal(v)) - v = Sort([g.bit_not()], *v, n_bits=[1]) - if debug: - print_ln('format sort %s', util.reveal(v)) + p = SortPerm(g.get_vector().bit_not()) + v = [p.apply(vv) for vv in v] return v def CropLayer(k, *v): @@ -208,42 +221,22 @@ def CropLayer(k, *v): n = 2 ** k return [vv[:min(n, len(vv))] for vv in v] -def TrainLeafNodes(h, g, y, NID): +def TrainLeafNodes(h, g, y, NID, Label, debug=False): assert len(g) == len(y) assert len(g) == len(NID) - Label = GroupSum(g, y.bit_not()) < GroupSum(g, y) return FormatLayer(h, g, NID, Label) -def GroupSame(g, y): - assert len(g) == len(y) - s = GroupSum(g, [sint(1)] * len(g)) - s0 = GroupSum(g, y.bit_not()) - s1 = GroupSum(g, y) - if debug_split: - print_ln('group same g=%s', util.reveal(g)) - print_ln('group same y=%s', util.reveal(y)) - return (s == s0).bit_or(s == s1) - def GroupFirstOne(g, b): assert len(g) == len(b) s = GroupPrefixSum(g, b) return s * b == 1 class TreeTrainer: - """ Decision tree training by `Hamada et al.`_ - - :param x: sample data (by attribute, list or - :py:obj:`~Compiler.types.Matrix`) - :param y: binary labels (list or sint vector) - :param h: height (int) - :param binary: binary attributes instead of continuous - :param attr_lengths: attribute description for mixed data - (list of 0/1 for continuous/binary) - :param n_threads: number of threads (default: single thread) - - .. _`Hamada et al.`: https://arxiv.org/abs/2112.12906 + def GetInversePermutation(self, perm): + res = Array.create_from(self.identity_permutation) + reveal_sort(perm, res) + return res - """ def ApplyTests(self, x, AID, Threshold): m = len(x) n = len(AID) @@ -251,125 +244,96 @@ def ApplyTests(self, x, AID, Threshold): for xx in x: assert len(xx) == len(AID) e = sint.Matrix(m, n) - AID = Array.create_from(AID) @for_range_multithread(self.n_threads, 1, m) def _(j): e[j][:] = AID[:] == j xx = sum(x[j] * e[j] for j in range(m)) - if self.debug > 1: - print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx)) - print_ln('threshold %s', util.reveal(Threshold)) - return 2 * xx < Threshold + return 2 * xx.get_vector() < Threshold.get_vector() - def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False): - assert len(g) == len(x) - assert len(g) == len(y) - if time: - start_timer(2) - s = ModifiedGini(g, y, debug=debug or self.debug > 2) - if time: - stop_timer(2) - if debug or self.debug > 1: - print_ln('gini %s', s.reveal()) - xx = x - t = get_type(x).Array(len(x)) - t[-1] = MIN_VALUE - t.assign_vector(xx.get_vector(size=len(x) - 1) + \ - xx.get_vector(size=len(x) - 1, base=1)) - gg = g - p = sint.Array(len(x)) - p[-1] = 1 - p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( - xx.get_vector(size=len(x) - 1) == \ - xx.get_vector(size=len(x) - 1, base=1))) - break_point() - if debug: - print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p)) - s = p[:].if_else(MIN_VALUE, s) - t = p[:].if_else(MIN_VALUE, t[:]) - if debug: - print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) - if time: - start_timer(3) - s, t = GroupMax(gg, s, t) - if time: - stop_timer(3) - if debug: - print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) - return t, s - - def GlobalTestSelection(self, x, y, g): - assert len(y) == len(g) + def TestSelection(self, g, x, y, pis, notysum, ysum, time=False): for xx in x: assert(len(xx) == len(g)) + assert len(g) == len(y) m = len(x) n = len(y) + gg = g u, t = [get_type(x).Matrix(m, n) for i in range(2)] v = get_type(y).Matrix(m, n) - s = sfix.Matrix(m, n) + s_num = get_type(y).Matrix(m, n) + s_den = get_type(y).Matrix(m, n) + a = sint.Array(n) + + notysum_arr = Array.create_from(notysum) + ysum_arr = Array.create_from(ysum) + @for_range_multithread(self.n_threads, 1, m) def _(j): single = not self.n_threads or self.n_threads == 1 time = self.time and single - if debug: + if self.debug_selection: print_ln('run %s', j) + u[j].assign_vector(x[j]) + v[j].assign_vector(y) + reveal_sort(pis[j], u[j]) + reveal_sort(pis[j], v[j]) + s_num[j][:], s_den[j][:], t[j][:] = ComputeGini(g, u[j], v[j], notysum_arr, ysum_arr, debug=False) + + ss_num, ss_den, tt, aa = VectMax((s_num[j][:] for j in range(m)), (s_den[j][:] for j in range(m)), (t[j][:] for j in range(m)), range(m), debug=self.debug) + + aaa = get_type(y).Array(n) + ttt = get_type(x).Array(n) + + GroupMax_num, GroupMax_den, GroupMax_ttt, GroupMax_aaa = GroupMax(g, ss_num, ss_den, tt, aa) + + f = sint.Array(n) + f = (self.zeros.get_vector() == notysum).bit_or(self.zeros.get_vector() == ysum) + aaa_vector, ttt_vector = f.if_else(0, GroupMax_aaa), f.if_else(MIN_VALUE, GroupMax_ttt) + + ttt.assign_vector(ttt_vector) + aaa.assign_vector(aaa_vector) + + return aaa, ttt + + def SetupPerm(self, g, x, y): + m = len(x) + n = len(y) + pis = get_type(y).Matrix(m, n) + @for_range_multithread(self.n_threads, 1, m) + def _(j): @if_e(self.attr_lengths[j]) def _(): - u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, - n_bits=[util.log2(n), 1], time=time) + pis[j][:] = self.GetInversePermutation(GetSortPerm([x[j]], x[j], y, + n_bits=[1], time=time)) @else_ def _(): - u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, - n_bits=[util.log2(n), None], - time=time) - if self.debug_threading: - print_ln('global sort %s %s %s', j, util.reveal(u[j]), - util.reveal(v[j])) - t[j][:], s[j][:] = self.AttributeWiseTestSelection( - g, u[j], v[j], time=time, debug=self.debug_selection) - if self.debug_threading: - print_ln('global attribute %s %s %s', j, util.reveal(t[j]), - util.reveal(s[j])) - n = len(g) - a = sint.Array(n) - if self.debug_threading: - print_ln('global s=%s', util.reveal(s)) - if self.debug_gini: - print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)), - *(ss[0].reveal() for ss in s)) - if self.time: - start_timer(4) - if self.debug > 1: - print_ln('s=%s', s.reveal_nested()) - print_ln('t=%s', t.reveal_nested()) - a[:], tt = VectMax((s[j][:] for j in range(m)), range(m), - (t[j][:] for j in range(m)), debug=self.debug > 1) - tt = Array.create_from(tt) - if self.time: - stop_timer(4) - if self.debug > 1: - print_ln('a=%s', util.reveal(a)) - print_ln('tt=%s', util.reveal(tt)) - return a[:], tt[:] - - def TrainInternalNodes(self, k, x, y, g, NID): - assert len(g) == len(y) - for xx in x: - assert len(xx) == len(g) - AID, Threshold = self.GlobalTestSelection(x, y, g) - s = GroupSame(g[:], y[:]) - if self.debug > 1 or debug_split: - print_ln('AID=%s', util.reveal(AID)) - print_ln('Threshold=%s', util.reveal(Threshold)) - print_ln('GroupSame=%s', util.reveal(s)) - AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold) - if self.debug > 1 or debug_split: - print_ln('AID=%s', util.reveal(AID)) - print_ln('Threshold=%s', util.reveal(Threshold)) - b = self.ApplyTests(x, AID, Threshold) - layer = FormatLayer_without_crop(g[:], NID, AID, Threshold, - debug=self.debug > 1) - return *layer, b + pis[j][:] = self.GetInversePermutation(GetSortPerm([x[j]], x[j], y, + n_bits=[None], + time=time)) + return pis + + def UpdateState(self, g, x, y, pis, NID, b, k): + m = len(x) + n = len(y) + q = SortPerm(b) + + y[:] = q.apply(y) + NID[:] = 2 ** k * b + NID + NID[:] = q.apply(NID) + g[:] = GroupFirstOne(g, b.bit_not()) + GroupFirstOne(g, b) + g[:] = q.apply(g) + + b_arith = sint.Array(n) + b_arith = Array.create_from(b) + + @for_range_multithread(self.n_threads, 1, m) + def _(j): + x[j][:] = q.apply(x[j]) + b_permuted = ApplyPermutation(pis[j], b_arith) + + pis[j] = q.apply(pis[j]) + pis[j] = ApplyInversePermutation(pis[j], SortPerm(b_permuted).perm) + + return [g, x, y, NID, pis] @method_block def train_layer(self, k): @@ -377,29 +341,39 @@ def train_layer(self, k): y = self.y g = self.g NID = self.NID - if self.debug > 1: - print_ln('g=%s', g.reveal()) - print_ln('y=%s', y.reveal()) - print_ln('x=%s', x.reveal_nested()) - self.nids[k], self.aids[k], self.thresholds[k], b = \ - self.TrainInternalNodes(k, x, y, g, NID) - if self.debug > 1: - print_ln('layer %s:', k) - for name, data in zip(('NID', 'AID', 'Thr'), - (self.nids[k], self.aids[k], - self.thresholds[k])): - print_ln(' %s: %s', name, data.reveal()) - NID[:] = 2 ** k * b + NID - b_not = b.bit_not() - if self.debug > 1: - print_ln('b_not=%s', b_not.reveal()) - g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b) - y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1]) - for i, xxx in enumerate(xx): - x[i] = xxx + pis = self.pis + + s0 = GroupSum(g, y.get_vector().bit_not()) + s1 = GroupSum(g, y.get_vector()) + + a, t = self.TestSelection(g, x, y, pis, s0, s1) + b = self.ApplyTests(x, a, t) + p = SortPerm(g.get_vector().bit_not()) + + self.nids[k], self.aids[k], self.thresholds[k]= FormatLayer_without_crop(g[:], NID, a, t, debug=self.debug) + self.g, self.x, self.y, self.NID, self.pis = self.UpdateState(g, x, y, pis, NID, b, k) + + @if_(k >= (len(self.nids)-1)) + def _(): + self.label = Array.create_from(s0 < s1) def __init__(self, x, y, h, binary=False, attr_lengths=None, n_threads=None): + """ Securely Training Decision Trees Efficiently by `Bhardwaj et al.`_ : https://eprint.iacr.org/2024/1077.pdf + + This protocol has communication complexity O( mN logN + hmN + hN log N) which is an improvement of ~min(h, m, log N) over `Hamada et al.`_ : https://petsymposium.org/popets/2023/popets-2023-0021.pdf + + To run this protocol, at the root of the MP-SPDZ repo, run Scripts/compile-run.py -H HOSTS -E ring custom_data_dt $((2**13)) 11 4 -Z 3 -R 128 + + :param x: Attribute values + :param y: Binary labels + :param h: Height of the decision tree + :param binary: Binary attributes instead of continuous + :param attr_lengths: Attribute description for mixed data + (list of 0/1 for continuous/binary) + :param n_threads: Number of threads + + """ assert not (binary and attr_lengths) if binary: attr_lengths = [1] * len(x) @@ -412,6 +386,7 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None, Matrix.disable_index_checks() for xx in x: assert len(xx) == len(y) + m = len(x) n = len(y) self.g = sint.Array(n) self.g.assign_all(0) @@ -420,8 +395,13 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None, self.NID.assign_all(1) self.y = Array.create_from(y) self.x = Matrix.create_from(x) + self.pis = sint.Matrix(m, n) self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)] self.thresholds = self.x.value_type.Matrix(h, n) + self.identity_permutation = sint.Array(n) + self.label = sintbit.Array(n) + self.zeros = sint.Array(n) + self.zeros.assign_all(0) self.n_threads = n_threads self.debug_selection = False self.debug_threading = False @@ -431,11 +411,19 @@ def __init__(self, x, y, h, binary=False, attr_lengths=None, def train(self): """ Train and return decision tree. """ + n = len(self.y) + + @for_range(n) + def _(i): + self.identity_permutation[i] = sint(i) + h = len(self.nids) + self.pis = self.SetupPerm(self.g, self.x, self.y) + @for_range(h) def _(k): self.train_layer(k) - return self.get_tree(h) + return self.get_tree(h, self.label) def train_with_testing(self, *test_set, output=False): """ Train decision tree and test against test data. @@ -449,7 +437,7 @@ def train_with_testing(self, *test_set, output=False): """ for k in range(len(self.nids)): self.train_layer(k) - tree = self.get_tree(k + 1) + tree = self.get_tree(k + 1, self.label) if output: output_decision_tree(tree) test_decision_tree('train', tree, self.y, self.x, @@ -459,12 +447,12 @@ def train_with_testing(self, *test_set, output=False): n_threads=self.n_threads) return tree - def get_tree(self, h): + def get_tree(self, h, Label): Layer = [None] * (h + 1) for k in range(h): Layer[k] = CropLayer(k, self.nids[k], self.aids[k], self.thresholds[k]) - Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID) + Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID, Label) return Layer def DecisionTreeTraining(x, y, h, binary=False): diff --git a/Compiler/sorting.py b/Compiler/sorting.py index 7779c7489..cdf9cf1f6 100644 --- a/Compiler/sorting.py +++ b/Compiler/sorting.py @@ -73,3 +73,4 @@ def _(): @library.else_ def _(): reveal_sort(h, D, reverse=True) + return h