Skip to content

Commit

Permalink
update docs and trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
ubuntu committed Apr 15, 2024
1 parent 4dab06d commit 3c62681
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 28 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,16 @@ from data4co.draw.tsp import draw_tsp_solution, draw_tsp_problem

# use TSPConcordeSolver to solve the problem
solver = TSPConcordeSolver(scale=1)
solver.from_tsp("docs/kroA150.tsp")
solver.from_tsp("examples/tsp/kroA150.tsp")
solver.solve(norm="EUC_2D")

# draw
draw_tsp_problem(
save_path="docs/kroA150_problem.png",
save_path="docs/assets/kroA150_problem.png",
points=solver.ori_points,
)
draw_tsp_solution(
save_path="docs/kroA150_solution.png",
save_path="docs/assets/kroA150_solution.png",
points=solver.ori_points,
tours=solver.tours
)
Expand All @@ -187,17 +187,17 @@ from data4co import draw_mis_problem, draw_mis_solution

# use KaMISSolver to solve the problem
mis_solver = KaMISSolver()
mis_solver.solve(src="docs/mis_example")
mis_solver.solve(src="examples/mis_example")

# draw
draw_mis_problem(
save_path="docs/mis_problem.png",
ckle_path="docs/mis_example/mis_example.gpickle"
save_path="docs/assets/mis_problem.png",
gpickle_path="examples/mis/mis_example.gpickle"
)
draw_mis_solution(
save_path="docs/mis_solution.png",
gpickle_path="docs/mis_example/mis_example.gpickle",
result_path="docs/mis_example/solve/mis_example_unweighted.result"
gpickle_path="examples/mis/mis_example.gpickle",
result_path="examples/mis/solve/mis_example_unweighted.result"
)
```

Expand Down
4 changes: 0 additions & 4 deletions docs/project_example/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import os
import torch
import pickle
import numpy as np
from sklearn.neighbors import KDTree
from torch_geometric.data import Data as GraphData


class TSPDataset(torch.utils.data.Dataset):
Expand Down
2 changes: 1 addition & 1 deletion docs/project_example/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.utils.data
from ml4co_kit.learning.env import BaseEnv
from data import TSPDataset
from ml4co_kit.learning.env import BaseEnv
from torch_geometric.data import DataLoader as GraphDataLoader


Expand Down
6 changes: 3 additions & 3 deletions docs/project_example/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch
import numpy as np
from typing import Any, Union
import torch.nn.functional as F
from ml4co_kit.learning.model import BaseEncoder
from ml4co_kit.evaluate import TSPEvaluator
from env import TSPEnv
import torch.nn.functional as F
from ml4co_kit.learning.search import SearchConfigurator
from env import TSPEnv
from search import tsp_greedy, tsp_2opt
from typing import Any, Union


class TSPGNN(BaseEncoder):
Expand Down
8 changes: 7 additions & 1 deletion docs/project_example/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import numpy as np


def tsp_greedy(adj_mat, np_points, parallel_sampling=1, device="cpu", **kwargs):
def tsp_greedy(
adj_mat: np.ndarray,
np_points: np.ndarray,
parallel_sampling: int=1,
device: str="cpu",
**kwargs
):
raise NotImplementedError


Expand Down
7 changes: 6 additions & 1 deletion docs/project_example/solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
import torch
import numpy as np
from tqdm import tqdm
from typing import Union, Any
Expand Down Expand Up @@ -84,6 +83,12 @@ def solve(
decoded_tours = list()
for idx in tqdm(range(self.points.shape[0]), desc="Decoding"):
adj_mat = np.expand_dims(heatmap[idx], axis=0)
tour = self.decoding_func(
adj_mat=adj_mat,
np_points=self.points[idx],
edge_index_np=None,
**decoding_kwargs
)
decoded_tours.append(tour[0])
decoded_tours = np.array(decoded_tours)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
16 changes: 7 additions & 9 deletions ml4co_kit/learning/trainer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import os
from wandb.util import generate_id
from typing import Optional
import torch
from torch import nn
from typing import Optional
from wandb.util import generate_id
from typing import Union, Optional
from pytorch_lightning import Trainer
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.strategies import Strategy, DDPStrategy
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.loggers import WandbLogger


class Checkpoint(ModelCheckpoint):
Expand Down Expand Up @@ -46,7 +44,7 @@ def __init__(
resume_id: Optional[str] = None,
):
if not os.path.exists(save_dir):
os.mkdir(save_dir)
os.makedirs(save_dir)
if id is None and resume_id is None:
wandb_id = os.getenv("WANDB_RUN_ID") or generate_id()
else:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_tag(self):
author=AUTHOR,
url=URL,
packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]),
package_data={NAME: ["**"]},
package_data={NAME: ["**"], "docs": ["**"]},
install_requires=REQUIRED,
extras_require=EXTRAS,
include_package_data=True,
Expand Down

0 comments on commit 3c62681

Please sign in to comment.