Skip to content

Commit ed76a20

Browse files
Merge pull request #52 from Living-with-machines/feature/51-pypi
Feature/51 pypi
2 parents a589cc9 + 2d3b930 commit ed76a20

16 files changed

+2795
-317
lines changed

.gitignore

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
__pycache__
22
.ipynb_checkpoints
3-
log.txt
4-
models
5-
pred_results.txt
63
.DS_Store
74
default.profraw
8-
log_test001.png
5+
DeezyMatch.egg-info
6+
build
7+
dist

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DeezyMatch/CONTRIBUTORS.txt

DeezyMatch/CONTRIBUTORS.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Coll Ardanuy, Mariona
2+
Hosseini, Kasra
3+
Nanni, Federico

DeezyMatch.py renamed to DeezyMatch/DeezyMatch.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: UTF-8 -*-
33

44
"""
5-
DeezyMatch main code: select the relevant module (train, finetune, inference, combine_vecs, candidate_finder)
5+
DeezyMatch main code: select the relevant module (train, finetune, inference, combine_vecs, candidate_ranker)
66
based on the inputs.
77
"""
88

@@ -12,18 +12,18 @@
1212
import shutil
1313
import sys
1414

15-
from candidateFinder import candidate_finder
16-
from candidateFinder import main as candidate_finder_main
17-
from combineVecs import combine_vecs
18-
from combineVecs import main as combine_vecs_main
19-
from data_processing import csv_split_tokenize
20-
from rnn_networks import gru_lstm_network, fine_tuning
21-
from rnn_networks import inference as rnn_inference
22-
from utils import deezy_mode_detector
23-
from utils import read_inputs_command, read_inference_command, read_input_file
24-
from utils import cprint, bc, log_message
15+
from .candidateRanker import candidate_ranker
16+
from .candidateRanker import main as candidate_ranker_main
17+
from .combineVecs import combine_vecs
18+
from .combineVecs import main as combine_vecs_main
19+
from .data_processing import csv_split_tokenize
20+
from .rnn_networks import gru_lstm_network, fine_tuning
21+
from .rnn_networks import inference as rnn_inference
22+
from .utils import deezy_mode_detector
23+
from .utils import read_inputs_command, read_inference_command, read_input_file
24+
from .utils import cprint, bc, log_message
2525
# --- set seed for reproducibility
26-
from utils import set_seed_everywhere
26+
from .utils import set_seed_everywhere
2727
set_seed_everywhere(1364)
2828

2929
# ------------------- train --------------------
@@ -282,8 +282,8 @@ def main():
282282
elif dm_mode in ["combine_vecs"]:
283283
combine_vecs_main()
284284

285-
elif dm_mode in ["candidate_finder"]:
286-
candidate_finder_main()
285+
elif dm_mode in ["candidate_ranker"]:
286+
candidate_ranker_main()
287287

288288
if __name__ == '__main__':
289289
main()

DeezyMatch/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from DeezyMatch.DeezyMatch import train
2+
from DeezyMatch.DeezyMatch import finetune
3+
from DeezyMatch.DeezyMatch import inference
4+
from DeezyMatch.DeezyMatch import combine_vecs
5+
from DeezyMatch.DeezyMatch import candidate_ranker
Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,20 @@
1919
import torch
2020
from torch.utils.data import DataLoader
2121

22-
from data_processing import test_tokenize
23-
from rnn_networks import test_model
24-
from utils import read_input_file
25-
from utils import read_command_candidate_finder
22+
from .data_processing import test_tokenize
23+
from .rnn_networks import test_model
24+
from .utils import read_input_file
25+
from .utils import read_command_candidate_ranker
2626
# --- set seed for reproducibility
27-
from utils import set_seed_everywhere
27+
from .utils import set_seed_everywhere
2828
set_seed_everywhere(1364)
2929

3030
# skip future warnings for now XXX
3131
import warnings
3232
warnings.simplefilter(action='ignore', category=FutureWarning)
3333

