-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
206 lines (153 loc) · 5.68 KB
/
train.py
File metadata and controls
206 lines (153 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# -*- coding: utf-8 -*-
import os
import random
import logging
import numpy as np
from dotenv import load_dotenv
import pickle
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Dense, Dropout, Bidirectional, LSTM
from utils.convert2format import convert
load_dotenv()
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
random.seed(42)
logging.basicConfig(level=logging.INFO,
format='TRAIN - %(asctime)s :: %(levelname)s :: %(message)s')
DATA_PATH = os.environ.get('DATA_PATH')
CONVERT_PATH = os.environ.get('CONVERT_PATH')
DATASET_FILE = os.environ.get('DATASET_FILE')
QUERY_FILE = os.environ.get('QUERY_FILE')
LABEL_FILE = os.environ.get('LABEL_FILE')
GLOVE_DIR = os.environ.get('GLOVE_DIR')
EMBEDDING_DIM = int(os.environ.get('EMBEDDING_DIM'))
MAX_SEQ_LEN = int(os.environ.get('MAX_SEQ_LEN'))
MODEL_DIR = os.environ.get('MODEL_DIR')
convert_path = os.path.join(DATA_PATH, CONVERT_PATH)
if not os.path.exists(convert_path):
os.mkdir(convert_path)
dir_dataset = os.path.join(DATA_PATH, DATASET_FILE)
dir2save_query = os.path.join(convert_path, QUERY_FILE)
dir2save_label = os.path.join(convert_path, LABEL_FILE)
logging.info('Data set dir: {}'.format(dir_dataset))
logging.info('Query dir: {}'.format(dir2save_query))
logging.info('Label dir: {}'.format(dir2save_label))
logging.info('Glove dir: {} / Embedding Dim: {}'.format(GLOVE_DIR, EMBEDDING_DIM))
# Convert format
if not os.path.exists(dir2save_query):
convert(dir_dataset, dir2save_query, dir2save_label)
logging.info('Loading query and labels files')
# Load files
with open(dir2save_query, "rb") as fp:
sentences = pickle.load(fp)
with open(dir2save_label, "rb") as fp:
labels = pickle.load(fp)
labels_original = labels
logging.info('Data loaded')
logging.info('Sentence: {}'.format(sentences[10]))
logging.info('Label: {}'.format(labels[10]))
# Join sentences
sentences = [' '.join(sent) for sent in sentences]
# Set of all entities
entities = [y for x in labels for y in x]
tags = list(set(entities))
# Create dictionary for labels
idx = np.arange(0, len(tags))
labels2idx = dict(zip(tags, idx))
logging.info('\tLabels2index: {}'.format(labels2idx))
# Convert list of labels into index_labels
logging.info('Convert list of labels into index_labels:')
labels_idx = []
for label in labels:
tag = []
for tags in label:
index = labels2idx.get(tags)
tag.append(index)
labels_idx.append(tag)
logging.info('\tlabels_idx: {}'.format(labels_idx[0]))
logging.info('\tlabels: {}'.format(labels[0]))
# Tokenizer
tokenizer = Tokenizer(num_words=20000, split=' ', oov_token='UNK')
tokenizer.fit_on_texts(sentences)
sequences = tokenizer.texts_to_sequences(sentences)
# Add UNK key in words2index for unknown words
logging.info('Creating words2index')
words2idx = tokenizer.word_index
n_classes = len(labels2idx)
n_vocab = len(words2idx)
logging.info('Number of labels: {}'.format(n_classes))
logging.info('Number of words: {}'.format(n_vocab))
# Load and prepare embedding
logging.info('Loading Glove...')
# Open embedding file
f = open(GLOVE_DIR, encoding='utf-8')
embeddings_index = {}
words_glove = []
for line in f:
values = line.split()
word = values[0]
words_glove.append(word)
coefs = np.asarray(values[1:], dtype='float32')
embeddings_index[word] = coefs
f.close()
for tok in words_glove:
if tok not in words2idx:
words2idx.update({tok.lower(): list(words2idx.values())[-1] + 1})
logging.info('len: {}'.format(len(words2idx)))
logging.info('Creating Embedding Matrix...')
i = 0
empty = []
embedding_matrix = np.random.random((len(words2idx) + 1, EMBEDDING_DIM))
for word, i in words2idx.items():
embedding_vector = embeddings_index.get(word)
if embedding_vector is not None:
# words not found in embedding index will be all-zeros.
embedding_matrix[i] = embedding_vector
empty.append(i)
i += 1
logging.info('Embedding Matrix created')
# Pad Sequences
data_train = pad_sequences(sequences, maxlen=MAX_SEQ_LEN)
# Create matrix with labels one-hot
labels_train = []
for items in labels_idx:
label = items
label = np.eye(n_classes)[items]
labels_train.append(label)
# Apply pad sequences to each labels
labels_train = pad_sequences(labels_train, maxlen=MAX_SEQ_LEN)
x_train = data_train
y_train = labels_train
logging.info('Shape of x_train: {}'.format(x_train.shape))
logging.info('Shape of y_train: {}'.format(y_train.shape))
# Define our model
def model():
model = Sequential()
model.add(Embedding(len(words2idx) + 1, EMBEDDING_DIM, weights=[embedding_matrix], mask_zero=True, trainable=True))
model.add(Dropout(0.25))
model.add(Bidirectional(LSTM(300, return_sequences=True)))
model.add(Dense(n_classes, activation='softmax'))
return model
logging.info('Compiling model')
model = model()
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['acc'])
logging.info('Training model')
model.summary()
hist = model.fit(x_train, y_train,
validation_split=0.1,
nb_epoch=20,
batch_size=64)
# Save model
logging.info('Save model in: {}'.format(MODEL_DIR))
if not os.path.exists(MODEL_DIR):
os.mkdir(MODEL_DIR)
model.save(os.path.join(MODEL_DIR, 'model.h5'))
with open(os.path.join(MODEL_DIR, 'tokenizer.pkl'), 'wb') as f:
pickle.dump(tokenizer, f, protocol=pickle.HIGHEST_PROTOCOL)
with open(os.path.join(MODEL_DIR, 'labels2idx.pkl'), 'wb') as f:
pickle.dump(labels2idx, f, pickle.HIGHEST_PROTOCOL)
logging.info('Model saved')