-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathANNEmbedding.py
More file actions
119 lines (90 loc) · 4.11 KB
/
ANNEmbedding.py
File metadata and controls
119 lines (90 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
import numpy as np
from numpy.linalg import norm
from TrainModel import w2v_my_model, w2v_hazm_model
from VectorSpace import vectors_without_stop_words, calculate_final_weight
from Preprocess import preprocess_query
from Score import sort_doc_score_dict
DOCS_EMBEDDING_FILE_MY_MODEL = "./docs_embedding.json"
DOCS_EMBEDDING_FILE_HAZM_MODEL = "./docs_embedding_hazm_model.json"
def final_search_result_in_embedding(query_vector, top_k, w2v_model_type):
"""
return final (best) top k list
:param query_vector:
:param top_k:
:return: top_k_list
"""
sorted_document_score_with_query = calculate_document_score_with_query_in_embedding(query_vector, w2v_model_type)
return list(sorted_document_score_with_query.items())[0: top_k]
def calculate_document_score_with_query_in_embedding(query_vector, w2v_model_type):
if w2v_model_type == "hazm model":
docs_embedding = docs_embedding_with_hazm_model
elif w2v_model_type == "my model":
docs_embedding = doc
doc_id_score_dict = {}
for doc_vector in docs_embedding:
score = calculate_similarity(query_vector, doc_vector)
doc_id_score_dict[docs_embedding.index(doc_vector)] = score
sorted_document_id_score = sort_doc_score_dict(doc_id_score_dict)
return sorted_document_id_score
def create_docs_embedding(w2v_model):
# print(len(vectors_without_stop_words["0"]))
# print(len(w2v_model.wv.vo))
docs_embedding = []
for doc, doc_info in vectors_without_stop_words.items():
doc_vector = np.zeros(300)
weights_sum = 0
for token, weight in doc_info.items():
print(type(w2v_model.wv))
if token in w2v_model.wv:
doc_vector += w2v_model.wv[token] * weight
weights_sum += weight
docs_embedding.append(doc_vector / weights_sum)
return docs_embedding
def save_docs_embedding(docs_embedding, file_name):
docs_embedding_list = []
for doc in docs_embedding:
print(doc)
docs_embedding_list.append(list(doc))
with open(file_name, 'w', encoding='utf-8') as fp:
json.dump(docs_embedding_list, fp, sort_keys=True, indent=4, ensure_ascii=False)
def load_docs_embedding(file_name):
with open(file_name, 'r', encoding='utf-8') as fp:
docs_embedding = json.load(fp)
return docs_embedding
def calculate_similarity(doc1, doc2):
similarity_score = np.dot(doc1, doc2) / (norm(doc1) * norm(doc2))
return (similarity_score + 1) / 2
def create_query_vector_embedding(query_string, w2v_model_type):
if w2v_model_type == "hazm model":
w2v_model = w2v_hazm_model
elif w2v_model_type == "my model":
w2v_model = w2v_my_model
else:
raise Exception("Please enter valid W2V model! (hazm model or my model)")
query_tokens_dict = preprocess_query(query_string, "positional", True, True)
# term frequency - raw (tf-raw)
query_term_freq_dict = {}
for term, positional_index in query_tokens_dict.items():
query_term_freq_dict[term] = len(positional_index)
# final term frequency
query_final_term_frequency = calculate_final_weight("query", query_term_freq_dict)
print(query_final_term_frequency)
query_vector_embedding = np.zeros(300)
weight_sum = 0
for token, weight in query_final_term_frequency.items():
query_vector_embedding += w2v_model.wv[token] * weight
weight_sum += weight
query_vector_embedding /= weight_sum
return query_vector_embedding
# docs_embedding = create_docs_embedding()
# save_docs_embedding(docs_embedding, DOCS_EMBEDDING_FILE)
docs_embedding_with_my_model = load_docs_embedding(DOCS_EMBEDDING_FILE)
# docs_embedding_with_hazm_model = create_docs_embedding(w2v_hazm_model)
# save_docs_embedding(docs_embedding_with_hazm_model, DOCS_EMBEDDING_FILE_HAZM_MODEL)
docs_embedding_with_hazm_model = load_docs_embedding(DOCS_EMBEDDING_FILE_HAZM_MODEL)
# print(docs_embedding)
if __name__ == "__main__":
query_vector = create_query_vector_embedding("دانشگاه صنعتی امیرکبیر", "model")
print(query_vector)
print(final_search_result_in_embedding(query_vector, 10))