1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
| diff --git a/scene/__init__.py b/scene/__init__.py
--- a/scene/__init__.py
+++ b/scene/__init__.py
@@ -37,9 +37,6 @@ class Scene:
self.loaded_iter = load_iteration
print("Loading trained model at iteration {}".format(self.loaded_iter))
- self.train_cameras = {}
- self.test_cameras = {}
-
if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
@@ -67,12 +64,8 @@ class Scene:
random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
self.cameras_extent = scene_info.nerf_normalization["radius"]
-
- for resolution_scale in resolution_scales:
- print("Loading Training Cameras")
- self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
- print("Loading Test Cameras")
- self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
+ self.scene_info = scene_info
+ self.args = args
if self.loaded_iter:
self.gaussians.load_ply(os.path.join(self.model_path,
@@ -87,7 +80,7 @@ class Scene:
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
def getTrainCameras(self, scale=1.0):
- return self.train_cameras[scale]
+ return cameraList_from_camInfos(self.scene_info.train_cameras, scale, self.args)
def getTestCameras(self, scale=1.0):
- return self.test_cameras[scale]
\ No newline at end of file
+ return cameraList_from_camInfos(self.scene_info.test_cameras, scale, self.args)
\ No newline at end of file
diff --git a/scene/cameras.py b/scene/cameras.py
--- a/scene/cameras.py
+++ b/scene/cameras.py
@@ -17,6 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid,
+ raw_image=None,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
):
super(Camera, self).__init__()
@@ -29,6 +30,14 @@ class Camera(nn.Module):
self.FoVy = FoVy
self.image_name = image_name
+ if raw_image is not None:
+ image = raw_image[:3, ...]
+ gt_alpha_mask = None
+
+ if raw_image.shape[1] == 4:
+ gt_alpha_mask = raw_image[3:4, ...]
+ del raw_image
+
try:
self.data_device = torch.device(data_device)
except Exception as e:
@@ -39,9 +48,11 @@ class Camera(nn.Module):
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
+ del image
if gt_alpha_mask is not None:
self.original_image *= gt_alpha_mask.to(self.data_device)
+ del gt_alpha_mask
else:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
@@ -56,6 +67,9 @@ class Camera(nn.Module):
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
+ def __del__(self):
+ del self.original_image, self.world_view_transform, self.projection_matrix, self.full_proj_transform, self.camera_center
+
class MiniCam:
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
self.image_width = width
diff --git a/train.py b/train.py
--- a/train.py
+++ b/train.py
@@ -44,103 +43,82 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
iter_start = torch.cuda.Event(enable_timing = True)
iter_end = torch.cuda.Event(enable_timing = True)
- viewpoint_stack = None
+ viewpoint_stack = scene.getTrainCameras()
ema_loss_for_log = 0.0
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
first_iter += 1
|