Skip to content

Commit

Permalink
Merge pull request #21 from rootvisionai/feature/faster_processing
Browse files Browse the repository at this point in the history
Feature/faster processing
  • Loading branch information
rootvisionai authored Jul 23, 2023
2 parents 6b24fd8 + 8da80b1 commit 98f358a
Show file tree
Hide file tree
Showing 8 changed files with 368 additions and 150 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,3 @@
*.drawio
/dev_gitignored/
!frontend_python/make_request_local.py

22 changes: 17 additions & 5 deletions README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ pip install git+https://github.com/facebookresearch/segment-anything.git
or clone the repository locally and install with

```
git clone [email protected]:facebookresearch/segment-anything.git
git clone [email protected]:rootvisionai/segment-anything.git
cd segment-anything; pip install -e .
```

The following dependencies are necessary for the FEWSAM:

```
pip install opencv-python PyYAML PySimpleGUI kmeans-pytorch
pip install opencv-python PyYAML PySimpleGUI
```

Now download the model checkpoints:
Expand All @@ -36,15 +36,27 @@ More accurate <<< [VIT-H](https://dl.fbaipublicfiles.com/segment_anything/sam_vi
| [VIT-L](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth)
| [VIT-B](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) >>> Faster

## RUN
## START SERVER

Before you start the application, create a folder to put your
support images that will be used to learn from, then create a
folder to put your query images that are going to be labeled.
Put the relative path to the folders to support_dir and query_dir in config.yml.
Then, let the magic begin ...

#### To create request json that will be sent to server
```commandline
python interface.py
```
then adjust make_request.py according to your images and paths

Finally, run the server ...
```
python main.py
python backend/server.py
```

and make request while server.py is running
```commandline
python make_request.py
```

## DOCKERIZATION
Expand Down
8 changes: 4 additions & 4 deletions backend/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,16 @@ def generate_polygons_from_mask(polygons, mask, label, polygon_resolution):

# Generate polygons from the contours
points_ = []

instances, num_instances = find_instances(mask)
for k in range(1, num_instances+1, 1):
instance = ((instances == k)*1).astype(np.uint8)

# Find the contours in the binary mask
contours, _ = cv2.findContours(instance, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
filtered_contours = [contour for contour in contours if cv2.contourArea(contour) >= 100]

for i, contour in enumerate(contours):
if int(len(contour)*polygon_resolution)>0:
for i, contour in enumerate(filtered_contours):
if int(len(contour)*polygon_resolution) > 2:
points = contour.squeeze()[np.arange(0,
len(contour),
int(len(contour)/int(len(contour)*polygon_resolution))
Expand All @@ -128,7 +128,7 @@ def generate_polygons_from_mask(polygons, mask, label, polygon_resolution):
"shape_type": "polygon",
"flags": {}
})
points_.append(points)
points_.append(np.array(points))

return polygons, points_

Expand Down
42 changes: 31 additions & 11 deletions backend/exact_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"""
import torch
import numpy as np
import torchvision.utils
import tqdm
from server_utils import flatten_feature_map, l2_norm


Expand Down Expand Up @@ -33,41 +35,59 @@ def __init__(self,
super(ExactSolution, self).__init__()

self.device = device
self.embedding_collection = embedding_collection
# self.loss_func = torch.nn.CrossEntropyLoss()

self.embedding_collection = l2_norm(embedding_collection).to(self.device).float()
self.threshold = threshold
self.num_classes = len(np.unique(labels_int))
# self.labels_int = torch.from_numpy(labels_int).to(self.device).long()
self.labels_bin = binarize_labels(labels_int)
self.linear = torch.nn.Linear(in_features=self.embedding_collection.shape[1],
out_features=self.labels_bin.shape[1],
bias=False)
self.solve_exact()
# self.opt = torch.optim.SGD(momentum=0.9, lr=0.1, params=self.linear.parameters())
# self.train()
# self.train_linear()
self.eval()

def solve_exact(self):
collection_inverse = torch.pinverse(l2_norm(self.embedding_collection)).float()
collection_inverse = torch.pinverse(self.embedding_collection)
self.W = torch.matmul(collection_inverse.to(self.device),
self.labels_bin.to(self.device))
with torch.no_grad():
self.linear.weight = torch.nn.Parameter(self.W.T)

def infer(self, query_features):

with torch.no_grad():

b, n, h, w = query_features.shape
query_features = flatten_feature_map(query_features)
query_features = l2_norm(query_features)[0]
predictions = self.forward(query_features.float())
predictions = predictions.reshape(b, h, w).squeeze(0)
out = self.forward(query_features.float())

# get indexes of maximums
predictions = out[:, 1]

predictions = predictions.reshape(h, w)
torchvision.utils.save_image(predictions.cpu().float(), "./intermediate_preds.png")

return predictions

def train_linear(self):
pbar = tqdm.tqdm(range(0, 200))
for epoch in pbar:
out = self.linear(self.embedding_collection)
loss = self.loss_func(out, self.labels_int)
pbar.set_description(f"EPOCH: {epoch} | LOSS: {loss.item()}")

self.opt.zero_grad()
loss.backward()
self.opt.step()

def forward(self, embedding):
out = self.linear(embedding)
out = torch.where(out > 1, 2-out, out)
out = torch.nn.functional.softmax(out, dim=-1)

# apply adaptive threshold
self.threshold = out[:, 1].max()-0.02 if self.threshold > out[:, 1].max() else self.threshold
out = torch.where(out >= self.threshold, 1, 0)

# get indexes of maximums
out = out.argmax(dim=-1)
return out
Loading

0 comments on commit 98f358a

Please sign in to comment.