diff --git a/ml4co_kit/generator/atsp_data.py b/ml4co_kit/generator/atsp_data.py index 0b9ac40..3d37213 100644 --- a/ml4co_kit/generator/atsp_data.py +++ b/ml4co_kit/generator/atsp_data.py @@ -359,5 +359,5 @@ def generate_uniform(self) -> Union[np.ndarray, np.ndarray]: dist = (dist[:, None, :] + dist[None, :, :].transpose(0, 2, 1)).min(axis=2) if (dist == old_dist).all(): break - dists.append(dist / scaler) + dists.append(dist / scaler) return np.array(dists), None diff --git a/ml4co_kit/generator/cvrp_data.py b/ml4co_kit/generator/cvrp_data.py index 6726bbd..bb4773b 100644 --- a/ml4co_kit/generator/cvrp_data.py +++ b/ml4co_kit/generator/cvrp_data.py @@ -211,10 +211,10 @@ def generate(self): depots=batch_depots_coord, points=batch_nodes_coord, demands=batch_demands, - capacities=batch_capacities, + capacities=batch_capacities.reshape(-1), num_threads=self.num_threads ) - + # write to txt with open(self.file_save_path, "a+") as f: for idx, tour in enumerate(tours): @@ -233,7 +233,7 @@ def generate(self): f.write(" demands " + str(" ").join(str(demand) for demand in demands)) f.write(" capacity " + str(capicity)) f.write(str(" output ")) - f.write(str(" ").join(str(node_idx) for node_idx in tour[0])) + f.write(str(" ").join(str(node_idx) for node_idx in tour)) f.write("\n") f.close() diff --git a/ml4co_kit/generator/mis_data.py b/ml4co_kit/generator/mis_data.py index d12f660..25e0c09 100644 --- a/ml4co_kit/generator/mis_data.py +++ b/ml4co_kit/generator/mis_data.py @@ -159,8 +159,8 @@ def check_solver(self): if isinstance(self.solver, SOLVER_TYPE): self.solver_type = self.solver supported_solver_dict = { - SOLVER_TYPE.KAMIS: KaMISSolver, - SOLVER_TYPE.GUROBI: MISGurobiSolver + SOLVER_TYPE.GUROBI: MISGurobiSolver, + SOLVER_TYPE.KAMIS: KaMISSolver } supported_solver_type = supported_solver_dict.keys() if self.solver not in supported_solver_type: @@ -177,6 +177,7 @@ def check_solver(self): # check solver check_solver_dict = { SOLVER_TYPE.GUROBI: self.check_free, + SOLVER_TYPE.KAMIS: self.check_free } check_func = check_solver_dict[self.solver_type] check_func() diff --git a/ml4co_kit/generator/tsp_data.py b/ml4co_kit/generator/tsp_data.py index 9e6177c..6de2bdc 100644 --- a/ml4co_kit/generator/tsp_data.py +++ b/ml4co_kit/generator/tsp_data.py @@ -132,11 +132,11 @@ def check_solver(self): if isinstance(self.solver, SOLVER_TYPE): self.solver_type = self.solver supported_solver_dict = { - "LKH": TSPLKHSolver, - "Concorde": TSPConcordeSolver, - "Concorde-Large": TSPConcordeLargeSolver, - "GA-EAX": TSPGAEAXSolver, - "GA-EAX-Large": TSPGAEAXLargeSolver + SOLVER_TYPE.CONCORDE: TSPConcordeSolver, + SOLVER_TYPE.LKH: TSPLKHSolver, + SOLVER_TYPE.CONCORDE_LARGE: TSPConcordeLargeSolver, + SOLVER_TYPE.GA_EAX: TSPGAEAXSolver, + SOLVER_TYPE.GA_EAX_LARGE: TSPGAEAXLargeSolver } supported_solver_type = supported_solver_dict.keys() if self.solver_type not in supported_solver_type: diff --git a/tests/test_generator.py b/tests/test_generator.py index 8adf5f2..bf90c46 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -14,8 +14,8 @@ ############################################## def _test_atsp_lkh_generator( - num_threads: int, nodes_num: int, data_type: str, - sat_vars_num: int = None, sat_clauses_nums: int = None + num_threads: int, nodes_num: int, data_type: str, sat_vars_num: int = None, + sat_clauses_nums: int = None, re_download: bool = False ): """ Test ATSPDataGenerator using ATSPLKHSolver @@ -26,7 +26,7 @@ def _test_atsp_lkh_generator( os.makedirs(save_path) # create TSPDataGenerator using lkh solver - tsp_data_lkh = ATSPDataGenerator( + atsp_data_lkh = ATSPDataGenerator( num_threads=num_threads, nodes_num=nodes_num, data_type=data_type, @@ -38,9 +38,12 @@ def _test_atsp_lkh_generator( sat_vars_nums=sat_vars_num, sat_clauses_nums=sat_clauses_nums, ) - + + if re_download: + atsp_data_lkh.download_lkh() + # generate data - tsp_data_lkh.generate() + atsp_data_lkh.generate() # remove the save path shutil.rmtree(save_path) @@ -50,6 +53,10 @@ def test_atsp(): """ Test ATSPDataGenerator """ + # uniform + _test_atsp_lkh_generator( + num_threads=4, nodes_num=50, data_type="uniform", re_download=True + ) # sat _test_atsp_lkh_generator( num_threads=4, nodes_num=55, data_type="sat", sat_clauses_nums=5, sat_vars_num=5 @@ -62,10 +69,6 @@ def test_atsp(): _test_atsp_lkh_generator( num_threads=4, nodes_num=50, data_type="hcp" ) - # uniform - _test_atsp_lkh_generator( - num_threads=4, nodes_num=50, data_type="uniform" - ) ############################################## @@ -443,8 +446,7 @@ def test_mvc(): ############################################## def _test_tsp_lkh_generator( - num_threads: int, nodes_num: int, data_type: str, - regret: bool, re_download: bool=False + num_threads: int, nodes_num: int, data_type: str, regret: bool ): """ Test TSPDataGenerator using LKH Solver @@ -465,8 +467,7 @@ def _test_tsp_lkh_generator( save_path=save_path, regret=regret, ) - if re_download: - tsp_data_lkh.download_lkh() + # generate data tsp_data_lkh.generate() # remove the save path @@ -564,10 +565,9 @@ def test_tsp(): """ Test TSPDataGenerator """ - # re-download lkh + # threads _test_tsp_lkh_generator( - num_threads=4, nodes_num=50, data_type="uniform", - regret=False, re_download=True + num_threads=4, nodes_num=50, data_type="uniform", regret=False ) # regret & threads _test_tsp_lkh_generator( @@ -603,11 +603,11 @@ def test_tsp(): ############################################## if __name__ == "__main__": - test_tsp() + test_atsp() + test_cvrp() test_mc() test_mcl() test_mis() test_mvc() - test_cvrp() - test_atsp() + test_tsp() shutil.rmtree("tmp")