Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options for text line height and normalization style #66

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions clstmhl.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,9 @@ struct CLSTMText {
struct CLSTMOCR {
shared_ptr<INormalizer> normalizer;
Network net;
int target_height = 48;
int target_height;// = 48; // to avoid unwanted values.
int nclasses = -1;
string dewarp; // Option for text-line normalization
Sequence aligned, targets;
Tensor2 image;
void setLearningRate(float lr, float mom) { net->setLearningRate(lr, mom); }
Expand All @@ -161,7 +162,7 @@ struct CLSTMOCR {
return false;
}
nclasses = net->codec.size();
normalizer.reset(make_CenterNormalizer());
normalizer.reset(make_Normalizer(dewarp));
normalizer->target_height = target_height;
return true;
}
Expand Down Expand Up @@ -194,7 +195,7 @@ struct CLSTMOCR {
{"nhidden", nhidden}});
net->initialize();
net->codec.set(codec);
normalizer.reset(make_CenterNormalizer());
normalizer.reset(make_Normalizer(dewarp));
normalizer->target_height = target_height;
}
std::wstring fwdbwd(TensorMap2 raw, const std::wstring &target) {
Expand Down
4 changes: 3 additions & 1 deletion clstmocr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ int main1(int argc, char **argv) {
string load_name = getsenv("load", "");
if (load_name == "") THROW("must give load= parameter");
CLSTMOCR clstm;
clstm.target_height = int(getrenv("target_height", 45));
clstm.dewarp = getsenv("dewarp", "none");
clstm.load(load_name);

bool conf = getienv("conf", 0);
string output = getsenv("output", "text");
bool save_text = getienv("save_text", 1);
Expand Down
42 changes: 39 additions & 3 deletions clstmocrtrain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ struct Dataset {
for (auto s : fnames) gtnames.push_back(basename(s) + ".gt.txt");
codec.build(gtnames, charsep);
}
void getCodec(Codec &codec, vector<string> file_lists) {
// get codec from several files, including training files, validation files,
// and perhaps testing files, in order to avoid unrecognized codecs
vector<string> gts;
for (int i=0; i<file_lists.size(); i++) {
vector<string> temp_names;
read_lines(temp_names, file_lists[i]);
for (auto s : temp_names) gts.push_back(basename(s) + ".gt.txt");
}
// build the codecs
codec.build(gts, charsep);
}

void readSample(Tensor2 &raw, wstring &gt, int index) {
string fname = fnames[index];
string base = basename(fname);
Expand Down Expand Up @@ -92,12 +105,19 @@ int main1(int argc, char **argv) {
int ntrain = getienv("ntrain", 10000000);
string save_name = getsenv("save_name", "_ocr");
int report_time = getienv("report_time", 0);
// vector storing the training and testing files
vector<string> file_lists;

if (argc < 2 || argc > 3) THROW("... training [testing]");
Dataset trainingset(argv[1]);
file_lists.push_back(argv[1]);
assert(trainingset.size() > 0);
Dataset testset;
if (argc > 2) testset.readFileList(argv[2]);
if (argc > 2) {
testset.readFileList(argv[2]);
file_lists.push_back(argv[2]);
}

print("got", trainingset.size(), "files,", testset.size(), "tests");

string load_name = getsenv("load", "");
Expand All @@ -108,13 +128,16 @@ int main1(int argc, char **argv) {
clstm.load(load_name);
} else {
Codec codec;
trainingset.getCodec(codec);
//trainingset.getCodec(codec);
trainingset.getCodec(codec, file_lists); // use all ground truth files
print("got", codec.size(), "classes");

clstm.target_height = int(getrenv("target_height", 48));
clstm.target_height = int(getrenv("target_height", 45));
clstm.dewarp = getsenv("dewarp", "none");
clstm.createBidi(codec.codec, getienv("nhidden", 100));
clstm.setLearningRate(getdenv("lrate", 1e-4), getdenv("momentum", 0.9));
}
file_lists.clear(); // clear the file_lists vector
network_info(clstm.net);

double test_error = 9999.0;
Expand All @@ -135,12 +158,16 @@ int main1(int argc, char **argv) {
Trigger report_trigger(getienv("report_every", 100), ntrain, start);
Trigger display_trigger(getienv("display_every", 0), ntrain, start);

double train_errors = 0.0;
double train_count = 0.0;
for (int trial = start; trial < ntrain; trial++) {
int sample = lrand48() % trainingset.size();
Tensor2 raw;
wstring gt;
trainingset.readSample(raw, gt, sample);
wstring pred = clstm.train(raw(), gt);
train_count += gt.size();
train_errors += levenshtein(pred, gt);

if (report_trigger(trial)) {
print(trial);
Expand Down Expand Up @@ -168,6 +195,15 @@ int main1(int argc, char **argv) {
double count = tse.second;
test_error = errors / count;
print("ERROR", trial, test_error, " ", errors, count);
double train_error;
if (train_errors > 0)
train_error = train_count / train_errors;
else
train_error = 9999.0;
print("Train ERROR: ", train_error);
train_count = 0.0;
train_errors = 0.0;

if (test_error < best_error) {
best_error = test_error;
string fname = save_name + ".clstm";
Expand Down
2 changes: 1 addition & 1 deletion extras.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using std::min;
// text line normalization

struct INormalizer {
int target_height = 48;
int target_height; // = 48;
float smooth2d = 1.0;
float smooth1d = 0.3;
float range = 4.0;
Expand Down