Skip to content

Training #2

Latest
Compare
Choose a tag to compare
@oovm oovm released this 05 Sep 12:06
· 4 commits to master since this release

Tang-RNN

SetDirectory[NotebookDirectory[]];
encoder = NetEncoder[{"Characters", chars = Append[Import["chars.WXF"], _]}]
trainingData = encoder /@ Import["ipt.WXF"]; // TT
predictNet = NetChain[{
	UnitVectorLayer[Length@chars],
	GatedRecurrentLayer[128],
	GatedRecurrentLayer[128],
	NetMapOperator[LinearLayer[Length@chars]],
	SoftmaxLayer[]
}];
teacherForcingNet = NetGraph[<|
	"predict" -> predictNet,
	"rest" -> SequenceRestLayer[],
	"most" -> SequenceMostLayer[],
	"loss" -> CrossEntropyLossLayer["Index"]
|>, {
	NetPort["Input"] -> "most" -> "predict" -> NetPort["loss", "Input"],
	NetPort["Input"] -> "rest" -> NetPort["loss", "Target"]
},
	"Input" -> {Length@First[trainingData], "Integer"}
] // NetInitialize
result = NetTrain[teacherForcingNet,
	<|"Input" -> trainingData|>, All,
	MaxTrainingRounds -> 100, TimeGoal -> 3600,
	TrainingProgressCheckpointing -> {
		"Directory", "CheckPoints", "Interval" -> Quantity[10000, "Batches"]
	},
	BatchSize -> 64, TargetDevice -> "GPU",
	ValidationSet -> Scaled[0.01]
]
Export["result.WXF", result, PerformanceGoal -> "Size"]