-
Notifications
You must be signed in to change notification settings - Fork 1
/
ShapeMatchVertex.java
106 lines (90 loc) · 3.93 KB
/
ShapeMatchVertex.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package ode.vertex.conf;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.jetbrains.annotations.NotNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
/**
* Duplicates the last input to match the shapes of the other inputs. Main use case is for performing merging or element
* wise operations with current time from an ODE solver. It is possible to override size of selected dimensions by
* providing a map between dimension to override and wanted size.
*
* @author Christian Skarby
*/
@Data
@EqualsAndHashCode(callSuper = false)
public class ShapeMatchVertex extends GraphVertex {
protected GraphVertex graphVertex;
protected Map<Integer, Long> overrideSizeDims;
public ShapeMatchVertex(MergeVertex graphVertex) {
this(graphVertex, Collections.singletonMap(1,1L)); // Might not hold for conv3D layers...
}
public ShapeMatchVertex(GraphVertex graphVertex) {
this(graphVertex, Collections.emptyMap());
}
public ShapeMatchVertex(
@JsonProperty("graphVertex") GraphVertex graphVertex,
@JsonProperty("overrideSizeDims") Map<Integer, Long> overrideSizeDims) {
this.graphVertex = graphVertex;
if(graphVertex.maxVertexInputs() < 2) {
throw new IllegalArgumentException("Must be able to take more than one input! Got: " + graphVertex);
}
this.overrideSizeDims = overrideSizeDims;
}
@Override
public GraphVertex clone() {
return new ShapeMatchVertex(graphVertex.clone(), new HashMap<>(overrideSizeDims));
}
@Override
public long numParams(boolean backprop) {
return graphVertex.numParams(backprop);
}
@Override
public int minVertexInputs() {
return Math.max(2, graphVertex.minVertexInputs());
}
@Override
public int maxVertexInputs() {
return graphVertex.maxVertexInputs();
}
@Override
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) {
return new ode.vertex.impl.ShapeMatchVertex(graph, name, idx,
graphVertex.instantiate(graph, name+"-vertex", idx, paramsView, initializeParams), overrideSizeDims);
}
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
final InputType[] afterDuplicate = getInputTypesAfterDuplication(vertexInputs);
return graphVertex.getOutputType(layerIndex, afterDuplicate);
}
@NotNull
private InputType[] getInputTypesAfterDuplication(InputType[] vertexInputs) {
if(vertexInputs.length < 2) {
throw new IllegalArgumentException("Must have more than one inputs!! Got: " + Arrays.toString(vertexInputs));
}
final InputType[] afterDuplicate = vertexInputs.clone();
final long[] shape = afterDuplicate[0].getShape(true);
shape[0] = 1;
for(Map.Entry<Integer, Long> sizeDim: overrideSizeDims.entrySet()) {
shape[sizeDim.getKey()] = sizeDim.getValue();
}
final InputType newType = InputType.inferInputType(Nd4j.createUninitialized(shape));
afterDuplicate[afterDuplicate.length - 1] = newType;
return afterDuplicate;
}
@Override
public MemoryReport getMemoryReport(InputType... inputTypes) {
return graphVertex.getMemoryReport(getInputTypesAfterDuplication(inputTypes));
}
}