Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
heatingma committed Oct 29, 2024
1 parent 566bf5b commit 119d04a
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 30 deletions.
2 changes: 1 addition & 1 deletion ml4co_kit/generator/atsp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,5 +359,5 @@ def generate_uniform(self) -> Union[np.ndarray, np.ndarray]:
dist = (dist[:, None, :] + dist[None, :, :].transpose(0, 2, 1)).min(axis=2)
if (dist == old_dist).all():
break
dists.append(dist / scaler)
dists.append(dist / scaler)
return np.array(dists), None
6 changes: 3 additions & 3 deletions ml4co_kit/generator/cvrp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ def generate(self):
depots=batch_depots_coord,
points=batch_nodes_coord,
demands=batch_demands,
capacities=batch_capacities,
capacities=batch_capacities.reshape(-1),
num_threads=self.num_threads
)

# write to txt
with open(self.file_save_path, "a+") as f:
for idx, tour in enumerate(tours):
Expand All @@ -233,7 +233,7 @@ def generate(self):
f.write(" demands " + str(" ").join(str(demand) for demand in demands))
f.write(" capacity " + str(capicity))
f.write(str(" output "))
f.write(str(" ").join(str(node_idx) for node_idx in tour[0]))
f.write(str(" ").join(str(node_idx) for node_idx in tour))
f.write("\n")
f.close()

Expand Down
5 changes: 3 additions & 2 deletions ml4co_kit/generator/mis_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def check_solver(self):
if isinstance(self.solver, SOLVER_TYPE):
self.solver_type = self.solver
supported_solver_dict = {
SOLVER_TYPE.KAMIS: KaMISSolver,
SOLVER_TYPE.GUROBI: MISGurobiSolver
SOLVER_TYPE.GUROBI: MISGurobiSolver,
SOLVER_TYPE.KAMIS: KaMISSolver
}
supported_solver_type = supported_solver_dict.keys()
if self.solver not in supported_solver_type:
Expand All @@ -177,6 +177,7 @@ def check_solver(self):
# check solver
check_solver_dict = {
SOLVER_TYPE.GUROBI: self.check_free,
SOLVER_TYPE.KAMIS: self.check_free
}
check_func = check_solver_dict[self.solver_type]
check_func()
Expand Down
10 changes: 5 additions & 5 deletions ml4co_kit/generator/tsp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ def check_solver(self):
if isinstance(self.solver, SOLVER_TYPE):
self.solver_type = self.solver
supported_solver_dict = {
"LKH": TSPLKHSolver,
"Concorde": TSPConcordeSolver,
"Concorde-Large": TSPConcordeLargeSolver,
"GA-EAX": TSPGAEAXSolver,
"GA-EAX-Large": TSPGAEAXLargeSolver
SOLVER_TYPE.CONCORDE: TSPConcordeSolver,
SOLVER_TYPE.LKH: TSPLKHSolver,
SOLVER_TYPE.CONCORDE_LARGE: TSPConcordeLargeSolver,
SOLVER_TYPE.GA_EAX: TSPGAEAXSolver,
SOLVER_TYPE.GA_EAX_LARGE: TSPGAEAXLargeSolver
}
supported_solver_type = supported_solver_dict.keys()
if self.solver_type not in supported_solver_type:
Expand Down
38 changes: 19 additions & 19 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
##############################################

def _test_atsp_lkh_generator(
num_threads: int, nodes_num: int, data_type: str,
sat_vars_num: int = None, sat_clauses_nums: int = None
num_threads: int, nodes_num: int, data_type: str, sat_vars_num: int = None,
sat_clauses_nums: int = None, re_download: bool = False
):
"""
Test ATSPDataGenerator using ATSPLKHSolver
Expand All @@ -26,7 +26,7 @@ def _test_atsp_lkh_generator(
os.makedirs(save_path)

# create TSPDataGenerator using lkh solver
tsp_data_lkh = ATSPDataGenerator(
atsp_data_lkh = ATSPDataGenerator(
num_threads=num_threads,
nodes_num=nodes_num,
data_type=data_type,
Expand All @@ -38,9 +38,12 @@ def _test_atsp_lkh_generator(
sat_vars_nums=sat_vars_num,
sat_clauses_nums=sat_clauses_nums,
)


if re_download:
atsp_data_lkh.download_lkh()

# generate data
tsp_data_lkh.generate()
atsp_data_lkh.generate()

# remove the save path
shutil.rmtree(save_path)
Expand All @@ -50,6 +53,10 @@ def test_atsp():
"""
Test ATSPDataGenerator
"""
# uniform
_test_atsp_lkh_generator(
num_threads=4, nodes_num=50, data_type="uniform", re_download=True
)
# sat
_test_atsp_lkh_generator(
num_threads=4, nodes_num=55, data_type="sat", sat_clauses_nums=5, sat_vars_num=5
Expand All @@ -62,10 +69,6 @@ def test_atsp():
_test_atsp_lkh_generator(
num_threads=4, nodes_num=50, data_type="hcp"
)
# uniform
_test_atsp_lkh_generator(
num_threads=4, nodes_num=50, data_type="uniform"
)


##############################################
Expand Down Expand Up @@ -443,8 +446,7 @@ def test_mvc():
##############################################

def _test_tsp_lkh_generator(
num_threads: int, nodes_num: int, data_type: str,
regret: bool, re_download: bool=False
num_threads: int, nodes_num: int, data_type: str, regret: bool
):
"""
Test TSPDataGenerator using LKH Solver
Expand All @@ -465,8 +467,7 @@ def _test_tsp_lkh_generator(
save_path=save_path,
regret=regret,
)
if re_download:
tsp_data_lkh.download_lkh()

# generate data
tsp_data_lkh.generate()
# remove the save path
Expand Down Expand Up @@ -564,10 +565,9 @@ def test_tsp():
"""
Test TSPDataGenerator
"""
# re-download lkh
# threads
_test_tsp_lkh_generator(
num_threads=4, nodes_num=50, data_type="uniform",
regret=False, re_download=True
num_threads=4, nodes_num=50, data_type="uniform", regret=False
)
# regret & threads
_test_tsp_lkh_generator(
Expand Down Expand Up @@ -603,11 +603,11 @@ def test_tsp():
##############################################

if __name__ == "__main__":
test_tsp()
test_atsp()
test_cvrp()
test_mc()
test_mcl()
test_mis()
test_mvc()
test_cvrp()
test_atsp()
test_tsp()
shutil.rmtree("tmp")

0 comments on commit 119d04a

Please sign in to comment.