-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
165 lines (130 loc) · 6.3 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import tyro
import glob
import imageio
import numpy as np
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from safetensors.torch import load_file
import rembg
from core.mvgamba_models import Gamba
import kiui
from kiui.op import recenter
from kiui.cam import orbit_camera
import pdb
from core.options import AllConfigs, Options
from mvdream.pipeline_mvdream import MVDreamPipeline
import json
import random
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
import cv2
opt = tyro.cli(AllConfigs)
model = Gamba(opt)
# resume pretrained checkpoint
if opt.resume is not None:
if opt.resume.endswith('safetensors'):
ckpt = load_file(opt.resume, device='cpu')
model.load_state_dict(ckpt, strict=False)
else:
ckpt = torch.load(opt.resume, map_location='cpu')
print(f"resume from {opt.resume}, loading ...")
model.load_state_dict(ckpt['model'], strict=True)
print(f'[INFO] Loaded checkpoint from {opt.resume}')
else:
print(f'[WARN] model randomly initialized, are you sure?')
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
rays_embeddings = model.prepare_default_rays(device)
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
proj_matrix[0, 0] = 1 / tan_half_fov
proj_matrix[1, 1] = 1 / tan_half_fov
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
proj_matrix[2, 3] = 1
#load image dream
pipe = MVDreamPipeline.from_pretrained(
"ashawkey/imagedream-ipmv-diffusers", # remote weights
torch_dtype=torch.float16,
trust_remote_code=True,
local_files_only=False,
)
pipe = pipe.to(device)
# load rembg
bg_remover = rembg.new_session()
# process function
def process(opt: Options, path, test_order=1):
name = os.path.splitext(os.path.basename(path))[0]
print(f'[INFO] Processing {path} --> {name}')
os.makedirs(opt.workspace, exist_ok=True)
input_image = kiui.read_image(path, mode='uint8')
# # bg removal
carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
mask = carved_image[..., -1] > 0
# recenter
image = recenter(carved_image, mask, border_ratio=0.2)
# generate mv
image = image.astype(np.float32) / 255.0
# rgba to rgb white bg
if image.shape[-1] == 4:
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=50, elevation=0)
if test_order == 1:
mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
else:
mv_image = np.stack([mv_image[0], mv_image[1], mv_image[2], mv_image[3]], axis=0)
inshow_image= (mv_image * 255).astype(np.uint8)
imageio.imwrite(os.path.join(opt.workspace, name + '_' + str(test_order) +'.png'), np.vstack((np.hstack((inshow_image[0], inshow_image[1])), np.hstack((inshow_image[2], inshow_image[3])))))
# generate gaussians
input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) #no need for dino but need for patch
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
# generate gaussians
gaussians = model.forward_gaussians(input_image , cam_poses= None) #no need camera pose for now
# save gaussians
model.gs_render.save_ply(gaussians, os.path.join(opt.workspace, name + '_' + str(test_order) + '.ply'))
# render 360 video
images = []
elevation = 0
if opt.fancy_video:
azimuth = np.arange(0, 720, 4, dtype=np.int32)
for azi in tqdm.tqdm(azimuth):
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
scale = min(azi / 360, 1)
image = model.gs_render.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
else:
azimuth = np.arange(0, 360, 2, dtype=np.int32)
for azi in tqdm.tqdm(azimuth):
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
image = model.gs_render.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
images = np.concatenate(images, axis=0)
imageio.mimwrite(os.path.join(opt.workspace, name + '_' + str(test_order) + '.mp4'), images, fps=30)
assert opt.test_path is not None
if os.path.isdir(opt.test_path):
file_paths = glob.glob(os.path.join(opt.test_path, "*"))
else:
file_paths = [opt.test_path]
for path in file_paths:
process(opt, path, test_order=1)
process(opt, path, test_order=0)