From 8a8ef38040623e7355ff85bd794cda9e2af26f5c Mon Sep 17 00:00:00 2001 From: shaoshitong <1090784053@qq.com> Date: Wed, 10 Aug 2022 22:00:33 +0800 Subject: [PATCH 1/2] NeRF's first submission Complete addition of NeRF profile with evaluator Fix a series of bugs and issues to make NeRF workable --- libai/data/datasets/__init__.py | 1 + libai/data/datasets/nerf_dataset.py | 614 +++++++++++++++++++ projects/NeRF/__init__.py | 0 projects/NeRF/configs/config_nerf.py | 180 ++++++ projects/NeRF/configs/models/config_model.py | 24 + projects/NeRF/evaluation/nerf_evaluator.py | 109 ++++ projects/NeRF/modeling/NeRF.py | 135 ++++ projects/NeRF/modeling/System.py | 431 +++++++++++++ projects/NeRF/optimizers/Radam.py | 113 ++++ projects/NeRF/optimizers/Ranger.py | 173 ++++++ projects/NeRF/optimizers/__init__.py | 2 + projects/__init__.py | 0 12 files changed, 1782 insertions(+) create mode 100644 libai/data/datasets/nerf_dataset.py create mode 100644 projects/NeRF/__init__.py create mode 100644 projects/NeRF/configs/config_nerf.py create mode 100644 projects/NeRF/configs/models/config_model.py create mode 100644 projects/NeRF/evaluation/nerf_evaluator.py create mode 100644 projects/NeRF/modeling/NeRF.py create mode 100644 projects/NeRF/modeling/System.py create mode 100644 projects/NeRF/optimizers/Radam.py create mode 100644 projects/NeRF/optimizers/Ranger.py create mode 100644 projects/NeRF/optimizers/__init__.py create mode 100644 projects/__init__.py diff --git a/libai/data/datasets/__init__.py b/libai/data/datasets/__init__.py index 5b8c75507..088304a1b 100644 --- a/libai/data/datasets/__init__.py +++ b/libai/data/datasets/__init__.py @@ -20,3 +20,4 @@ from .roberta_dataset import RobertaDataset from .gpt_dataset import GPT2Dataset from .t5_dataset import T5Dataset +from .nerf_dataset import BlenderDataset, LLFFDataset diff --git a/libai/data/datasets/nerf_dataset.py b/libai/data/datasets/nerf_dataset.py new file mode 100644 index 000000000..d5e081986 --- /dev/null +++ b/libai/data/datasets/nerf_dataset.py @@ -0,0 +1,614 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import json +import os +import re +import sys +from collections import OrderedDict + +import numpy as np +import oneflow as flow +from flowvision import datasets +from flowvision import transforms as T +from oneflow.utils.data import Dataset +from PIL import Image + +from libai.config import LazyCall, instantiate +from libai.data.structures import DistTensorData, Instance + + +# TODO: Somw tools about the fpm storage, which is not necessarily used +def read_pfm(filename): + file = open(filename, "rb") + + header = file.readline().decode("utf-8").rstrip() + if header == "PF": + color = True + elif header == "Pf": + color = False + else: + raise Exception("Not a PFM file.") + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("utf-8")) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + file.close() + return data, scale + + +def save_pfm(filename, image, scale=1): + file = open(filename, "wb") + image = np.flipud(image) + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n".encode("utf-8") if color else "Pf\n".encode("utf-8")) + file.write("{} {}\n".format(image.shape[1], image.shape[0]).encode("utf-8")) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write(("%f\n" % scale).encode("utf-8")) + + image.tofile(file) + file.close() + + +# TODO: Preparatory conversion tools for 3D rendering +from kornia import create_meshgrid + + +def get_rays(directions, c2w): + """ + Get ray origin and normalized directions in world coordinate for all pixels in one image. + Inputs: + directions: (H, W, 3) precomputed ray directions in camera coordinate + c2w: (3, 4) transformation matrix from camera coordinate to world coordinate + Outputs: + rays_o: (H*W, 3), the origin of the rays in world coordinate + rays_d: (H*W, 3), the normalized direction of the rays in world coordinate + """ + # Rotate ray directions from camera coordinate to the world coordinate + rays_d = directions @ c2w[:, :3].T # (H, W, 3) + rays_d = rays_d / flow.norm(rays_d, dim=-1, keepdim=True) + # The origin of all rays is the camera origin in world coordinate + rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3) + + rays_d = rays_d.view(-1, 3) + rays_o = rays_o.view(-1, 3) + + return rays_o, rays_d + + +def get_ray_directions(H, W, focal): + """ + Get ray directions for all pixels in camera coordinate. + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + i, j = grid.unbind(-1) + i = flow.tensor(i.numpy()) + j = flow.tensor(j.numpy()) + directions = flow.stack( + [(i - W / 2) / focal, -(j - H / 2) / focal, -flow.ones_like(i)], -1 + ) # compute about tanx (H, W, 3) + + return directions + + +def get_ndc_rays(H, W, focal, near, rays_o, rays_d): + """ + Transform rays from world coordinate to NDC. + NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis. + For detailed derivation, please see: + http://www.songho.ca/opengl/gl_projectionmatrix.html + https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf + https://pengfeixc.com/blogs/computer-graphics/3D-matrix-transformation-part-three + + In practice, use NDC "if and only if" the scene is unbounded (has a large depth). + See https://github.com/bmild/nerf/issues/18 + + Inputs: + H, W, focal: image height, width and focal length + near: (N_rays) or float, the depths of the near plane + rays_o: (N_rays, 3), the origin of the rays in world coordinate + rays_d: (N_rays, 3), the direction of the rays in world coordinate + + Outputs: + rays_o: (N_rays, 3), the origin of the rays in NDC + rays_d: (N_rays, 3), the direction of the rays in NDC + """ + # Shift ray origins to near plane + t = -(near + rays_o[..., 2]) / rays_d[..., 2] + rays_o = rays_o + t[..., None] * rays_d + + # Store some intermediate homogeneous results + ox_oz = rays_o[..., 0] / rays_o[..., 2] + oy_oz = rays_o[..., 1] / rays_o[..., 2] + + # Projection + o0 = -1.0 / (W / (2.0 * focal)) * ox_oz + o1 = -1.0 / (H / (2.0 * focal)) * oy_oz + o2 = 1.0 + 2.0 * near / rays_o[..., 2] + + d0 = -1.0 / (W / (2.0 * focal)) * (rays_d[..., 0] / rays_d[..., 2] - ox_oz) + d1 = -1.0 / (H / (2.0 * focal)) * (rays_d[..., 1] / rays_d[..., 2] - oy_oz) + d2 = 1 - o2 + rays_o = flow.stack([o0, o1, o2], -1) # (B, 3) + rays_d = flow.stack([d0, d1, d2], -1) # (B, 3) + + return rays_o, rays_d + + +def normalize(v): + """Normalize a vector.""" + return v / np.linalg.norm(v) + + +def average_poses(poses): + """ + Calculate the average pose, which is then used to center all poses + using @center_poses. Its computation is as follows: + 1. Compute the center: the average of pose centers. + 2. Compute the z axis: the normalized average z axis. + 3. Compute axis y': the average y axis. + 4. Compute x' = y' cross product z, then normalize it as the x axis. + 5. Compute the y axis: z cross product x. + + Note that at step 3, we cannot directly use y' as y axis since it's + not necessarily orthogonal to z axis. We need to pass from x to y. + + Inputs: + poses: (N_images, 3, 4) + + Outputs: + pose_avg: (3, 4) the average pose + """ + # 1. Compute the center + center = poses[..., 3].mean(0) # (3) + + # 2. Compute the z axis + z = normalize(poses[..., 2].mean(0)) # (3) + + # 3. Compute axis y' (no need to normalize as it's not the final output) + y_ = poses[..., 1].mean(0) # (3) + + # 4. Compute the x axis + x = normalize(np.cross(y_, z)) # (3) + + # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) + y = np.cross(z, x) # (3) + + pose_avg = np.stack([x, y, z, center], 1) # (3, 4) + + return pose_avg + + +def center_poses(poses): + """ + Center the poses so that we can use NDC. + See https://github.com/bmild/nerf/issues/34 + + Inputs: + poses: (N_images, 3, 4) + + Outputs: + poses_centered: (N_images, 3, 4) the centered poses + pose_avg: (3, 4) the average pose + """ + + pose_avg = average_poses(poses) # (3, 4) + pose_avg_homo = np.eye(4) + pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation + # by simply adding 0, 0, 0, 1 as the last row + last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) + poses_homo = np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate + + poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) + poses_centered = poses_centered[:, :3] # (N_images, 3, 4) + + return poses_centered, np.linalg.inv(pose_avg_homo) + + +def create_spiral_poses(radii, focus_depth, n_poses=120): + """ + Computes poses that follow a spiral path for rendering purpose. + See https://github.com/Fyusion/LLFF/issues/19 + In particular, the path looks like: + https://tinyurl.com/ybgtfns3 + + Inputs: + radii: (3) radii of the spiral for each axis + focus_depth: float, the depth that the spiral poses look at + n_poses: int, number of poses to create along the path + + Outputs: + poses_spiral: (n_poses, 3, 4) the poses in the spiral path + """ + + poses_spiral = [] + for t in np.linspace(0, 4 * np.pi, n_poses + 1)[:-1]: # rotate 4pi (2 rounds) + # the parametric function of the spiral (see the interactive web) + center = np.array([np.cos(t), -np.sin(t), -np.sin(0.5 * t)]) * radii + + # the viewing z axis is the vector pointing from the @focus_depth plane + # to @center + z = normalize(center - np.array([0, 0, -focus_depth])) + + # compute other axes as in @average_poses + y_ = np.array([0, 1, 0]) # (3) + x = normalize(np.cross(y_, z)) # (3) + y = np.cross(z, x) # (3) + + poses_spiral += [np.stack([x, y, z, center], 1)] # (3, 4) + + return np.stack(poses_spiral, 0) # (n_poses, 3, 4) + + +def create_spheric_poses(radius, n_poses=120): + """ + Create circular poses around z axis. + Inputs: + radius: the (negative) height and the radius of the circle. + + Outputs: + spheric_poses: (n_poses, 3, 4) the poses in the circular path + """ + + def spheric_pose(theta, phi, radius): + trans_t = lambda t: np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, -0.9 * t], + [0, 0, 1, t], + [0, 0, 0, 1], + ] + ) + + rot_phi = lambda phi: np.array( + [ + [1, 0, 0, 0], + [0, np.cos(phi), -np.sin(phi), 0], + [0, np.sin(phi), np.cos(phi), 0], + [0, 0, 0, 1], + ] + ) + + rot_theta = lambda th: np.array( + [ + [np.cos(th), 0, -np.sin(th), 0], + [0, 1, 0, 0], + [np.sin(th), 0, np.cos(th), 0], + [0, 0, 0, 1], + ] + ) + + c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius) + c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w + return c2w[:3] + + spheric_poses = [] + for th in np.linspace(0, 2 * np.pi, n_poses + 1)[:-1]: + spheric_poses += [spheric_pose(th, -np.pi / 5, radius)] # 36 degree view downwards + return np.stack(spheric_poses, 0) + + +# TODO: Blender and LLFF Datasets + + +def trun_dict_to_instance(dict): + return Instance(**{key: DistTensorData(flow.tensor(value)) for key, value in dict.items()}) + + +class NerfBaseDataset(Dataset): + def __init__(self, root_dir, split, img_wh): + super(NerfBaseDataset, self).__init__() + self.root_dir = root_dir + self.split = split + self.img_wh = img_wh + self.transform = T.Compose([T.ToTensor()]) + + def load_meta(self): + pass + + +class BlenderDataset(NerfBaseDataset): + def __init__(self, root_dir, split="train", img_wh=(800, 800), **kwargs): + """ + Args: + root_dir: str, + split: str, + img_wh: tuple, + """ + super(BlenderDataset, self).__init__(root_dir, split, img_wh) + self.white_back = True + self.load_meta() + + def load_meta(self): + with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), "r") as f: + self.meta = json.load(f) + w, h = self.img_wh + camera_angle_x = float(self.meta["camera_angle_x"]) + self.focal = 0.5 * w / np.tan(0.5 * camera_angle_x) + self.near = 2.0 + self.far = 6.0 + self.bounds = np.array([self.near, self.far]) + self.directions = get_ray_directions(h, w, self.focal) # (h, w, 3) + + if self.split == "train": # create buffer of all rays and rgb data + self.image_paths = [] + self.poses = [] + self.all_rays = [] + self.all_rgbs = [] + for frame in self.meta["frames"]: + pose = np.array(frame["transform_matrix"])[:3, :4] + self.poses += [pose] + c2w = flow.Tensor(pose) + image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") + self.image_paths += [image_path] + img = Image.open(image_path) + img = img.resize(self.img_wh, Image.LANCZOS) + img = self.transform(img) # (4, h, w) + img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA + img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB + self.all_rgbs += [img] + rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) + self.all_rays += [ + flow.cat( + [ + rays_o, + rays_d, + self.near * flow.ones_like(rays_o[:, :1]), + self.far * flow.ones_like(rays_o[:, :1]), + ], + 1, + ) + ] # (h*w, 8) + + self.all_rays = flow.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) + self.all_rgbs = flow.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) + + def __len__(self): + if self.split == "train": + return len(self.all_rays) + if self.split == "val": + return 8 # only validate 8 images (to support <=8 gpus) + return len(self.meta["frames"]) + + def __getitem__(self, idx): + if self.split == "train": # use data in the buffers + sample = OrderedDict(rays=self.all_rays[idx], rgbs=self.all_rgbs[idx]) + + else: # create data for each image separately + frame = self.meta["frames"][idx] + c2w = flow.Tensor(frame["transform_matrix"])[:3, :4] + + img = Image.open(os.path.join(self.root_dir, f"{frame['file_path']}.png")) + img = img.resize(self.img_wh, Image.LANCZOS) + img = self.transform(img) # (4, H, W) + valid_mask = (img[-1] > 0).flatten() # (H*W) valid color area + img = img.view(4, -1).permute(1, 0) # (H*W, 4) RGBA + img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB + + rays_o, rays_d = get_rays(self.directions, c2w) + + rays = flow.concat( + [ + rays_o, + rays_d, + self.near * flow.ones_like(rays_o[:, :1]), + self.far * flow.ones_like(rays_o[:, :1]), + ], + 1, + ) # (H*W, 8) + sample = OrderedDict(rays=rays, rgbs=img, c2w=c2w, valid_mask=valid_mask) + return trun_dict_to_instance(sample) + + +class LLFFDataset(NerfBaseDataset): + def __init__(self, root_dir, split="train", img_wh=(504, 378), spheric_poses=False, val_num=1): + """ + Args: + root_dir: str, + split: str, + img_wh: tuple, + spheric_poses: bool, whether the images are taken in a spheric inward-facing manner + default: False (forward-facing) + val_num: int, number of val images (used for multigpu training, validate same image + for all gpus) + """ + super(LLFFDataset, self).__init__(root_dir, split, img_wh) + self.spheric_poses = spheric_poses + self.val_num = max(1, val_num) # at least 1 + self.load_meta() + self.white_back = False + + def load_meta(self): + poses_bounds = np.load(os.path.join(self.root_dir, "poses_bounds.npy")) # (N_images, 17) + self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, "images/*"))) + if self.split in ["train", "val"]: + assert len(poses_bounds) == len( + self.image_paths + ), "Mismatch between number of images and number of poses! Please rerun COLMAP!" + poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) + self.bounds = poses_bounds[:, -2:] # (N_images, 2) + H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images + H, W, self.focal = H.item(), W.item(), self.focal.item() + assert ( + H * self.img_wh[0] == W * self.img_wh[1] + ), f"You must set @img_wh to have the same aspect ratio as ({W}, {H}) !" + self.focal *= self.img_wh[0] / W + poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) + # (N_images, 3, 4) exclude H, W, focal + self.poses, self.pose_avg = center_poses(poses) + distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) + val_idx = np.argmin(distances_from_center) # choose val image as the closest to + near_original = self.bounds.min() + scale_factor = near_original * 0.75 # 0.75 is the default parameter + self.bounds /= scale_factor + self.poses[..., 3] /= scale_factor + self.directions = get_ray_directions( + self.img_wh[1], self.img_wh[0], self.focal + ) # (H, W, 3) + + if self.split == "train": # create buffer of all rays and rgb data + # use first N_images-1 to train, the LAST is val + self.all_rays = [] + self.all_rgbs = [] + for i, image_path in enumerate(self.image_paths): + if i == val_idx: # exclude the val image + continue + c2w = flow.Tensor(self.poses[i]) + img = Image.open(image_path).convert("RGB") + assert ( + img.size[1] * self.img_wh[0] == img.size[0] * self.img_wh[1] + ), f"""{image_path} has different aspect ratio than img_wh, + please check your data!""" + img = img.resize(self.img_wh, Image.LANCZOS) + img = self.transform(img) # (3, h, w) + img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB + self.all_rgbs += [img] + + rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) + if not self.spheric_poses: + near, far = 0, 1 + rays_o, rays_d = get_ndc_rays( + self.img_wh[1], self.img_wh[0], self.focal, 1.0, rays_o, rays_d + ) + else: + near = self.bounds.min() + far = min(8 * near, self.bounds.max()) # focus on central object only + + self.all_rays += [ + flow.concat( + [ + rays_o, + rays_d, + near * flow.ones_like(rays_o[:, :1]), + far * flow.ones_like(rays_o[:, :1]), + ], + 1, + ) + ] # (h*w, 8) + + self.all_rays = flow.cat(self.all_rays, 0) # ((N_images-1)*h*w, 8) + self.all_rgbs = flow.cat(self.all_rgbs, 0) # ((N_images-1)*h*w, 3) + + elif self.split == "val": + self.c2w_val = self.poses[val_idx] + self.image_path_val = self.image_paths[val_idx] + + else: # for testing, create a parametric rendering path + if self.split.endswith("train"): # test on training set + self.poses_test = self.poses + elif not self.spheric_poses: + focus_depth = 3.5 # hardcoded, this is numerically close to the formula + # given in the original repo. Mathematically if near=1 + # and far=infinity, then this number will converge to 4 + radii = np.percentile(np.abs(self.poses[..., 3]), 90, axis=0) + self.poses_test = create_spiral_poses(radii, focus_depth) + else: + radius = 1.1 * self.bounds.min() + self.poses_test = create_spheric_poses(radius) + + def __len__(self): + if self.split == "train": + return len(self.all_rays) + if self.split == "val": + return self.val_num + return len(self.poses_test) + + def __getitem__(self, idx): + if self.split == "train": # use data in the buffers + sample = OrderedDict(rays=self.all_rays[idx], rgbs=self.all_rgbs[idx]) + + else: + if self.split == "val": + c2w = flow.Tensor(self.c2w_val) + else: + c2w = flow.Tensor(self.poses_test[idx]) + + rays_o, rays_d = get_rays(self.directions, c2w) + if not self.spheric_poses: + near, far = 0, 1 + rays_o, rays_d = get_ndc_rays( + self.img_wh[1], self.img_wh[0], self.focal, 1.0, rays_o, rays_d + ) + else: + near = self.bounds.min() + far = min(8 * near, self.bounds.max()) + + rays = flow.cat( + [ + rays_o, + rays_d, + near * flow.ones_like(rays_o[:, :1]), + far * flow.ones_like(rays_o[:, :1]), + ], + 1, + ) # (h*w, 8) + + sample = OrderedDict(rays=rays, rgbs=None, c2w=c2w) + if self.split == "val": + img = Image.open(self.image_path_val).convert("RGB") + img = img.resize(self.img_wh, Image.LANCZOS) + img = self.transform(img) # (3, h, w) + img = img.view(3, -1).permute(1, 0) # (h*w, 3) + sample["rgbs"] = img + + return trun_dict_to_instance(sample) + + +# +# if __name__=="__main__": +# # dataset=BlenderDataset(root_dir="/home/Bigdata/Nerf/nerf_synthetic/chair",split='train') +# # for img in dataset: +# # print(img['rgbs'].shape) +# +# dataset=LLFFDataset(root_dir="/home/Bigdata/Nerf/nerf_llff_data/fern",split="val") +# dataset=instantiate(dataset) +# for img in dataset: +# img=img.get_fields() +# print(img) diff --git a/projects/NeRF/__init__.py b/projects/NeRF/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/projects/NeRF/configs/config_nerf.py b/projects/NeRF/configs/config_nerf.py new file mode 100644 index 000000000..30a7ad367 --- /dev/null +++ b/projects/NeRF/configs/config_nerf.py @@ -0,0 +1,180 @@ +from omegaconf import OmegaConf + +import oneflow as flow +import oneflow.nn as nn + +from libai.data.datasets import BlenderDataset, LLFFDataset +from libai.data.build import build_image_train_loader, build_image_test_loader +from libai.config import LazyCall, get_config +from libai.optim import get_default_optimizer_params +from libai.scheduler.lr_scheduler import WarmupCosineAnnealingLR, WarmupMultiStepLR + +from projects.NeRF.optimizers import Ranger, RAdam +from projects.NeRF.configs.models.config_model import model +from projects.NeRF.evaluation.nerf_evaluator import NerfEvaluator + +# TODO: Create a unified interface to Nerf datasets + + + +def get_nerf_dataset(dataset_type="Blender"): + """ + Args: + dataset_type: Blender or LLFF + """ + assert dataset_type in ["Blender", "LLFF"], "The Nerf dataset must be one of Blender and LLFF" + if dataset_type == "Blender": + return BlenderDataset + else: + return LLFFDataset + + +graph = get_config("common/models/graph.py").graph +graph.enabled = False +train = get_config("common/train.py").train + +# Refine train cfg for Nerf System + +train.train_micro_batch_size = 1024 +train.test_micro_batch_size = 1 +train.dataset_type = "Blender" +train.blender_dataset_path = "/path/to/llff" +train.llff_dataset_path = "/path/to/llff" +train.train_epoch = 16 if train.dataset_type == "Blender" else 30 +train.warmup_ratio = 0 / train.train_epoch +train.evaluation.eval_period = 1000 +train.log_period = 50 +train.optim_type = "adam" +train.lr_scheduler_type = "steplr" + + +# Redefining model config + +model.cfg.dataset_type = train.dataset_type +model.cfg.loss_func = nn.MSELoss() + +# Redefining evaluator + +train.evaluation = dict( + enabled=True, + # evaluator for calculating top-k acc + evaluator=LazyCall(NerfEvaluator)( + img_wh=(400, 400) if train.dataset_type == "Blender" else (504, 378) + ), + eval_period=train.evaluation.eval_period, + eval_iter=1e5, # running steps for validation/test + # Metrics to be used for best model checkpoint. + eval_metric="Psnr", + eval_mode="max", +) + +# Refine optimizer cfg for Nerf System +# NOTE: In theory, both datasets used by Nerf +# are optimized using the Adam optimizer, but +# since the borrowed code base also implemen- +# ts three other optimizer configurations, l- +# ibai also implements the corresponding opt- +# imizer. +if train.optim_type == "adam": + optimizer = flow.optim.Adam + lr = 5e-4 +elif train.optim_type == "sgd": + optimizer = flow.optim.SGD + lr = 5e-1 +elif train.optim_type == "radam": + optimizer = RAdam + lr = 5e-4 +elif train.optim_type == "ranger": + optimizer = Ranger + lr = 5e-4 +else: + raise NotImplementedError("Nerf does not support this type of optimizer!") + +optim = LazyCall(optimizer)( + params=LazyCall(get_default_optimizer_params)( + # params.model is meant to be set to the model object, + # before instantiating the optimizer. + clip_grad_max_norm=None, + clip_grad_norm_type=None, + weight_decay_norm=None, + weight_decay_bias=None, + ), + lr=lr, + weight_decay=0, +) + +if train.optim_type == "sgd": + optim.momentum = 0.9 + +if train.lr_scheduler_type == "steplr": + scheduler = WarmupMultiStepLR +elif train.optim_type == "cosine": + scheduler = WarmupCosineAnnealingLR +else: + raise NotImplementedError("Nerf does not support this type of scheduler!") + +train.scheduler = LazyCall(scheduler)( + # In DefaultTrainer we will automatically set `max_iter` + # and `warmup_iter` by the given train cfg. + warmup_factor=0.001, + warmup_method="linear", +) + +if train.lr_scheduler_type == "steplr": + if train.dataset_type == "Blender": + milestones = [2/16, 4/16, 8/16] + else: + milestones = [10/30, 20/30] + train.scheduler.milestones = milestones + train.scheduler.gamma = 0.5 +elif train.optim_type == "cosine": + train.scheduler.eta_min = 1e-8 + +train.warmup_ratio = ( + train.warmup_ratio + if train.warmup_ratio > 0 and train.optim_type not in ["radam", "ranger"] + else 0.0 +) + +# Set fp16 ON +train.amp.enabled = True + +dataset = LazyCall(get_nerf_dataset)(dataset_type=train.dataset_type) + +dataloader = OmegaConf.create() + +dataloader.train = LazyCall(build_image_train_loader)( + dataset=[ + LazyCall(dataset)( + split="train", + img_wh=(400, 400) if dataset.dataset_type == "Blender" else (504, 378), + root_dir=train.blender_dataset_path if dataset.dataset_type == "Blender" else train.llff_dataset_path, + spheric_poses=None if dataset.dataset_type == "Blender" else False, + val_num=None if dataset.dataset_type == "Blender" else 1, # Number of your GPUs + ) + ], + num_workers=4, + train_batch_size=train.train_micro_batch_size, + test_batch_size=1, +) + +dataloader.test = [ + LazyCall(build_image_test_loader)( + dataset=LazyCall(dataset)( + split="val", + img_wh=(400, 400) if dataset.dataset_type == "Blender" else (504, 378), + root_dir=train.blender_dataset_path if dataset.dataset_type == "Blender" else train.llff_dataset_path, + spheric_poses=None if dataset.dataset_type == "Blender" else False, + val_num=None if dataset.dataset_type == "Blender" else 1, # Number of your GPUs + ), + num_workers=4, + test_batch_size=1, + ) +] + +# Distributed Settings +depth = None +train.dist.pipeline_num_layers = depth +train.dist.data_parallel_size = 1 +train.dist.tensor_parallel_size = 1 +train.dist.pipeline_parallel_size = 1 diff --git a/projects/NeRF/configs/models/config_model.py b/projects/NeRF/configs/models/config_model.py new file mode 100644 index 000000000..ef201b520 --- /dev/null +++ b/projects/NeRF/configs/models/config_model.py @@ -0,0 +1,24 @@ +from omegaconf import DictConfig +from libai.config import LazyCall + +from projects.NeRF.modeling.System import NerfSystem + + +cfg = dict( + D=8, + W=256, + in_channels_xyz=63, + in_channels_dir=27, + skips=[4], + N_samples=64, + use_disp=False, + perturb=1.0, + noise_std=1.0, + N_importance=128, + chunk=32 * 1204, + dataset_type="blender", +) + +cfg = DictConfig(cfg) + +model = LazyCall(NerfSystem)(cfg=cfg) diff --git a/projects/NeRF/evaluation/nerf_evaluator.py b/projects/NeRF/evaluation/nerf_evaluator.py new file mode 100644 index 000000000..4bd809064 --- /dev/null +++ b/projects/NeRF/evaluation/nerf_evaluator.py @@ -0,0 +1,109 @@ +import math, os, cv2, copy +from datetime import datetime +import flowvision.transforms as T +import numpy as np +import oneflow as flow +from PIL import Image +from collections import OrderedDict + +from libai.utils import distributed as dist +from libai.evaluation.cls_evaluator import ClsEvaluator + + +class NerfEvaluator(ClsEvaluator): + def __init__(self, img_wh, image_save_path=None): + super().__init__(topk=(1, 5)) + self.img_wh = img_wh + self.image_save_path = ( + str(os.path.dirname(os.path.realpath(__file__))) + "/../images" + if image_save_path is None + else image_save_path + ) + if not os.path.exists(self.image_save_path): + os.makedirs(self.image_save_path) + self.toimage = T.ToPILImage() + + def current_time(self): + currentDateAndTime = datetime.now() + currentTime = currentDateAndTime.strftime("%H_%M_%S") + return currentTime + + def process(self, inputs, outputs): + losses, rgbs = ( + outputs["losses"], + outputs["rgbs"].squeeze(0), + ) + typ = list(outputs.keys())[1] + outputs.pop(typ) + outputs.pop("losses") + outputs.pop("rgbs") + results = {k: v.squeeze(0) for k, v in outputs.items()} + if len(self._predictions) == 0: + W, H = self.img_wh + print(results[f"rgb_{typ}"].shape) + img = results[f"rgb_{typ}"].view(H, W, 3).cpu() + img = img.permute(2, 0, 1) # (3, H, W) + img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) + depth = self.visualize_depth(results[f"depth_{typ}"].view(H, W)) # (3, H, W) + img = self.toimage(img) + img_gt = self.toimage(img_gt) + depth = self.toimage(depth) + img.save( + os.path.join(self.image_save_path, f"img_{self.current_time()}.png"), quality=100 + ) + img_gt.save( + os.path.join(self.image_save_path, f"img_gt_{self.current_time()}.png"), quality=100 + ) + depth.save( + os.path.join(self.image_save_path, f"depth_{self.current_time()}.png"), quality=100 + ) + psnr = self.psnr(results[f"rgb_{typ}"], rgbs) + self._predictions.append({"losses": losses.item(), "psnr": psnr.item()}) + + def evaluate(self): + if not dist.is_main_process(): + return {} + else: + predictions = self._predictions + + total_correct_num = OrderedDict() + total_correct_num["losses"] = 0 + total_correct_num["psnr"] = 0 + total_samples = 0 + for prediction in predictions: + losses = prediction["losses"] + psnr = prediction["psnr"] + total_correct_num["losses"] += losses + total_correct_num["psnr"] += psnr + total_samples += 1 + + self._results = OrderedDict() + for key, value in total_correct_num.items(): + self._results[key] = value / total_samples + + return copy.deepcopy(self._results) + + def visualize_depth(self, depth, cmap=cv2.COLORMAP_JET): + """ + depth: (H, W) + """ + x = depth.cpu().numpy() + x = np.nan_to_num(x) # change nan to 0 + mi = np.min(x) # get minimum depth + ma = np.max(x) + x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1 + x = (255 * x).astype(np.uint8) + x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) + x_ = T.ToTensor()(x_) # (3, H, W) + return x_ + + def mse(self, image_pred, image_gt, valid_mask=None, reduction="mean"): + value = (image_pred - image_gt) ** 2 + if valid_mask is not None: + value = value[valid_mask] + if reduction == "mean": + return flow.mean(value) + return value + + def psnr(self, image_pred, image_gt, valid_mask=None, reduction="mean"): + return -10 * flow.log(self.mse(image_pred, image_gt, valid_mask, reduction)) / math.log(10) diff --git a/projects/NeRF/modeling/NeRF.py b/projects/NeRF/modeling/NeRF.py new file mode 100644 index 000000000..2a86faf1f --- /dev/null +++ b/projects/NeRF/modeling/NeRF.py @@ -0,0 +1,135 @@ +import oneflow as flow +import oneflow.nn as nn + +from libai.utils import distributed as dist + + + +class Embedding(nn.Module): + def __init__(self, in_channels, N_freqs, logscale=True): + """ + Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) + in_channels: number of input channels (3 for both xyz and direction) + """ + super(Embedding, self).__init__() + self.N_freqs = N_freqs + self.in_channels = in_channels + self.funcs = [flow.sin, flow.cos] + self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1) + + if logscale: + freq_bands = 2 ** flow.linspace(0, N_freqs - 1, N_freqs) + else: + freq_bands = flow.linspace(1, 2 ** (N_freqs - 1), N_freqs).cuda() + self.register_buffer("freq_bands", freq_bands.to_global( + placement=dist.get_layer_placement(0), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + )) + + def forward(self, x): + """ + Embeds x to (x, sin(2^k x), cos(2^k x), ...) + Different from the paper, "x" is also in the output + See https://github.com/bmild/nerf/issues/12 + + Inputs: + x: (B, self.in_channels) + + Outputs: + out: (B, self.out_channels) + """ + out = [x] + for freq in self.freq_bands: + for func in self.funcs: + m = func(freq * x) + out += [m] + + return flow.cat(out, -1) + + +class NeRF(nn.Module): + def __init__(self, D=8, W=256, in_channels_xyz=63, in_channels_dir=27, skips=[4]): + """ + D: number of layers for density (sigma) encoder + W: number of hidden units in each layer + in_channels_xyz: number of input channels for xyz (3+3*10*2=63 by default) + in_channels_dir: number of input channels for direction (3+3*4*2=27 by default) + skips: add skip connection in the Dth layer + """ + super(NeRF, self).__init__() + self.D = D + self.W = W + self.in_channels_xyz = in_channels_xyz + self.in_channels_dir = in_channels_dir + self.skips = skips + + # xyz encoding layers + for i in range(D): + if i == 0: + layer = nn.Linear(in_channels_xyz, W) + elif i in skips: + layer = nn.Linear(W + in_channels_xyz, W) + else: + layer = nn.Linear(W, W) + layer = nn.Sequential(layer, nn.ReLU(True)) + setattr(self, f"xyz_encoding_{i+1}", layer) + self.xyz_encoding_final = nn.Linear(W, W) + + # direction encoding layers + self.dir_encoding = nn.Sequential(nn.Linear(W + in_channels_dir, W // 2), nn.ReLU(True)) + + # output layers + self.sigma = nn.Linear(W, 1) + self.rgb = nn.Sequential(nn.Linear(W // 2, 3), nn.Sigmoid()) + self.reset_parameters(self) + + def reset_parameters(self, modules): + for module in modules.modules(): + if isinstance(module, nn.Linear): + nn.init.constant_(module.weight.data, 0.1) + if module.bias is not None: + nn.init.constant_(module.weight.data, 0.01) + + def forward(self, x, sigma_only=False): + """ + Encodes input (xyz+dir) to rgb+sigma (not ready to render yet). + For rendering this ray, please see rendering.py + + Inputs: + x: (B, self.in_channels_xyz(+self.in_channels_dir)) + the embedded vector of position and direction + sigma_only: whether to infer sigma only. If True, + x is of shape (B, self.in_channels_xyz) + + Outputs: + if sigma_ony: + sigma: (B, 1) sigma + else: + out: (B, 4), rgb and sigma + """ + if not sigma_only: + input_xyz, input_dir = flow.split( + x, [self.in_channels_xyz, self.in_channels_dir], dim=-1 + ) + else: + input_xyz = x + + xyz_ = input_xyz + for i in range(self.D): + if i in self.skips: + xyz_ = flow.cat([input_xyz, xyz_], -1) + xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) + + sigma = self.sigma(xyz_) + if sigma_only: + return sigma + + xyz_encoding_final = self.xyz_encoding_final(xyz_) + + dir_encoding_input = flow.cat([xyz_encoding_final, input_dir], -1) + dir_encoding = self.dir_encoding(dir_encoding_input) + rgb = self.rgb(dir_encoding) + + out = flow.cat([rgb, sigma], -1) + + return out diff --git a/projects/NeRF/modeling/System.py b/projects/NeRF/modeling/System.py new file mode 100644 index 000000000..be1da2608 --- /dev/null +++ b/projects/NeRF/modeling/System.py @@ -0,0 +1,431 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +from collections import defaultdict + +import oneflow as flow +import oneflow.nn as nn +import torch + +from libai.config.config import configurable +from libai.utils import distributed as dist +from libai.layers.linear import Linear + +from projects.NeRF.modeling.NeRF import Embedding, NeRF + + +class NerfSystem(nn.Module): + @configurable + def __init__( + self, + D=8, + W=256, + in_channels_xyz=63, + in_channels_dir=27, + skips=[4], + N_samples=64, + use_disp=False, + perturb=1.0, + noise_std=1.0, + N_importance=128, + chunk=32 * 1204, + dataset_type="blender", + loss_func=None + ): + super(NerfSystem, self).__init__() + self.N_samples = N_samples + self.use_disp = use_disp + self.perturb = perturb + self.noise_std = noise_std + self.N_importance = N_importance + self.chunk = chunk + self.white_back = True if dataset_type == "blender" else False + self.loss_func = nn.MSELoss() if loss_func == None else loss_func + self.embedding_xyz = Embedding(3, 10) # 10 is the default number + self.embedding_dir = Embedding(3, 4) # 4 is the default number + self.nerf_coarse = NeRF( + D=D, + W=W, + in_channels_xyz=in_channels_xyz, + in_channels_dir=in_channels_dir, + skips=skips, + ) + self.models = [self.nerf_coarse] + if N_importance > 0: + self.nerf_fine = NeRF( + D=D, + W=W, + in_channels_xyz=in_channels_xyz, + in_channels_dir=in_channels_dir, + skips=skips, + ) + self.models += [self.nerf_fine] + self.models = nn.ModuleList(self.models) + + @classmethod + def from_config(cls, cfg): + return { + "D": cfg.D, + "W": cfg.W, + "in_channels_xyz": cfg.in_channels_xyz, + "in_channels_dir": cfg.in_channels_dir, + "skips": cfg.skips, + "N_samples": cfg.N_samples, + "use_disp": cfg.use_disp, + "perturb": cfg.perturb, + "noise_std": cfg.noise_std, + "N_importance": cfg.N_importance, + "chunk": cfg.chunk, + "dataset_type": cfg.dataset_type, + "loss_func": cfg.loss_func, + } + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / flow.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = flow.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = flow.cat([flow.zeros_like(cdf[:, :1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = flow.linspace(0, 1, N_importance).to_global( + placement=bins.placement, sbp=bins.sbp + ) + u = u.expand(N_rays, N_importance) + else: + u = flow.rand(N_rays, N_importance).to_global( + placement=bins.placement, sbp=bins.sbp + ) + u = u.contiguous() + + inds = flow.searchsorted(cdf, u, right=True) + below = flow.clamp(inds - 1, 0, 1e6) + above = flow.clamp(inds, -1e6, N_samples_) + + inds_sampled = flow.stack([below, above], -1).view(N_rays, 2 * N_importance) + cdf_g = flow.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = flow.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[..., 1] - cdf_g[..., 0] + + # denom equals 0 means a bin has weight 0, in which case it will not be sampled + # anyway, therefore any value for it is fine (set to 1 here) + denom = flow.masked_fill(denom, denom < eps, 1) + samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * (bins_g[..., 1] - bins_g[..., 0]) + return samples + + def inference( + self, + N_rays, + model, + embedding_xyz, + xyz_, + dir_, + dir_embedded, + z_vals, + noise_std=1, + chunk=1024 * 32, + white_back=False, + weights_only=False, + ): + """ + Helper function that performs model inference. + + Inputs: + N_rays: rays (N_rays, 3+3+2), ray origins, directions and near, far depth bounds + model: NeRF model (coarse or fine) + embedding_xyz: embedding module for xyz + xyz_: (N_rays, N_samples_, 3) sampled positions + N_samples_ is the number of sampled points in each ray; + = N_samples for coarse model + = N_samples+N_importance for fine model + dir_: (N_rays, 3) ray directions + dir_embedded: (N_rays, embed_dir_channels) embedded directions + z_vals: (N_rays, N_samples_) depths of the sampled positions + weights_only: do inference on sigma only or not + + Outputs: + if weights_only: + weights: (N_rays, N_samples_): weights of each sample + else: + rgb_final: (N_rays, 3) the final rgb image + depth_final: (N_rays) depth map + weights: (N_rays, N_samples_): weights of each sample + """ + N_samples_ = xyz_.shape[1] + # Embed directions + xyz_ = xyz_.view(-1, 3) # (N_rays*N_samples_, 3) + if not weights_only: + dir_embedded = flow.repeat_interleave(dir_embedded.to_local(), repeats=N_samples_, dim=0).to_global( + placement=xyz_.placement, sbp=xyz_.sbp + ) + # (N_rays*N_samples_, embed_dir_channels) + + # Perform model inference to get rgb and raw sigma + B = xyz_.shape[0] + out_chunks = [] + for i in range(0, B, chunk): + # Embed positions by chunk + xyz_embedded = embedding_xyz(xyz_[i: i + chunk]) + if not weights_only: + xyzdir_embedded = flow.cat([xyz_embedded, dir_embedded[i: i + chunk]], 1) + else: + xyzdir_embedded = xyz_embedded + xyzdir_embedded = xyzdir_embedded + out_chunk = model(xyzdir_embedded) # , sigma_only=weights_only) + out_chunks = out_chunks + [out_chunk] + + out = flow.cat(out_chunks, 0) + if weights_only: + sigmas = out.view(N_rays, N_samples_) + else: + rgbsigma = out.view(N_rays, N_samples_, 4) + rgbs = rgbsigma[..., :3] # (N_rays, N_samples_, 3) + sigmas = rgbsigma[..., 3] # (N_rays, N_samples_) + + # Convert these values using volume rendering (Section 4) + deltas = z_vals[:, 1:].clone() - z_vals[:, :-1].clone() # (N_rays, N_samples_-1) + delta_inf = 1e10 * flow.ones_like(deltas[:, :1]).to_global(sbp=deltas.sbp, + placement=deltas.placement) # (N_rays, 1) the last delta is infinity + deltas = flow.cat([deltas, delta_inf], -1) # (N_rays, N_samples_) + + # Multiply each distance by the norm of its corresponding direction ray + # to convert to real world distance (accounts for non-unit directions). + deltas = deltas * flow.norm(dir_.unsqueeze(1), dim=-1) + sigmas = sigmas # TODO: error in it + noise = flow.randn(sigmas.shape).to_global( + placement=sigmas.placement, sbp=sigmas.sbp + ) * noise_std + + # compute alpha by the formula (3) + alphas = 1 - flow.exp(-deltas * flow.relu(sigmas + noise)) # (N_rays, N_samples_) + ne_alphas = 1 - alphas + 1e-10 + alphas_shifted = flow.cat( + [flow.ones_like(alphas[:, :1]).to_global(sbp=alphas.sbp, placement=alphas.placement), ne_alphas], -1 + ) # [1, a1, a2, ...] + alphas = alphas + rgbs = rgbs + weights = alphas * cumprod(alphas_shifted, -1)[:, :-1] # (N_rays, N_samples_) + # weights = alphas * alphas_shifted[:, :-1] # (N_rays, N_samples_) + weights_sum = weights.sum(1) # (N_rays), the accumulated opacity along the rays + # equals "1 - (1-a1)(1-a2)...(1-an)" mathematically + if weights_only: + return weights + + # compute final weighted outputs + rgb_final = flow.sum(weights.unsqueeze(-1) * rgbs, -2) # (N_rays, 3) + depth_final = flow.sum(weights * z_vals, -1) # (N_rays) + + if white_back: + rgb_final = rgb_final + 1 - weights_sum.unsqueeze(-1) + + return rgb_final, depth_final, weights + + def render_rays( + self, + models, + embeddings, + rays, + N_samples=64, + use_disp=False, + perturb=0.0, + N_importance=0.0, + test_time=False, + noise_std=1.0, + chunk=1024 * 32, + white_back=False, + ): + + # Extract models from lists + model_coarse = models[0] + embedding_xyz = embeddings[0] + embedding_dir = embeddings[1] + + # Decompose the inputs + N_rays = rays.shape[0] + rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) + near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) + + # Embed direction + dir_embedded = embedding_dir(rays_d) # (N_rays, embed_dir_channels) + + # Sample depth points + z_steps = flow.linspace(0, 1, N_samples).to_global(sbp=rays.sbp, placement=rays.placement) # (N_samples) + if not use_disp: # use linear sampling in depth space + z_vals = near * (1 - z_steps) + far * z_steps + else: # use linear sampling in disparity space + z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) + + z_vals = z_vals.expand(N_rays, N_samples) + + if perturb > 0: # perturb sampling depths (z_vals) + z_vals_mid = 0.5 * ( + z_vals[:, :-1] + z_vals[:, 1:] + ) # (N_rays, N_samples-1) interval mid points + # get intervals between samples + upper = flow.cat([z_vals_mid, z_vals[:, -1:]], -1) + lower = flow.cat([z_vals[:, :1], z_vals_mid], -1) + + perturb_rand = perturb * flow.rand(z_vals.shape).to_global(sbp=rays.sbp, placement=rays.placement) + z_vals = lower + (upper - lower) * perturb_rand + + xyz_coarse_sampled = rays_o.unsqueeze(1) + rays_d.unsqueeze(1) * z_vals.unsqueeze( + 2 + ) # (N_rays, N_samples, 3) + + if test_time: + weights_coarse = self.inference( + rays.shape[0], + model_coarse, + embedding_xyz, + xyz_coarse_sampled, + rays_d, + dir_embedded, + z_vals, + noise_std, + chunk, + white_back, + weights_only=True, + ) + result = {"opacity_coarse": weights_coarse.sum(1)} + else: + rgb_coarse, depth_coarse, weights_coarse = self.inference( + rays.shape[0], + model_coarse, + embedding_xyz, + xyz_coarse_sampled, + rays_d, + dir_embedded, + z_vals, + noise_std, + chunk, + white_back, + weights_only=False, + ) + result = { + "rgb_coarse": rgb_coarse, + "depth_coarse": depth_coarse, + "opacity_coarse": weights_coarse.sum(1), + } + + if N_importance > 0: # sample points for fine model + z_vals_mid = 0.5 * ( + z_vals[:, :-1] + z_vals[:, 1:] + ) # (N_rays, N_samples-1) interval mid points + z_vals_ = self.sample_pdf( + z_vals_mid, weights_coarse[:, 1:-1], N_importance, det=(perturb == 0) + ) + # detach so that grad doesn't propogate to weights_coarse from here + + z_vals, _ = flow.sort(flow.cat([z_vals, z_vals_], -1), -1) + + xyz_fine_sampled = rays_o.unsqueeze(1) + rays_d.unsqueeze(1) * z_vals.unsqueeze(2) + # (N_rays, N_samples+N_importance, 3) + + model_fine = models[1] + rgb_fine, depth_fine, weights_fine = self.inference( + rays.shape[0], + model_fine, + embedding_xyz, + xyz_fine_sampled, + rays_d, + dir_embedded, + z_vals, + noise_std, + chunk, + white_back, + weights_only=False, + ) + result["rgb_fine"] = rgb_fine + result["depth_fine"] = depth_fine + result["opacity_fine"] = weights_fine.sum(1) + + return result + + def forward_features(self, rays): + """Do batched inference on rays using chunk.""" + B = rays.shape[0] + results = defaultdict(list) + for i in range(0, B, self.chunk): + rendered_ray_chunks = self.render_rays( + self.models, + [self.embedding_xyz, self.embedding_dir], + rays[i: i + self.chunk], + self.N_samples, + self.use_disp, + self.perturb, + self.N_importance, + False, + self.noise_std, + self.chunk, # chunk size is effective in val mode + self.white_back, + ) + + for k, v in rendered_ray_chunks.items(): + results[k] += [v] + for k, v in results.items(): + results[k] = flow.cat(v, 0) + return results + + def forward(self, rays, rgbs, c2w=None, valid_mask=None): + if c2w == None: + results = self.forward_features(rays) + losses = self.loss_func(results["rgb_coarse"], rgbs) + if "rgb_fine" in results: + losses += self.loss_func(results["rgb_fine"], rgbs) + return {"losses": losses} + else: + rays = rays.squeeze() # (H*W, 3) + rgbs = rgbs.squeeze() # (H*W, 3) + results = self.forward_features(rays) + losses = self.loss_func(results["rgb_coarse"], rgbs) + if "rgb_fine" in results: + losses += self.loss_func(results["rgb_fine"], rgbs) + typ = "fine" if "rgb_fine" in results else "coarse" + re = collections.OrderedDict() + re["losses"] = losses + re[typ] = flow.Tensor([0.]).to_global(sbp=losses.sbp,placement=losses.placement) + for key,value in results.items(): + re[key] = value.unsqueeze(0) + re["rgbs"] = rgbs.unsqueeze(0) + return re + + +def cumprod(inputs, dim=0): + ndim = inputs.ndim + assert 0 <= dim < ndim or -ndim <= dim <= 0, f"{dim} must between [0,{ndim}) or [-{ndim},0]" + if dim < 0: + dim = ndim + dim + res = flow.index_select(inputs, dim, flow.LongTensor([0]).to_global(sbp=inputs.sbp,placement=inputs.placement)) + result = [res] + for i in range(1, inputs.shape[dim]): + res = res * flow.index_select(inputs, dim, flow.LongTensor([i]).to_global(sbp=inputs.sbp,placement=inputs.placement)) # + result.append(res) + result = flow.cat(result,dim=dim) + return result \ No newline at end of file diff --git a/projects/NeRF/optimizers/Radam.py b/projects/NeRF/optimizers/Radam.py new file mode 100644 index 000000000..d431e120d --- /dev/null +++ b/projects/NeRF/optimizers/Radam.py @@ -0,0 +1,113 @@ +import math + +import oneflow as flow +from oneflow.optim import Optimizer + + +class RAdam(Optimizer): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if "betas" in param and ( + param["betas"][0] != betas[0] or param["betas"][1] != betas[1] + ): + param["buffer"] = [[None, None, None] for _ in range(10)] + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + buffer=[[None, None, None] for _ in range(10)], + ) + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = flow.zeros_like(p_data_fp32) + state["exp_avg_sq"] = flow.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state["step"] += 1 + buffered = group["buffer"][int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state["step"]) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group["weight_decay"] != 0: + p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group["weight_decay"] != 0: + p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) + p_data_fp32.add_(-step_size * group["lr"], exp_avg) + p.data.copy_(p_data_fp32) + + return loss diff --git a/projects/NeRF/optimizers/Ranger.py b/projects/NeRF/optimizers/Ranger.py new file mode 100644 index 000000000..ed5c21da8 --- /dev/null +++ b/projects/NeRF/optimizers/Ranger.py @@ -0,0 +1,173 @@ +import math + +import oneflow as flow +from oneflow.optim import Optimizer + + +class Ranger(Optimizer): + def __init__( + self, + params, + lr=1e-3, + alpha=0.5, + k=6, + N_sma_threshhold=5, + betas=(0.95, 0.999), + eps=1e-5, + weight_decay=0, + ): + # parameter checks + if not 0.0 <= alpha <= 1.0: + raise ValueError(f"Invalid slow update rate: {alpha}") + if not 1 <= k: + raise ValueError(f"Invalid lookahead steps: {k}") + if not lr > 0: + raise ValueError(f"Invalid Learning Rate: {lr}") + if not eps > 0: + raise ValueError(f"Invalid eps: {eps}") + + # parameter comments: + # beta1 (momentum) of .95 seems to work better than .90... + # N_sma_threshold of 5 seems better in testing than 4. + # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. + + # prep defaults and init flow.optim base + defaults = dict( + lr=lr, + alpha=alpha, + k=k, + step_counter=0, + betas=betas, + N_sma_threshhold=N_sma_threshhold, + eps=eps, + weight_decay=weight_decay, + ) + super().__init__(params, defaults) + + # adjustable threshold + self.N_sma_threshhold = N_sma_threshhold + + # now we can get to work... + # removed as we now use step from RAdam...no need for duplicate step counting + # for group in self.param_groups: + # group["step_counter"] = 0 + # print("group step counter init") + + # look ahead params + self.alpha = alpha + self.k = k + + # radam buffer for state + self.radam_buffer = [[None, None, None] for ind in range(10)] + + # self.first_run_check=0 + + # lookahead weights + # 9/2/19 - lookahead param tensors have been moved to state storage. + # This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs. + + # self.slow_weights = [[p.clone().detach() for p in group['params']] + # for group in self.param_groups] + + # don't use grad for lookahead weights + # for w in it.chain(*self.slow_weights): + # w.requires_grad = False + + def __setstate__(self, state): + print("set state called") + super(Ranger, self).__setstate__(state) + + def step(self, closure=None): + loss = None + # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. + # Uncomment if you need to use the actual closure... + + # if closure is not None: + # loss = closure() + + # Evaluate averages and grad, update param tensors + for group in self.param_groups: + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("Ranger optimizer does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] # get state dict for this param + + if ( + len(state) == 0 + ): # if first time to run...init dictionary with our desired entries + # if self.first_run_check==0: + # self.first_run_check=1 + # print("Initializing slow buffer...should not see this at load from saved model!") + state["step"] = 0 + state["exp_avg"] = flow.zeros_like(p_data_fp32) + state["exp_avg_sq"] = flow.zeros_like(p_data_fp32) + + # look ahead weight storage now in state dict + state["slow_buffer"] = p.data.clone() + + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + # begin computations + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + # compute variance mov avg + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + # compute mean moving avg + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state["step"] += 1 + + buffered = self.radam_buffer[int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + if N_sma > self.N_sma_threshhold: + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) + else: + step_size = 1.0 / (1 - beta1 ** state["step"]) + buffered[2] = step_size + + if group["weight_decay"] != 0: + p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) + + if N_sma > self.N_sma_threshhold: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group["lr"], exp_avg) + + p.data.copy_(p_data_fp32) + + # integrated look ahead... + # we do it at the param level instead of group level + if state["step"] % group["k"] == 0: + slow_p = state["slow_buffer"] # get access to slow param tensor + slow_p.add_( + self.alpha, p.data - slow_p + ) # (fast weights - slow weights) * alpha + p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor + + return loss diff --git a/projects/NeRF/optimizers/__init__.py b/projects/NeRF/optimizers/__init__.py new file mode 100644 index 000000000..c9a29ec89 --- /dev/null +++ b/projects/NeRF/optimizers/__init__.py @@ -0,0 +1,2 @@ +from .Radam import RAdam +from .Ranger import Ranger diff --git a/projects/__init__.py b/projects/__init__.py new file mode 100644 index 000000000..e69de29bb From 46242fd26b10cbf39a49031f4baee893ab5bb959 Mon Sep 17 00:00:00 2001 From: shaoshitong <1090784053@qq.com> Date: Wed, 17 Aug 2022 12:57:45 +0800 Subject: [PATCH 2/2] Notes --- projects/NeRF/configs/config_nerf.py | 2 +- projects/NeRF/evaluation/nerf_evaluator.py | 4 ++++ projects/NeRF/modeling/System.py | 11 +++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/projects/NeRF/configs/config_nerf.py b/projects/NeRF/configs/config_nerf.py index 30a7ad367..a179a1cc3 100644 --- a/projects/NeRF/configs/config_nerf.py +++ b/projects/NeRF/configs/config_nerf.py @@ -38,7 +38,7 @@ def get_nerf_dataset(dataset_type="Blender"): train.train_micro_batch_size = 1024 train.test_micro_batch_size = 1 train.dataset_type = "Blender" -train.blender_dataset_path = "/path/to/llff" +train.blender_dataset_path = "/home/Bigdata/Nerf/nerf_synthetic/chair" train.llff_dataset_path = "/path/to/llff" train.train_epoch = 16 if train.dataset_type == "Blender" else 30 train.warmup_ratio = 0 / train.train_epoch diff --git a/projects/NeRF/evaluation/nerf_evaluator.py b/projects/NeRF/evaluation/nerf_evaluator.py index 4bd809064..a26df1670 100644 --- a/projects/NeRF/evaluation/nerf_evaluator.py +++ b/projects/NeRF/evaluation/nerf_evaluator.py @@ -33,6 +33,10 @@ def process(self, inputs, outputs): outputs["losses"], outputs["rgbs"].squeeze(0), ) + """ + Notes: + 这里也是一样,我需要通过额外的squeeze(0)操作来恢复原有的变量 + """ typ = list(outputs.keys())[1] outputs.pop(typ) outputs.pop("losses") diff --git a/projects/NeRF/modeling/System.py b/projects/NeRF/modeling/System.py index be1da2608..64abf5548 100644 --- a/projects/NeRF/modeling/System.py +++ b/projects/NeRF/modeling/System.py @@ -414,6 +414,17 @@ def forward(self, rays, rgbs, c2w=None, valid_mask=None): for key,value in results.items(): re[key] = value.unsqueeze(0) re["rgbs"] = rgbs.unsqueeze(0) + """ + Notes: + 这里首先不支持传出str,int等类型,且输出为一个字典, + 后续代码会对这些指标进行计算,但通常情况下我想要传到 + evaluator的不是相应指标,而是需要保存的中间值。因为 + 在这里我取不到forward的第一个batch_idx,所以只能 + 在evaluator进行运算。 + + 可以的话在这里加入一个额外的字典,其中这个字典存储的元 + 素不会经过检查会直接传到evaluator,而供用户进行扩展。 + """ return re