1919import torch
2020from torch .utils .data import DataLoader
2121
22- from data_processing import test_tokenize
23- from rnn_networks import test_model
24- from utils import read_input_file
25- from utils import read_command_candidate_finder
22+ from . data_processing import test_tokenize
23+ from . rnn_networks import test_model
24+ from . utils import read_input_file
25+ from . utils import read_command_candidate_ranker
2626# --- set seed for reproducibility
27- from utils import set_seed_everywhere
27+ from . utils import set_seed_everywhere
2828set_seed_everywhere (1364 )
2929
3030# skip future warnings for now XXX
3131import warnings
3232warnings .simplefilter (action = 'ignore' , category = FutureWarning )
3333
34- # ------------------- candidate_finder --------------------
35- def candidate_finder (input_file_path = "default" , scenario = None , ranking_metric = "faiss" , selection_threshold = 0.8 ,
34+ # ------------------- candidate_ranker --------------------
35+ def candidate_ranker (input_file_path = "default" , scenario = None , ranking_metric = "faiss" , selection_threshold = 0.8 ,
3636 num_candidates = 10 , search_size = 4 , output_filename = None ,
3737 pretrained_model_path = None , pretrained_vocab_path = None , number_test_rows = - 1 ):
3838
@@ -191,10 +191,19 @@ def candidate_finder(input_file_path="default", scenario=None, ranking_metric="f
191191 sys .exit (f"[ERROR] ranking_metric: { ranking_metric } is not implemented. See the documentation." )
192192
193193 num_found_candidates += len (query_candidate_filtered_pd )
194- print ("ID: %s/%s -- Number of found candidates so far: %s, search span: 0, %s" % (iq , len (vecs_query ), num_found_candidates , id_1_neigh ))
194+ print ("ID: %s/%s -- Number of found candidates so far: %s, searched: %s" % (iq + 1 , len (vecs_query ), num_found_candidates , id_1_neigh ))
195195
196196 if num_found_candidates > 0 :
197197 collect_neigh_pd = collect_neigh_pd .append (query_candidate_filtered_pd )
198+
199+ if ranking_metric .lower () in ["faiss" ]:
200+ # 1.01 is multiplied to avoid issues with float numbers and rounding erros
201+ if query_candidate_pd ["faiss_dist" ].max () > (selection_threshold * 1.01 ):
202+ break
203+ elif ranking_metric .lower () in ["cosine" ]:
204+ # 0.99 is multiplied to avoid issues with float numbers and rounding errors
205+ if query_candidate_pd ["cosine_sim" ].min () < (selection_threshold * 0.99 ):
206+ break
198207
199208 # Go to the next zone
200209 if (num_found_candidates < num_candidates ):
@@ -223,7 +232,7 @@ def candidate_finder(input_file_path="default", scenario=None, ranking_metric="f
223232 mydict_candid_id [row ["s2" ]] = row ["s2_orig_ids" ]
224233 one_row = {
225234 "id" : orig_id_queries ,
226- "toponym " : all_queries [0 ],
235+ "query " : all_queries [0 ],
227236 "pred_score" : [mydict_dl_match ],
228237 "faiss_distance" : [mydict_faiss_dist ],
229238 "cosine_sim" : [mydict_cosine_sim ],
@@ -243,10 +252,10 @@ def main():
243252 # --- read args from the command line
244253 output_filename , selection_threshold , ranking_metric , search_size , num_candidates , \
245254 par_dir , input_file_path , number_test_rows , pretrained_model_path , pretrained_vocab_path = \
246- read_command_candidate_finder ()
255+ read_command_candidate_ranker ()
247256
248257 # ---
249- candidate_finder (input_file_path = input_file_path ,
258+ candidate_ranker (input_file_path = input_file_path ,
250259 scenario = par_dir ,
251260 ranking_metric = ranking_metric ,
252261 selection_threshold = selection_threshold ,
0 commit comments