Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update the work to new version #118

Open
jiugexuan opened this issue Jul 3, 2024 · 1 comment
Open

update the work to new version #118

jiugexuan opened this issue Jul 3, 2024 · 1 comment

Comments

@jiugexuan
Copy link

I want update all dependencies to new version, so I need to remove the [Detectron2].For it,I write a new script to visual the result:

import mmcv
from mmdet.apis import init_detector, inference_detector, show_result_pyplot

# 使用你自己的配置文件和训练好的模型检查点
config_file = 'configs/psgtr/psgtr_r50_psg_inference.py'
checkpoint_file = 'work_dirs/psgtr_r50_e60/epoch_60.pth'

# 初始化检测器
model = init_detector(config_file, checkpoint_file, device='cuda:0')

# 测试单张图片
img = "./data/coco/val2017/000000568439.jpg"  # 或者 img = mmcv.imread(img), 只加载一次
# img = 'bw25.png'
# 运行推理
result = inference_detector(model, img)

import networkx as nx
from pyvis.network import Network
import mmcv
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import cv2
from mmdet.datasets.coco_panoptic import INSTANCE_OFFSET

# 定义类别
CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 
           'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 
           'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 
           'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 
           'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 
           'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 
           'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 
           'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 
           'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 
           'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 'blanket', 
           'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff', 'floor-wood', 'flower', 'fruit', 
           'gravel', 'house', 'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield', 
           'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow', 'stairs', 'tent', 
           'towel', 'wall-brick', 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'window-blind', 
           'window-other', 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', 
           'cabinet-merged', 'table-merged', 'floor-other-merged', 'pavement-merged', 'mountain-merged', 
           'grass-merged', 'dirt-merged', 'paper-merged', 'food-other-merged', 'building-other-merged', 
           'rock-merged', 'wall-other-merged', 'rug-merged', 'background']

# 定义关系
PREDICATES = [
    'over', 'in front of', 'beside', 'on', 'in', 'attached to', 'hanging from', 'on back of', 
    'falling off', 'going down', 'painted on', 'walking on', 'running on', 'crossing', 'standing on', 
    'lying on', 'sitting on', 'flying over', 'jumping over', 'jumping from', 'wearing', 'holding', 
    'carrying', 'looking at', 'guiding', 'kissing', 'eating', 'drinking', 'feeding', 'biting', 
    'catching', 'picking', 'playing with', 'chasing', 'climbing', 'cleaning', 'playing', 'touching', 
    'pushing', 'pulling', 'opening', 'cooking', 'talking to', 'throwing', 'slicing', 'driving', 
    'riding', 'parked on', 'driving on', 'about to hit', 'kicking', 'swinging', 'entering', 'exiting', 
    'enclosing', 'leaning on'
]

# 读取图像
img_path = img  # 替换为你的图像路径
img = mmcv.imread(img_path)
img_h, img_w = img.shape[:-1]

# 获取 pan_results
pan_results = result.pan_results

# 处理 ids
ids = np.unique(pan_results)[::-1]
num_classes = 133
legal_indices = (ids != num_classes)  # 过滤掉 VOID 标签
ids = ids[legal_indices]

# 获取预测标签
# INSTANCE_OFFSET = 1000  # 确保这是正确的偏移量
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)

# 创建标签计数器,用于生成唯一编号
label_counter = {}
unique_labels = []
for label in labels:
    if label not in label_counter:
        label_counter[label] = 0
    label_counter[label] += 1
    if label_counter[label] > 1:
        unique_labels.append(f'{CLASSES[label]}_{label_counter[label]}')
    else:
        same_label_count = labels.tolist().count(label)
        if same_label_count == 1:
            unique_labels.append(f'{CLASSES[label]}')
        else:
            unique_labels.append(f'{CLASSES[label]}_1')

# 获取分割掩码
segms = pan_results[None] == ids[:, None, None]

# 绘制分割结果和标签
# 绘制图像
plt.figure(figsize=(15, 15))
plt.imshow(img)

# 绘制分割结果和标签
for i, segm in enumerate(segms):
    # 生成浅色颜色掩码
    color_mask = np.random.rand(3) * 0.7 + 0.3  # 确保颜色较浅
    mask = segm.astype(np.uint8)

    # 找到掩码的轮廓
    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours:
        polygon = Polygon(contour.reshape(-1, 2), fill=True, edgecolor=color_mask, facecolor=color_mask, linewidth=0.5, alpha=0.5)
        plt.gca().add_patch(polygon)

    # 添加标签
    y, x = np.where(mask)
    if len(y) > 0 and len(x) > 0:
        center_y, center_x = int(y.mean()), int(x.mean())
        offset_y, offset_x = center_y - 5, center_x - 5  # 调整标签位置
        label = unique_labels[i]  # 使用唯一标签
        plt.text(offset_x, offset_y, label, color=color_mask, fontsize=8, ha='center', va='center', bbox=dict(facecolor='black', alpha=0.7, edgecolor='none'))

# 显示结果
plt.axis('off')
plt.show()

# 创建关系图
G = nx.DiGraph()

# 创建颜色映射
node_colors = {}
for cls in CLASSES:
    node_colors[cls] = np.random.rand(3) * 0.7 + 0.3  # 生成浅色颜色

# 添加节点
for i, label in enumerate(unique_labels):
    node_type = label.split('_')[0]
    G.add_node(str(i), label=label, color=node_colors[node_type], size=20)

# 添加关系边
rels = result.rels
colors = plt.cm.tab20(np.linspace(0, 1, len(PREDICATES)))
for rel in rels:
    subj_idx, obj_idx, rel_label = rel
    subj_str = str(subj_idx)
    obj_str = str(obj_idx)
    color = colors[int(rel_label)]  # 转换为整数索引
    G.add_edge(subj_str, obj_str, label=PREDICATES[int(rel_label)], color=color)

# 确保关系图中包含所有对象的标签
for rel in rels:
    subj_idx, obj_idx, rel_label = rel
    subj_str = str(subj_idx)
    obj_str = str(obj_idx)
    if subj_str not in G.nodes:
        G.add_node(subj_str, label=unique_labels[subj_idx], color=node_colors[unique_labels[subj_idx].split('_')[0]], size=20)
    if obj_str not in G.nodes:
        G.add_node(obj_str, label=unique_labels[obj_idx], color=node_colors[unique_labels[obj_idx].split('_')[0]], size=20)

# 使用 spring_layout 生成节点布局,使图更加离散
pos = nx.spring_layout(G, k=0.5)

# 创建 PyVis 网络图
net = Network(notebook=True, width="1500px", height="1500px", directed=True, cdn_resources='remote')

# 从 NetworkX 图导入 PyVis 图
net.from_nx(G)

# 设置节点标签和颜色
for node in net.nodes:
    node['title'] = node['label']
    node['label'] = node['label']
    color = node_colors[node['label'].split('_')[0]]
    node['color'] = 'rgba({}, {}, {}, 1)'.format(int(color[0]*255), int(color[1]*255), int(color[2]*255))

# 设置边标签和颜色
for edge in net.edges:
    edge['title'] = edge['label']
    edge['color'] = 'rgba({}, {}, {}, 1)'.format(int(edge['color'][0]*255), int(edge['color'][1]*255), int(edge['color'][2]*255))

# 显示 PyVis 图
net.show('relationship_graph.html')

but it has some issue,could any one help me?
output

@ZHUXUHAN
Copy link

ZHUXUHAN commented Oct 5, 2024

hi, @jiugexuan, do you solve this issue? Based on my observation, there doesn’t seem to be any issues with this visualization. What specific issues do you have in mind?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants