-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpred.py
More file actions
125 lines (96 loc) · 4.25 KB
/
pred.py
File metadata and controls
125 lines (96 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import logging
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from nltk import word_tokenize
from dotenv import load_dotenv
load_dotenv()
# Run prediction on CPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Flask creates a new threads which will generate their own Tensorflow session
graph = tf.get_default_graph()
logging.basicConfig(level=logging.INFO,
format='%(asctime)s :: %(levelname)s :: %(message)s')
MODEL_DIR = os.environ.get('MODEL_DIR')
logging.info('Loading model and tokenizer from: {}'.format(MODEL_DIR))
model = load_model(os.path.join(MODEL_DIR, 'model.h5'))
tokenizer = pickle.load(open(os.path.join(MODEL_DIR, 'tokenizer.pkl'), "rb"))
labels2idx = pickle.load(open(os.path.join(MODEL_DIR, 'labels2idx.pkl'), "rb"))
words2idx = tokenizer.word_index
idx2words = {words2idx[k]: k for k in words2idx}
idx2labels = {labels2idx[k]: k for k in labels2idx}
logging.info('Loaded correctly')
def prediction(query):
"""
Function to extract the entities from a query
:param query: User sentence
:return: Sentence with the entities
"""
global graph
with graph.as_default():
# Tokenizer sentence
tok_ls = word_tokenize(query.lower())
# Convert to idx
tok_idx = []
for el in tok_ls:
if el in words2idx:
tok_idx.append(words2idx[el])
else:
tok_idx.append(words2idx['UNK'])
# Reshape this array as same before
reshape_tok_ls = np.array(tok_idx)[np.newaxis, :]
# Prediction
pred = model.predict(reshape_tok_ls)
# Take the best result
pred_max = np.argmax(pred, -1)[0]
# Show the decoding prediction
pred_decode = []
for el in pred_max:
pred_decode.append(idx2labels.get(el))
logging.info('Prediction decode: {}'.format(pred_decode))
labels_decode = []
tokens_decode = []
for el1, el2 in zip(pred_decode, tok_ls):
if el1 != 'O':
labels_decode.append(el1)
tokens_decode.append(el2)
ext = ['O', 'O']
labels_decode = labels_decode + ext
tokens_decode = tokens_decode + ext
for i in range(len(labels_decode)):
if labels_decode[i] != 'O':
# print(labels_decode[i])
item = labels_decode[i]
next_item = labels_decode[i + 1]
next_next_item = labels_decode[i + 2]
if item[:2] != 'I-' and next_item[:2] != 'I-':
logging.info('{} : {}'.format(labels_decode[i], tokens_decode[i]))
if item[:2] != 'I-' and next_item[:2] == 'I-' and next_next_item[:2] != 'I-':
logging.info('B+I {}: {} {}'.format(labels_decode[i], tokens_decode[i], tokens_decode[i + 1]))
if item[:2] == 'I-' and next_item[:2] == 'I-' and next_next_item[:2] != 'I-':
logging.info('B+I {}: {} {} {}'.format(labels_decode[i - 1], tokens_decode[i - 1],
tokens_decode[i], tokens_decode[i + 1]))
# Combine and return result
result_sent = []
tokens = tok_ls + ext
labels = pred_decode + ext
for items in range(len(labels)):
if labels[items] != 'O':
before_item = labels[items - 1]
item = labels[items]
next_item = labels[items + 1]
next_next_item = labels[items + 2]
if item[:2] != 'I-' and next_item[:2] != 'I-':
a = ' '.join(['[', tokens[items], item, ']'])
result_sent.append(a)
if item[:2] == 'I-' and next_item[:2] == 'I-' and next_next_item[:2] != 'I-':
a = ' '.join(['[', tokens[items - 1], tokens[items], tokens[items + 1], before_item, ']'])
result_sent.append(a)
if item[:2] != 'I-' and next_item[:2] == 'I-' and next_next_item[:2] != 'I-':
a = ' '.join(['[', tokens[items], tokens[items + 1], item, ']'])
result_sent.append(a)
else:
result_sent.append(tokens[items])
return ' '.join(result_sent[:-2])