Skip to content

Commit

Permalink
Bugfix loading a model with char embeddings + update to keras 2.2.0 a…
Browse files Browse the repository at this point in the history
…nd TF 1.8.0
  • Loading branch information
nreimers committed Jun 27, 2018
1 parent 57f37c8 commit 8350368
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 50 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# BiLSTM-CNN-CRF Implementation for Sequence Tagging

This repository contains a BiLSTM-CRF implementation that used for NLP Sequence Tagging (for example POS-tagging, Chunking, or Named Entity Recognition). The implementation is based on Keras 2.1.5 and can be run with Tensorflow 1.7.0 as backend. It was optimized for Python 3.5 / 3.6. It does **not work** with Python 2.7.
This repository contains a BiLSTM-CRF implementation that used for NLP Sequence Tagging (for example POS-tagging, Chunking, or Named Entity Recognition). The implementation is based on Keras 2.2.0 and can be run with Tensorflow 1.8.0 as backend. It was optimized for Python 3.5 / 3.6. It does **not work** with Python 2.7.

The architecture is described in our papers:
- [Reporting Score Distributions Makes a Difference: Performance Study of LSTM-networks for Sequence Tagging](https://arxiv.org/abs/1707.09861)
Expand Down
24 changes: 13 additions & 11 deletions docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
absl-py==0.1.13
absl-py==0.2.2
astor==0.6.2
bleach==1.5.0
gast==0.2.0
grpcio==1.11.0
h5py==2.7.1
grpcio==1.12.1
h5py==2.8.0
html5lib==0.9999999
Keras==2.1.5
Keras==2.2.0
Keras-Applications==1.0.2
Keras-Preprocessing==1.0.1
Markdown==2.6.11
nltk==3.2.5
numpy==1.14.2
protobuf==3.5.2.post1
PyYAML==3.12
scipy==1.0.1
numpy==1.14.5
protobuf==3.6.0
PyYAML==4.1
scipy==1.1.0
six==1.11.0
tensorboard==1.7.0
tensorflow==1.7.0
tensorboard==1.8.0
tensorflow==1.8.0
termcolor==1.1.0
Werkzeug==0.14.1
Werkzeug==0.14.1
4 changes: 2 additions & 2 deletions docs/Pretrained_Models.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ Trained on [CoNLL 2003](http://www.cnts.ua.ac.be/conll2003/ner/) and [GermEval 2
| Language | Development (F1) | Test (F1) |
|----------|:-----------:|:----:|
|[English (CoNLL 2003)](https://public.ukp.informatik.tu-darmstadt.de/reimers/2017_SequenceTaggingModels/v2.1.5/EN_NER.h5) | 93.87% | 90.22% |
|[German (CoNLL 2003)](https://public.ukp.informatik.tu-darmstadt.de/reimers/2017_SequenceTaggingModels/v2.1.5/DE_NER_CoNLL.h5) | 80.12% | 77.52% |
|[German (GermEval 2014)](https://public.ukp.informatik.tu-darmstadt.de/reimers/2017_SequenceTaggingModels/v2.1.5/DE_NER_GermEval.h5) | 80.74% | 79.96% |
|[German (CoNLL 2003)](https://public.ukp.informatik.tu-darmstadt.de/reimers/2017_SequenceTaggingModels/v2.1.5/DE_NER_CoNLL.h5) | 81.15% | 77.70% |
|[German (GermEval 2014)](https://public.ukp.informatik.tu-darmstadt.de/reimers/2017_SequenceTaggingModels/v2.1.5/DE_NER_GermEval.h5) | 80.93% | 78.94% |


## Entities
Expand Down
59 changes: 33 additions & 26 deletions neuralnets/BiLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ def setDataset(self, datasets, data):

self.casing2Idx = self.mappings['casing']

if self.params['charEmbeddings'] not in [None, "None", "none", False, "False", "false"]:
logging.info("Pad words to uniform length for characters embeddings")
all_sentences = []
for dataset in self.data.values():
for data in [dataset['trainMatrix'], dataset['devMatrix'], dataset['testMatrix']]:
for sentence in data:
all_sentences.append(sentence)

self.padCharacters(all_sentences)
logging.info("Words padded to %d characters" % (self.maxCharLen))


def buildModel(self):
self.models = {}
Expand All @@ -109,43 +120,39 @@ def buildModel(self):

# :: Character Embeddings ::
if self.params['charEmbeddings'] not in [None, "None", "none", False, "False", "false"]:
logging.info("Pad words to uniform length for characters embeddings")
all_sentences = []
for dataset in self.data.values():
for data in [dataset['trainMatrix'], dataset['devMatrix'], dataset['testMatrix']]:
for sentence in data:
all_sentences.append(sentence)

self.padCharacters(all_sentences)
logging.info("Words padded to %d characters" % (self.maxCharLen))

charset = self.mappings['characters']
charEmbeddingsSize = self.params['charEmbeddingsSize']
maxCharLen = self.maxCharLen
charEmbeddings= []
charEmbeddings = []
for _ in charset:
limit = math.sqrt(3.0/charEmbeddingsSize)
vector = np.random.uniform(-limit, limit, charEmbeddingsSize)
limit = math.sqrt(3.0 / charEmbeddingsSize)
vector = np.random.uniform(-limit, limit, charEmbeddingsSize)
charEmbeddings.append(vector)
charEmbeddings[0] = np.zeros(charEmbeddingsSize) #Zero padding

charEmbeddings[0] = np.zeros(charEmbeddingsSize) # Zero padding
charEmbeddings = np.asarray(charEmbeddings)

chars_input = Input(shape=(None,maxCharLen), dtype='int32', name='char_input')
chars = TimeDistributed(Embedding(input_dim=charEmbeddings.shape[0], output_dim=charEmbeddings.shape[1], weights=[charEmbeddings], trainable=True, mask_zero=True), name='char_emd')(chars_input)

if self.params['charEmbeddings'].lower() == 'lstm': #Use LSTM for char embeddings from Lample et al., 2016

chars_input = Input(shape=(None, maxCharLen), dtype='int32', name='char_input')
mask_zero = (self.params['charEmbeddings'].lower()=='lstm') #Zero mask only works with LSTM
chars = TimeDistributed(
Embedding(input_dim=charEmbeddings.shape[0], output_dim=charEmbeddings.shape[1],
weights=[charEmbeddings],
trainable=True, mask_zero=mask_zero), name='char_emd')(chars_input)

if self.params['charEmbeddings'].lower()=='lstm': # Use LSTM for char embeddings from Lample et al., 2016
charLSTMSize = self.params['charLSTMSize']
chars = TimeDistributed(Bidirectional(LSTM(charLSTMSize, return_sequences=False)), name="char_lstm")(chars)
else: #Use CNNs for character embeddings from Ma and Hovy, 2016
chars = TimeDistributed(Bidirectional(LSTM(charLSTMSize, return_sequences=False)), name="char_lstm")(
chars)
else: # Use CNNs for character embeddings from Ma and Hovy, 2016
charFilterSize = self.params['charFilterSize']
charFilterLength = self.params['charFilterLength']
chars = TimeDistributed(Conv1D(charFilterSize, charFilterLength, padding='same'), name="char_cnn")(chars)
chars = TimeDistributed(Conv1D(charFilterSize, charFilterLength, padding='same'), name="char_cnn")(
chars)
chars = TimeDistributed(GlobalMaxPooling1D(), name="char_pooling")(chars)


self.params['featureNames'].append('characters')
mergeInputLayers.append(chars)
inputNodes.append(chars_input)
self.params['featureNames'].append('characters')

# :: Task Identifier ::
if self.params['useTaskIdentifier']:
Expand Down Expand Up @@ -242,7 +249,7 @@ def buildModel(self):
model = Model(inputs=inputNodes, outputs=[output])
model.compile(loss=lossFct, optimizer=opt)

model.summary(line_length=200)
model.summary(line_length=125)
#logging.info(model.get_config())
#logging.info("Optimizer: %s - %s" % (str(type(model.optimizer)), str(model.optimizer.get_config())))

Expand Down
22 changes: 12 additions & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
absl-py==0.1.13
absl-py==0.2.2
astor==0.6.2
bleach==1.5.0
gast==0.2.0
grpcio==1.11.0
h5py==2.7.1
grpcio==1.12.1
h5py==2.8.0
html5lib==0.9999999
Keras==2.1.5
Keras==2.2.0
Keras-Applications==1.0.2
Keras-Preprocessing==1.0.1
Markdown==2.6.11
nltk==3.2.5
numpy==1.14.2
protobuf==3.5.2.post1
PyYAML==3.12
scipy==1.0.1
numpy==1.14.5
protobuf==3.6.0
PyYAML==4.1
scipy==1.1.0
six==1.11.0
tensorboard==1.7.0
tensorflow==1.7.0
tensorboard==1.8.0
tensorflow==1.8.0
termcolor==1.1.0
Werkzeug==0.14.1

0 comments on commit 8350368

Please sign in to comment.