34-
# ------------------- candidate_finder --------------------
35-
def candidate_finder(input_file_path="default", scenario=None, ranking_metric="faiss", selection_threshold=0.8,
34+
# ------------------- candidate_ranker --------------------
35+
def candidate_ranker(input_file_path="default", scenario=None, ranking_metric="faiss", selection_threshold=0.8,
3636
num_candidates=10, search_size=4, output_filename=None,
3737
pretrained_model_path=None, pretrained_vocab_path=None, number_test_rows=-1):
3838

@@ -191,10 +191,19 @@ def candidate_finder(input_file_path="default", scenario=None, ranking_metric="f
191191
sys.exit(f"[ERROR] ranking_metric: {ranking_metric} is not implemented. See the documentation.")
192192

193193
num_found_candidates += len(query_candidate_filtered_pd)
194-
print("ID: %s/%s -- Number of found candidates so far: %s, search span: 0, %s" % (iq, len(vecs_query), num_found_candidates, id_1_neigh))
194+
print("ID: %s/%s -- Number of found candidates so far: %s, searched: %s" % (iq+1, len(vecs_query), num_found_candidates, id_1_neigh))
195195

196196
if num_found_candidates > 0:
197197
collect_neigh_pd = collect_neigh_pd.append(query_candidate_filtered_pd)
198+
199+
if ranking_metric.lower() in ["faiss"]:
200+
# 1.01 is multiplied to avoid issues with float numbers and rounding erros
201+
if query_candidate_pd["faiss_dist"].max() > (selection_threshold*1.01):
202+
break
203+
elif ranking_metric.lower() in ["cosine"]:
204+
# 0.99 is multiplied to avoid issues with float numbers and rounding errors
205+
if query_candidate_pd["cosine_sim"].min() < (selection_threshold*0.99):
206+
break
198207

199208
# Go to the next zone
200209
if (num_found_candidates < num_candidates):
@@ -223,7 +232,7 @@ def candidate_finder(input_file_path="default", scenario=None, ranking_metric="f
223232
mydict_candid_id[row["s2"]] = row["s2_orig_ids"]
224233
one_row = {
225234
"id": orig_id_queries,
226-
"toponym": all_queries[0],
235+
"query": all_queries[0],
227236
"pred_score": [mydict_dl_match],
228237
"faiss_distance": [mydict_faiss_dist],
229238
"cosine_sim": [mydict_cosine_sim],
@@ -243,10 +252,10 @@ def main():
243252
# --- read args from the command line
244253
output_filename, selection_threshold, ranking_metric, search_size, num_candidates, \
245254
par_dir, input_file_path, number_test_rows, pretrained_model_path, pretrained_vocab_path = \
246-
read_command_candidate_finder()
255+
read_command_candidate_ranker()
247256

248257
# ---
249-
candidate_finder(input_file_path=input_file_path,
258+
candidate_ranker(input_file_path=input_file_path,
250259
scenario=par_dir,
251260
ranking_metric=ranking_metric,
252261
selection_threshold=selection_threshold,

combineVecs.py renamed to DeezyMatch/combineVecs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222

2323
start_time = time.time()
2424

25-
from utils import read_input_file
26-
from utils import sort_key
27-
from utils import read_command_combinevecs
25+
from .utils import read_input_file
26+
from .utils import sort_key
27+
from .utils import read_command_combinevecs
2828
# --- set seed for reproducibility
29-
from utils import set_seed_everywhere
29+
from .utils import set_seed_everywhere
3030
set_seed_everywhere(1364)
3131

3232
# ------------------- combine_vecs --------------------
Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
import pickle
1212
from torch.utils.data import Dataset
1313

14-
from utils import cprint, bc
15-
from utils import string_split
16-
from utils import normalizeString
14+
from .utils import cprint, bc
15+
from .utils import string_split
16+
from .utils import normalizeString
1717
# --- set seed for reproducibility
18-
from utils import set_seed_everywhere
18+
from .utils import set_seed_everywhere
1919
set_seed_everywhere(1364)
2020

2121

@@ -70,7 +70,10 @@ def csv_split_tokenize(dataset_path, pretrained_vocab_path=None, n_train_example
7070
n_total = len(rows_one_label)
7171

7272
if n_train_examples:
73-
# number of positive examples
73+
# We have two sets of labels: True and False
74+
# Here, we divide the number of requested rows by two
75+
# This way 50% of the requested rows will be True and 50% will be False
76+
# Compare this with n_train = int(train_prop * n_total)
7477
n_pos = int(int(n_train_examples)/2)
7578
n_train = n_pos
7679
else:
Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@
3636
import numpy as np
3737
import sys
3838

39-
from data_processing import test_tokenize
40-
from utils import cprint, bc, log_message
41-
from utils import print_stats
42-
from utils import torch_summarize
43-
from utils import create_parent_dir
44-
from utils import eval_map
39+
from .data_processing import test_tokenize
40+
from .utils import cprint, bc, log_message
41+
from .utils import print_stats
42+
from .utils import torch_summarize
43+
from .utils import create_parent_dir
44+
from .utils import eval_map
4545
# --- set seed for reproducibility
46-
from utils import set_seed_everywhere
46+
from .utils import set_seed_everywhere
4747
set_seed_everywhere(1364)
4848

4949
# skip future warnings for now XXX
@@ -400,7 +400,7 @@ def test_model(model, test_dl, eval_mode='test', valid_desc=None,
400400
if eval_mode == 'valid':
401401
eval_desc = valid_desc
402402
elif eval_mode == 'test':
403-
eval_desc = "test"
403+
eval_desc = 'Epoch: 0/0; Test'
404404

405405
t_test.set_description(eval_mode)
406406

@@ -423,7 +423,9 @@ def test_model(model, test_dl, eval_mode='test', valid_desc=None,
423423
len2 = len2.numpy()
424424

425425
with torch.no_grad():
426-
pred = model(x1, len1, x2, len2, pooling_mode=pooling_mode, device=device, output_state_vectors=output_state_vectors, evaluation=evaluation)
426+
pred = model(x1, len1, x2, len2, pooling_mode=pooling_mode,
427+
device=device, output_state_vectors=output_state_vectors,
428+
evaluation=evaluation)
427429
if output_state_vectors:
428430
all_preds = []
429431
continue
@@ -845,7 +847,8 @@ def inference(model_path, dataset_path, train_vocab_path, input_file_path,
845847
output_preds=dl_inputs['inference']['output_preds'],
846848
output_preds_file=output_preds_file,
847849
csv_sep=dl_inputs['preprocessing']['csv_sep'],
848-
map_flag=dl_inputs['inference']['eval_map_metric']
850+
map_flag=dl_inputs['inference']['eval_map_metric'],
851+
model_path=os.path.dirname(os.path.abspath(model_path))
849852
)
850853

851854
print("--- %s seconds ---" % (time.time() - start_time))

utils.py renamed to DeezyMatch/utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ def deezy_mode_detector():
9898

9999
parser = ArgumentParser()
100100
parser.add_argument("--deezy_mode",
101-
help="DeezyMatch mode (options: train, finetune, inference, combine_vecs, candidate_finder)",
101+
help="DeezyMatch mode (options: train, finetune, inference, combine_vecs, candidate_ranker)",
102102
default="train",
103103
)
104104
dm_mode, unknown = parser.parse_known_args()
105105
dm_mode = dm_mode.deezy_mode.lower()
106-
if dm_mode not in ["train", "finetune", "inference", "combine_vecs", "candidate_finder"]:
107-
parser.exit(f"ERROR: implemeted modes are: train, finetune, inference, combine_vecs, candidate_finder (input: {dm_mode})")
106+
if dm_mode not in ["train", "finetune", "inference", "combine_vecs", "candidate_ranker"]:
107+
parser.exit(f"ERROR: implemeted modes are: train, finetune, inference, combine_vecs, candidate_ranker (input: {dm_mode})")
108108

109109
return dm_mode
110110

@@ -201,7 +201,7 @@ def read_inputs_command():
201201
parser.exit(f"ERROR: model {fine_tuning_model_path} not found!")
202202

203203
if os.path.exists(vocab_path) is False:
204-
parser.exit(f"ERROR: vocab {vocab} not found!")
204+
parser.exit(f"ERROR: vocab {vocab_path} not found!")
205205

206206
else:
207207
fine_tuning_model_name = os.path.split(fine_tuning_model)[-1]
@@ -296,13 +296,13 @@ def read_command_combinevecs():
296296
input_file_path = args.input_file_path
297297
return qc_mode, cq_sc, rnn_pass, combined_sc, input_file_path
298298

299-
# ------------------- read_command_candidate_finder --------------------
300-
def read_command_candidate_finder():
299+
# ------------------- read_command_candidate_ranker --------------------
300+
def read_command_candidate_ranker():
301301
parser = ArgumentParser()
302302

303303
parser.add_argument("--deezy_mode",
304304
help="DeezyMatch mode",
305-
default="candidate_finder"
305+
default="candidate_ranker"
306306
)
307307

308308
parser.add_argument("-t", "--threshold",
@@ -558,7 +558,9 @@ def log_plotter(path2log, dataset="DEFAULT"):
558558
train_arr = []
559559
valid_arr = []
560560
time_arr = []
561-
for one_line in log[3:]:
561+
for one_line in log[2:]:
562+
if one_line.lower().strip().startswith("python"):
563+
continue
562564
line_split = one_line.split()
563565
datetime_str = line_split[0]
564566
epoch = int(line_split[3].split("/")[0])
@@ -655,6 +657,12 @@ def log_plotter(path2log, dataset="DEFAULT"):
655657
plt.subplot(3, 2, 5)
656658
plt.title(f"Dataset: {dataset}\nTotal time: {total_time}s, Ave. Time / epoch: {total_time/(len(time_arr)-1):.3f}s", size=16)
657659
plt.plot(train_arr[1:, 0], diff_time, c="k", lw=2)
660+
661+
# If min_valid_arg is 0 (the first model has the lowest valid loss)
662+
# Increment min_valid_arg for Time as we use cumsum (lose one point in the plot)
663+
if min_valid_arg == 0:
664+
min_valid_arg += 1
665+
658666
if plot_valid:
659667
plt.axvline(valid_arr[min_valid_arg, 0], 0, 1, ls="--", c="k")
660668
plt.text(valid_arr[min_valid_arg, 0]*1.05, min(diff_time)*0.98,

0 commit comments

Comments
 (0)