DDP示例

https://zhuanlan.zhihu.com/p/602305591
https://zhuanlan.zhihu.com/p/178402798

在这里插入图片描述
在这里插入图片描述

关于模型保存与加载 : 其实分为保存 有module和无module2种 ; (上面知乎这篇文章说带时带module)
在这里插入图片描述

关于2种带与不带的说明:
https://blog.csdn.net/hustwayne/article/details/120324639

在project中, 是不带module的, 然后加载预训练权重,会remove一些key; 后期改为mmcv中的load_checkpoint自适应匹配kye-value;

在这里插入图片描述

老模型main.py DDP示例

"""
Copyright (C) 2020 NVIDIA Corporation.  All rights reserved.
Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot.
Authors: Jonah Philion and Sanja Fidler
"""
import warnings
warnings.filterwarnings("error", "MAGMA*")
from fire import Fire
import argparse
import torch
import src
import os
"""
Copyright (C) 2020 NVIDIA Corporation.  All rights reserved.
Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot.
Authors: Jonah Philion and Sanja Fidler
"""

import os
import numpy as np
from time import time
from torch import nn
from src.models_goe_1129_nornn_2d_2_ori import compile_model
# from src.models_goe_1129_nornn_2d_2_zj import compile_model
from tensorboardX import SummaryWriter
from src.data_tfmap_newcxy_nextmask2 import compile_data  # 当前帧拼接帧都加超界点
# from src.data_tfmap_newcxy_ori import compile_data  #  不加超界点
#from src.data_tfmap import compile_data
from src.tools import SimpleLoss, RegLoss, SegLoss, SegLoss, BCEFocalLoss, get_batch_iou, get_val_info, denormalize_img, SimpleLoss
import sys
import cv2
from collections import OrderedDict

from src.config.defaults import get_cfg_defaults
from src.options import get_opts
from src.rendering.neuconw_helper import NeuconWHelper
import open3d as o3d
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
os.environ['LOCAL_RANK'] = "0,1"
torch.set_num_threads(8)


# os.environ["CUDA_VISIBLE_DEVICES"] = "4"
# os.environ['RANK'] = "0"
# os.environ['WORLD_SIZE'] = "1"
# os.environ['MASTER_ADDR'] = "localhost"
# os.environ['MASTER_PORT'] = "12345"
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

"动静态分离里, 构造sample时rays要加一个type的维度"

import argparse


def project_from_lidar_2_cam(img, points, rots, trans, intrins, post_rots, post_trans):
    color_arr = np.zeros((points.shape[0], 3))
    # ego_to_cam
    points -= trans

    points = torch.inverse(rots.view(1, 3, 3)).matmul(points.unsqueeze(-1)).squeeze(-1)
    depths = points[..., 2:]
    points = torch.cat((points[..., :2] / depths, torch.ones_like(depths)), -1)

    # cam_to_img
    points = intrins.view(1, 3, 3).matmul(points.unsqueeze(-1)).squeeze(-1)
    points = post_rots.view(1, 3, 3).matmul(points.unsqueeze(-1)).squeeze(-1)
    points = points + post_trans.view(1, 3)
    # points = points.view(B, N, Z, Y, X, 3)[..., :2]
    points = points.view(-1, 3).int().numpy()

    # imshow
    # pts = points[0,0,2,...].reshape(-1, 2).cpu().numpy()
    # image = np.zeros((128, 352, 3), dtype=np.uint8)
    # for i in range(pts.shape[0]):
    #     cv2.circle(image, (int(pts[i, 0]), int(pts[i, 1])), 1, (255, 255, 255), 2)
    # cv2.imshow("local_map", image)
    # cv2.waitKey(-1)

    # normalize_coord
    img = np.array(img)
    # for i in range(points.shape[0]):
    #     cv2.circle(img, (points[i,0], points[i,1]), 1, tuple(color_arr[i].tolist()), -1)
    return img

def main():
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--local_rank", default = 0, type=int)
    # args = parser.parse_args()

    args = get_opts()
    config = get_cfg_defaults()
    config.merge_from_file(args.cfg_path)
    print(config)

    # args.local_rank = 2
    print("sssss",args.local_rank)
    # 新增3:DDP backend初始化
#       a.根据local_rank来设定当前使用哪块GPU
#       b.初始化DDP,使用默认backend(nccl)就行。如果是CPU模型运行,需要选择其他后端。
    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
        device=torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method='env://')
    
    version = "0"
    #dataroot = "/defaultShare/aishare/share"
    dataroot = "/data/zjj/data/aishare/share"
    nepochs=10000
 
    final_dim=(128, 352)
    max_grad_norm=5.0
    #max_grad_norm=2.0
    pos_weight=2.13

    logdir=f'/mnt/sdb/xzq/occ_project/occ_nerf_st/log/{args.exp_name}'
   
    xbound=[0.0, 102., 0.85]
    ybound=[-10.0, 10.0, 0.5]
    zbound=[-2.0, 4.0, 1]
    dbound=[3.0, 103.0, 2.]

    # xbound=[0.0, 96., 0.5]
    # ybound=[-12.0, 12.0, 0.5]
    # zbound=[-2.0, 4.0, 1]
    # dbound=[3.0, 103.0, 2.]

    bsz=4
    seq_len=5 #5
    nworkers=1 #2
    lr=1e-4
    # weight_decay=1e-7
    weight_decay = 0
    sample_num = 1024
    datatype = "single"    #multi   single
        
    torch.backends.cudnn.benchmark = True
    grid_conf = {
        'xbound': xbound,
        'ybound': ybound,
        'zbound': zbound,
        'dbound': dbound,
    }
 
    ### bevgnd
    data_aug_conf = {
                'resize_lim': [(0.05, 0.4), (0.3, 0.90)],#(0.3-0.9)
                'final_dim': (128, 352),
                'rot_lim': (-5.4, 5.4),
                # 'H': H, 'W': W,
                'rand_flip': False,
                'bot_pct_lim': [(0.04, 0.35), (0.15, 0.4)],
                'cams': ['CAM_FRONT0', 'CAM_FRONT1'],
                'Ncams': 2,
            }
    
    train_sampler, val_sampler, trainloader, valloader = compile_data(version, dataroot, data_aug_conf=data_aug_conf,
                                          grid_conf=grid_conf, bsz=bsz, seq_len=seq_len, sample_num=sample_num, nworkers=nworkers,
                                          parser_name='segmentationdata', datatype=datatype)
    print("train lengths: ", len(trainloader))
    # print("val lengths: ", len(valloader))
    # device = torch.device('cpu') if gpuid < 0 else torch.device(f'cuda:{gpuid}')
    writer = SummaryWriter(logdir=logdir)
    model = compile_model(grid_conf, data_aug_conf, seq_len=seq_len, batchsize=int(bsz), config=config, args=args, writer=writer)
    counter = 0
 
    if 0:
        print('==> loading existing model')
        model_info = torch.load('/data/zjj/project/bev_osr_distort_multi_addtime_nornn_align_h5_nerf_multi2/checkpoints/models_20231113_nornn_120_21_6_b2_lall_sample1024_v1/checkpts/model_30000.pt')
        # model_info = torch.load('/zhangjingjuan/NeRF/bev_osr_distort_multi_addtime_nornn_align_h5_nerf_multi2/checkpoints/models_20231114_nornn_v2/checkpts/model_50000.pt')
        #model_info = torch.load('/data/zjj/bev_osr_distort_multi_addtime_nornn_align_h5_nerf_multi2/checkpoints/models_20231120_nornn_v1/checkpts/model_18000.pt')

        counter = 0

        new_state_dict = OrderedDict()
        for k, v in model_info.items():
            if 'semantic_net' in k:
                continue
            # if 'SEnet' in k or 'voxels' in k or 'bevencode.downchannel' in k or 'bevencode.up3' in k or 'bevencode.conv1_block' in k:
            #    continue
            # if 'voxels' in k:
            #     continue
            # if 'color_net' in k:
            #     continue
    
            
            if "neuconw_helper" in k:
                name = k[22:]
            elif "module." in k:
                name = k[7:]  # remove "module."
                #print(k)
            else:
                name = k
            
            '''
            if "module." in k:
                name = k[7:]  # remove "module."
            else:
                name = k
            '''
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict, strict=False)
        model.dx.data = torch.tensor([0.85, 0.5, 1.0]).to(device)
        # model.dx.data = torch.tensor([0.5, 0.5, 0.5]).to(device)
        # model.nx.data = torch.tensor([204, 40, 12]).to(device)
        # model.bx.data = torch.tensor([0.25, -9.75, -1.75]).to(device)
    # 封装之前要把模型移到对应的gpu
    model.to(device)

    neuconw_helper = NeuconWHelper(args, config, model.neuconw, model.embedding_a, writer)
		#  DDP封装
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                            output_device=args.local_rank,find_unused_parameters=True)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    # opt = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)


    loss_fn = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_ll = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_sl = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_zc = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_ar = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_rs = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_cl = SimpleLoss(pos_weight).cuda(args.local_rank)
    loss_fn_lf_pred = SimpleLoss(pos_weight).cuda(args.local_rank)
    loss_fn_lf_norm = RegLoss(0).cuda(args.local_rank)
    # loss_fn_patch = SimpleLoss(pos_weight).cuda(args.local_rank)
    
    val_step = 1000
    t1 = time()
    t2 = time()
    model.train()
    scaler = torch.cuda.amp.GradScaler()

    train_bev = False # False
    train_occ = True
    for epoch in range(nepochs):
        np.random.seed()
        train_sampler.set_epoch(epoch)
        start = time()
        for batchi, (imgs, rots, trans, intrins, dist_coeffss, post_rots, post_trans, cam_pos_embeddings, binimgs, lf_label_gt, lf_norm_gt, fork_scales_gt, fork_offsets_gt, fork_oris_gt, rays, theta_mat_2d, theta_mat_3d) in enumerate(trainloader):
            t0 = time()
            t =  t0 - t1
            tt = t0 - t2
            t1 = time()

            # print("img_path = ", img_paths[-1][0])
            if 1:
                seg_preds1, seg_preds2, lf_preds, _, _ , loss_osr = model(imgs.to(device), rots.to(device), trans.to(device), intrins.to(device), dist_coeffss.to(device), post_rots.to(device), 
                        post_trans.to(device), cam_pos_embeddings.to(device), fork_scales_gt.to(device),fork_offsets_gt.to(device),fork_oris_gt.to(device), rays.to(device), theta_mat_2d.to(device), counter, 'train')

                if train_bev:
                    lf_pred = lf_preds[:, :, :1].contiguous()
                    lf_norm = lf_preds[:, :, 1:(1+4)].contiguous()
                    # lf_kappa = lf_preds[:, :, (1+4):(1+4+2)].contiguous()

                    lf_out = lf_pred.sigmoid()
                    out = seg_preds1.sigmoid()
                    out1 = seg_preds2.sigmoid()

                    binimgs = binimgs.to(device)
                    seg_preds_0 = seg_preds1[:, :, 0] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs0 = binimgs[:, :, 0] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    seg_preds_1 = seg_preds1[:, :, 1] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs1 = binimgs[:, :, 1] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    seg_preds_2 = seg_preds1[:, :, 2] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs2 = binimgs[:, :, 2] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    seg_preds_3 = seg_preds2[:, :, 0] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs3 = binimgs[:, :, 3] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    seg_preds_4 = seg_preds1[:, :, 3] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs4 = binimgs[:, :, 4] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    seg_preds_5 = seg_preds1[:, :, 4] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs5 = binimgs[:, :, 5] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])

                    loss_ll = loss_fn_ll(seg_preds1[:, :, 0].contiguous(), binimgs[:, :, 0].contiguous()) + loss_fn_ll(
                        seg_preds_0.contiguous(), binimgs0.contiguous())
                    loss_sl = loss_fn_sl(seg_preds1[:, :, 1].contiguous(), binimgs[:, :, 1].contiguous()) + loss_fn_sl(
                        seg_preds_1.contiguous(), binimgs1.contiguous())
                    loss_zc = loss_fn_zc(seg_preds1[:, :, 2].contiguous(), binimgs[:, :, 2].contiguous()) + loss_fn_zc(
                        seg_preds_2.contiguous(), binimgs2.contiguous())
                    loss_ar = loss_fn_ar(seg_preds2[:, :, 0].contiguous(), binimgs[:, :, 3].contiguous()) + loss_fn_ar(
                        seg_preds_3.contiguous(), binimgs3.contiguous())
                    loss_rs = loss_fn_rs(seg_preds1[:, :, 3].contiguous(), binimgs[:, :, 4].contiguous()) + loss_fn_rs(
                        seg_preds_4.contiguous(), binimgs4.contiguous())
                    loss_cl = loss_fn_cl(seg_preds1[:, :, 4].contiguous(), binimgs[:, :, 5].contiguous()) + loss_fn_cl(
                        seg_preds_5.contiguous(), binimgs5.contiguous())
            
                    # lf_norm_gt0 = torch.unsqueeze(torch.sum(lf_norm_gt, 2), 2)
                    norm_mask = (lf_norm_gt > -500)
                    # norm_mask = ((lf_label_gt>-0.5)).repeat(1, 1, 4, 1, 1)

                    scale_lf = 5.
                    loss_lf = loss_fn_lf_pred(lf_pred, lf_label_gt.to(device)) + loss_fn_lf_norm(lf_norm[norm_mask], scale_lf*lf_norm_gt[norm_mask].to(device))
                    # loss_ilf = loss_fn_lf_pred(lf_ipred, lf_label_gt.to(device)) + loss_fn_lf_norm(scale_lf*lf_inorm[norm_mask], scale_lf*lf_norm_gt[norm_mask].to(device))
                    # loss_lf_crop = loss_fn_patch(lf_crop_preds, fork_patch_gt.to(device))
                    # print('lf_loss = ', loss_lf)
                    loss_gnd = loss_lf + loss_ll + loss_sl + loss_zc + loss_ar + loss_rs + loss_cl# + loss_ilf
                    # loss = loss_ll + loss_sl + loss_zc + loss_ar + loss_rs + loss_cl

                if train_occ:
                    # loss = loss_gnd + loss_osr
                    loss = loss_osr
                    #loss = loss_gnd
                    opt.zero_grad()
                    # scaler.scale(loss).backward()
                    loss.backward()
                    clip_debug = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                    opt.step()
            # except:
                # continue
            # scaler.step(opt)
            # scaler.update()
            t2 = time()
            writer.add_scalar('train/clip_debug', clip_debug.item(), counter)
            if counter % 10 == 0 and args.local_rank==0:
                print(counter, loss.item(),  time() - start)

            if train_bev:
                if counter % 10 == 0 and args.local_rank==0:
                    # print(loss_lf.item(), loss_ll.item(), loss_sl.item(), loss_zc.item(), loss_ar.item(), loss_rs.item(), loss_cl.item())
                    # print(counter, loss.item(), loss_gnd.item(), loss_osr.item(), time() - start)
                    # print(counter, loss.item(), time() - start)
                    writer.add_scalar('train/loss', loss, counter)
                    writer.add_scalar('train/loss_ll', loss_ll, counter)
                    writer.add_scalar('train/loss_sl', loss_sl, counter)
                    writer.add_scalar('train/loss_zc', loss_zc, counter)
                    writer.add_scalar('train/loss_ar', loss_ar, counter)
                    writer.add_scalar('train/loss_rs', loss_rs, counter)
                    writer.add_scalar('train/loss_cl', loss_cl, counter)
                    writer.add_scalar('train/loss_lf', loss_lf, counter)
                    # writer.add_scalar('train/loss_lf_crop', loss_lf_crop, counter)
                    writer.add_scalar('train/loss_gnd', loss_gnd, counter)
                    writer.add_scalar('train/loss_osr', loss_osr, counter)
                    writer.add_scalar('train/clip_debug', clip_debug.item(), counter)

                if counter % 50 == 0 and args.local_rank==0:
                    _, _, iou_ll = get_batch_iou(seg_preds1[:, :, 0].contiguous(), binimgs[:, :, 0].contiguous())
                    _, _, iou_sl = get_batch_iou(seg_preds1[:, :, 1].contiguous(), binimgs[:, :, 1].contiguous())
                    _, _, iou_zc = get_batch_iou(seg_preds1[:, :, 2].contiguous(), binimgs[:, :, 2].contiguous())
                    _, _, iou_ar = get_batch_iou(seg_preds2[:, :, 0].contiguous(), binimgs[:, :, 3].contiguous())
                    _, _, iou_rs = get_batch_iou(seg_preds1[:, :, 3].contiguous(), binimgs[:, :, 4].contiguous())
                    _, _, iou_cl = get_batch_iou(seg_preds1[:, :, 4].contiguous(), binimgs[:, :, 5].contiguous())
                    writer.add_scalar('train/iou_ll', iou_ll, counter)
                    writer.add_scalar('train/iou_sl', iou_sl, counter)
                    writer.add_scalar('train/iou_zc', iou_zc, counter)
                    writer.add_scalar('train/iou_ar', iou_ar, counter)
                    writer.add_scalar('train/iou_rs', iou_rs, counter)
                    writer.add_scalar('train/iou_cl', iou_cl, counter)
                    writer.add_scalar('train/epoch', epoch, counter)
                    writer.add_scalar('train/step_time', t, counter)
                    writer.add_scalar('train/data_time', tt, counter)

                if counter % 200 == 0 and args.local_rank==0:
                    fH = final_dim[0]
                    fW = final_dim[1]
                    image0 =np.array(denormalize_img(imgs[0, 0]))
                    image1 =np.array(denormalize_img(imgs[0, 1]))
                    # image2 =np.array(denormalize_img(imgs[0, 2]))
                    # image3 =np.array(denormalize_img(imgs[0, 3]))
                    writer.add_image('train/image/00', image0, global_step=counter, dataformats='HWC')
                    writer.add_image('train/image/01', image1, global_step=counter, dataformats='HWC')
                    # writer.add_image('train/image/02', image2, global_step=counter, dataformats='HWC')
                    # writer.add_image('train/image/03', image3, global_step=counter, dataformats='HWC')
                    writer.add_image('train/binimg/0', (binimgs[0, 1, 0:1]+1.)/2.01, global_step=counter)

                    writer.add_image('train/binimg/1', (binimgs[0, 1, 1:2]+1.)/2.01, global_step=counter)
                    writer.add_image('train/binimg/2', (binimgs[0, 1, 2:3]+1.)/2.01, global_step=counter)
                    writer.add_image('train/binimg/3', (binimgs[0, 1, 3:4]+1.)/2.01, global_step=counter)
                    writer.add_image('train/binimg/4', (binimgs[0, 1, 4:5]+1.)/2.01, global_step=counter)
                    writer.add_image('train/binimg/5', (binimgs[0, 1, 5:6]+1.)/2.01, global_step=counter)
                    writer.add_image('train/out/0', out[0, 1, 0:1], global_step=counter)
                    writer.add_image('train/out/1', out[0, 1, 1:2], global_step=counter)
                    writer.add_image('train/out/2', out[0, 1, 2:3], global_step=counter)
                    writer.add_image('train/out/3', out1[0, 1, 0:1], global_step=counter)
                    writer.add_image('train/out/4', out[0, 1, 3:4], global_step=counter)
                    writer.add_image('train/out/5', out[0, 1, 4:5], global_step=counter)

                    writer.add_image('train/lf_label_gt/0', (lf_label_gt[0, 1]+1.)/2.01, global_step=counter)
                    writer.add_image('train/lf_out/0', lf_out[0, 1], global_step=counter)
                    # writer.add_image('train/fork_patch/0', (fork_patch_gt[0, 1, 0:1]+1.)/2.01, global_step=counter)
                    # writer.add_image('train/fork_patch/1', (fork_patch_gt[0, 1, 1:2]+1.)/2.01, global_step=counter)
                    # writer.add_image('train/lf_crop_out/0', lf_crop_out[0, 1, 0:1], global_step=counter)
                    # writer.add_image('train/lf_crop_out/1', lf_crop_out[0, 1, 1:2], global_step=counter)

                    seg_ll_data = binimgs[0, 1, 0].cpu().detach().numpy()
                    seg_cl_data = binimgs[0, 1, 5].cpu().detach().numpy()

                    lf_label_data_gt = lf_label_gt[0, 1, 0].numpy()
                    lf_norm_data_gt = lf_norm_gt[0, 1].numpy()

                    lf_norm_show = np.zeros((480, 160, 3), dtype=np.uint8)
                    ys, xs = np.where(seg_ll_data > 0.5)
                    lf_norm_show[ys, xs, :] = 255

                    ys, xs = np.where(lf_label_data_gt> -0.5)
                    lf_norm_show[ys, xs, :] = 128

                    labels = np.logical_or(seg_ll_data[ys, xs] > 0.5, seg_cl_data[ys, xs] > 0.5)
                    ys = ys[labels]
                    xs = xs[labels]
                    scale = 1.7

                    if ys.shape[0] > 0:
                        for mm in range(0, ys.shape[0], 10):
                            y = ys[mm]
                            x = xs[mm]
                            norm0 = lf_norm_data_gt[0:2, y, x]
                            if norm0[0] == -999.:
                                continue
                            cv2.line(lf_norm_show, (x, y), (x+int(round(norm0[0]*50)), y + int(round(scale * (norm0[1]+1)*50))), (0, 0, 255))
                            norm1 = lf_norm_data_gt[2:4, y, x]
                            if norm1[0] == -999.:
                                continue
                            cv2.line(lf_norm_show, (x, y), (x+int(round(norm1[0]*50)), y + int(round(scale * (norm1[1]+1)*50))), (255, 0, 0))
                    writer.add_image('train/lf_norm_gt/0',  lf_norm_show, global_step=counter, dataformats='HWC')

                    lf_norm_data = lf_norm[0, 1].detach().cpu().numpy()
                    ys, xs = np.where(np.logical_or(seg_ll_data > 0.5, seg_cl_data > 0.5))
                    lf_norm_show = np.zeros((480, 160, 3), dtype=np.uint8)
                    if ys.shape[0] > 0:
                        for mm in range(0, ys.shape[0], 10):
                            y = ys[mm]
                            x = xs[mm]
                            norm0 = lf_norm_data[0:2, y, x]/scale_lf
                            cv2.line(lf_norm_show, (x, y), (x+int(round(norm0[0]*50)), y+int(round(scale * (norm0[1]+1)*50))), (0, 0, 255))
                            norm1 = lf_norm_data[2:4, y, x]/scale_lf
                            cv2.line(lf_norm_show, (x, y), (x+int(round(norm1[0]*50)), y+int(round(scale * (norm1[1]+1)*50))), (255, 0, 0))
                    writer.add_image('train/lf_norm/0',  lf_norm_show, global_step=counter, dataformats='HWC')

            if counter % (1*val_step) == 0 and args.local_rank==0:
                model.eval()
                #mname = os.path.join(logdir, "model{}.pt".format(0))
                #mname = os.path.join(logdir, "model{}.pt".format(counter))#counter))
                #print('saving', mname)
                #torch.save(model.state_dict(), mname)

                checkpt_dir = f"{config.TRAINER.SAVE_DIR}/{args.exp_name}/checkpts/"
                os.makedirs(checkpt_dir, exist_ok=True)
                mname = os.path.join(checkpt_dir, f"model_{counter}.pt")
                torch.save(model.state_dict(), mname)

				
    

            counter += 1


  
if __name__ == '__main__':
    main()

train.sh

PORT=${PORT:-29512}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
    --master_addr=$MASTER_ADDR \
    --master_port=$PORT \
    --nproc_per_node=2 \  # 对应gpu数量
    main_multii_conv2d.py \
    --cfg_path /mnt/sdb/xzq/occ_project/occ_nerf_st/src/config/train_tongfan_ngp.yaml \
    --num_epochs 50 \
    --num_gpus 2 \
    --num_nodes 1 \
    --batch_size 2048 \
    --test_batch_size 512 \
    --num_workers 2 \
    --exp_name models_20231207_nornn_2d_2_ori_theatmatvalid__st_v0_1bag_bsz4_rays1024_data_tfmap_newcxy_nextmask2_bevgrid_conf_adjustnearfar2

Note :

  1. 貌似 单机多卡不需要通讯address, port
  2. 多机多卡才需要
# 单机多卡示例
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py

老模型推理原始脚本 - remove key

"""
Copyright (C) 2020 NVIDIA Corporation.  All rights reserved.
Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot.
Authors: Jonah Philion and Sanja Fidler
"""

import os
import torch
import numpy as np
from torch import nn
from collections import OrderedDict
from src.models_goe_1129_nornn_2d_2_ori import compile_model
# from src.models_goe_1129_nornn_2d_2_ori_flash import compile_model
from tensorboardX import SummaryWriter
# from src.data_tfmap_newcxy_ori import compile_data
from src.data_tfmap_newcxy_nextmask2 import compile_data
from src.tools import SimpleLoss, RegLoss, SegLoss, BCEFocalLoss, get_batch_iou, get_val_info, denormalize_img
import sys
import cv2
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ['RANK'] = "0"
os.environ['WORLD_SIZE'] = "1"
os.environ['MASTER_ADDR'] = "localhost"
os.environ['MASTER_PORT'] = "12332"
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import argparse
import open3d as o3d
import json
from src.config.defaults import get_cfg_defaults
from src.options import get_opts
from src.utils.visualization import extract_mesh, extract_mesh2, extract_alpha
from src.rendering.neuconw_helper import NeuconWHelper

pi = 3.1415926

def convert_rollyawpitch_to_rot(roll, yaw, pitch):
    roll *= pi/180.
    yaw *= pi/180.
    pitch *= pi/180.
    Rr = np.array([[0.0, -1.0, 0.0],
                   [0.0, 0.0, -1.0],
                   [1.0, 0.0, 0.0]], dtype=np.float32)
    Rx = np.array([[1.0, 0.0, 0.0],
                   [0.0, np.cos(roll), np.sin(roll)],
                   [0.0, -np.sin(roll), np.cos(roll)]], dtype=np.float32)
    Ry = np.array([[np.cos(pitch), 0.0, -np.sin(pitch)],
                   [0.0, 1.0, 0.0],
                   [np.sin(pitch), 0.0, np.cos(pitch)]], dtype=np.float32)
    Rz = np.array([[np.cos(yaw), np.sin(yaw), 0.0],
                   [-np.sin(yaw), np.cos(yaw), 0.0],
                   [0.0, 0.0, 1.0]], dtype=np.float32)
    R = np.matrix(Rr) * np.matrix(Rx) * np.matrix(Ry) * np.matrix(Rz)
    return R

def get_view_control(vis, idx):
    view_control = vis.get_view_control()
    if idx == 0:
        ### cam view
        # view_control.set_front([-1, 0, 0])
        # view_control.set_lookat([8, 0, 2])
        # view_control.set_up([0, 0, 1])
        # view_control.set_zoom(0.025)
        # view_control.rotate(0, 2100 / 40)

        ### bev observe object depth
        view_control.set_front([-1, 0, 1])
        view_control.set_lookat([30, 0, 0])
        view_control.set_up([0, 0, 1])
        view_control.set_zoom(0.3)
        view_control.rotate(0, 2100 / 20)

    elif idx == 1:
        view_control.set_front([-1, 0, 0])
        view_control.set_lookat([8, 0, 0])
        # view_control.set_lookat([8, 0, 2])  ### look down
        view_control.set_up([0, 0, 1])
        view_control.set_zoom(0.025)
        view_control.rotate(0, 2100 / 40)
    return view_control

def main():
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--local_rank", default = 0, type=int)
    # args = parser.parse_args()

    args = get_opts()
    config = get_cfg_defaults()
    config.merge_from_file(args.cfg_path)

    args.local_rank = 1
    print("sssss",args.local_rank)
    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
        device=torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method='env://')
    

    # model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231128_nornn_2d_2_ori_st_v0_1bag_bsz4_rays600_data_tfmap_newcxy_ori_theta_matiszero"
    # model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231201_nornn_2d_2_ori_st_v0_1bag_bsz4_rays800_data_tfmap_newcxy_ori_theta_iszero_z6" # 单包, retrain 2d
    # model_path = "/home/algo/mnt/xzq/occ_project/occ_nerf_st/checkpoints/models_20231204_nornn_2d_2_ori_st_v0_10bag_bsz4_rays1024_data_tfmap_newcxy_ori_theta_iszero_z6_adjustnearfar" # 10包, retrain 2d
    # model_path = "/home/algo/mnt/xzq/occ_project/occ_nerf_st/checkpoints/models_20231204_nornn_2d_2_ori_flash_st_v0_1bag_bsz4_rays1024_data_tfmap_newcxy_ori_theta_iszero_z6_adjustnearfar_2"
    # model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231205_nornn_2d_2_ori_st_v0_1bag_bsz4_rays1024_data_tfmap_newcxy_nextmask2_theta_iszero_bevgrid_conf_adjustnearfar2"
    model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231207_nornn_2d_2_ori_st_v0_1bag_bsz4_rays1024_data_tfmap_newcxy_nextmask2_bevgrid_conf_adjustnearfar2"
    
    model_name = "model_32000.pt"
    ckpt_path = model_path + "/checkpts/" + model_name
    to_result_path = "result/" + model_path.split('/')[-1] + '/' + model_name.split('.')[0]
    viz_train = False
    viz_gnd = False
    viz_osr = True

    # xbound=[0.0, 96., 0.5]
    # ybound=[-12.0, 12.0, 0.5]
    # zbound=[-3.0, 5.0, 0.5]
    # dbound=[3.0, 103.0, 2.]

    # xbound=[0.0, 96., 0.5]
    # ybound=[-12.0, 12.0, 0.5]
    # zbound=[-2.0, 4.0, 1]
    # dbound=[3.0, 103.0, 2.]
    xbound=[0.0, 102., 0.85]
    ybound=[-10.0, 10.0, 0.5]
    zbound=[-2.0, 4.0, 1]
    dbound=[3.0, 103.0, 2.]


    bsz=1
    seq_len=5
    nworkers=1
    sample_num = 3200
    datatype = "single"    #multi   single

    version = "0"
    dataroot = "/data/zjj/data/aishare/share"
    # dataroot = "/run/user/1000/gvfs/sftp:host=192.168.1.40%20-p%2022/mnt/inspurfs/share-directory/defaultShare/aishare/share"


    torch.backends.cudnn.benchmark = True
    grid_conf = {
        'xbound': xbound,
        'ybound': ybound,
        'zbound': zbound,
        'dbound': dbound,
    }

    data_aug_conf = {
                'resize_lim': [(0.05, 0.4), (0.3, 0.90)],#(0.3-0.9)
                'final_dim': (128, 352),
                'rot_lim': (-5.4, 5.4),
                # 'H': H, 'W': W,
                'rand_flip': False,
                'bot_pct_lim': [(0.04, 0.35), (0.15, 0.4)],
                'cams': ['CAM_FRONT0', 'CAM_FRONT1'],
                'Ncams': 2,
            }

    # data_aug_conf = {
    #             'resize_lim': [(0.125, 0.125), (0.25, 0.25)],
    #             'final_dim': (128, 352),
    #             'rot_lim': (0, 0),
    #             'rand_flip': False,
    #             'bot_pct_lim': [(0.0, 0.051), (0.2, 0.2)],
    #             'cams': ['CAM_FRONT0', 'CAM_FRONT1'],
    #             'Ncams': 2,
    #     }
    
    train_sampler, val_sampler,trainloader, valloader = compile_data(version, dataroot, data_aug_conf=data_aug_conf,
					  grid_conf=grid_conf, bsz=bsz, seq_len=seq_len, sample_num=sample_num, nworkers=nworkers,
					  parser_name='segmentation1data', datatype=datatype)
    loader = trainloader if viz_train else valloader

    writer = SummaryWriter(logdir=None)
    model = compile_model(grid_conf, data_aug_conf, seq_len=seq_len, batchsize=int(bsz), config=config, args=args, writer=writer,phase='validation')
    checkpoint = torch.load(ckpt_path)
    new_state_dict = OrderedDict()
    for k, v in checkpoint.items():

        if "neuconw_helper" in k:
            name = k[22:]  # remove "neuconw_helper.module."
            # name = k[15:]  # remove "neuconw_helper."
            print(k, name)
            continue
        elif "module." in k:
            name = k[7:]  # remove "module."
            print(k)
        else:
            name = k
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict, True)
    model.to(device)
    num_gpus = torch.cuda.device_count()
    # if num_gpus > 1:
    #     model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
    #                                                          output_device=args.local_rank,find_unused_parameters=True)
    neuconw_helper = NeuconWHelper(args, config, model.neuconw, model.embedding_a, None)

    ww = 160
    hh = 480
    model.eval()
    fps = 30
    flourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
    width = int(3715*300./1110)
    n_view = 2
    roi_num = 2
    osr_hh = int((width + ww * 6)/1853/2*1025)
    if viz_gnd:
        if viz_osr:
            out_shape = (width + ww * 6, hh + osr_hh)
        else:
            out_shape = (width + ww * 6, hh)
    else:
        if viz_osr:
            out_shape = (width + ww * 6, 1080)
        else:
            out_shape = (0, 0)

    colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255)]
    # vis = o3d.visualization.Visualizer()
    # vis.create_window(window_name='bev')
    cur_sce_name = None
    
    count = 0
    with torch.no_grad():
        for batchi, (imgs, rots, trans, intrins, dist_coeffss, post_rots, post_trans, cam_pos_embeddings, binimgs, lf_label_gt, lf_norm_gt, fork_scales_gt, fork_offsets_gt, fork_oris_gt, rays, theta_mat_2d, theta_mat_3d, img_paths, sce_name) in enumerate(loader):
        # for batchi, (imgs, rots, trans, intrins, dist_coeffss, post_rots, post_trans, cam_pos_embeddings, binimgs, lf_label_gt, lf_norm_gt, fork_scales_gt, fork_offsets_gt, fork_oris_gt, rays, theta_mat_2d, theta_mat_3d,  sce_id_ind, idx, img_paths, sce_name) in enumerate(loader):
            if count==0:
                count += 1
                continue
            if sce_name[0] != cur_sce_name:
                sname = '_'.join(sce_name[0].split('/')[-6:-3])
                # output_path = model_path + "/result/" + model_name.split('.')[0] + "/" + sname + '_roi3'
                output_path = to_result_path + "/" + sname
                os.makedirs(output_path, exist_ok=True)
                to_video_path = output_path + "/demo_" + sname + "_train.mp4"
                print(to_video_path)
                to_occ_gt_dir = output_path + '/occ_gts/'
                to_mesh_dir = output_path + '/meshes/'
                to_occ_pred_dir = output_path + '/occ_preds/'
                to_img_dir = output_path + '/img_result/'
                # if cur_sce_name is not None:
                #     videoWriter.release()
                # videoWriter = cv2.VideoWriter(to_video_path, flourcc, fps, out_shape)
                os.makedirs(to_occ_gt_dir, exist_ok=True)
                os.makedirs(to_occ_pred_dir, exist_ok=True)
                os.makedirs(to_mesh_dir, exist_ok=True)
                os.makedirs(to_img_dir, exist_ok=True)
                cur_sce_name = sce_name[0]

            voxel_map_data = model(imgs.to(device), rots.to(device), trans.to(device), 
                                    intrins.to(device), dist_coeffss.to(device), post_rots.to(device), 
                                    post_trans.to(device), cam_pos_embeddings.to(device), fork_scales_gt.to(device),fork_offsets_gt.to(device),fork_oris_gt.to(device), 
                                    rays.to(device), theta_mat_2d.to(device), 0, 'validation')
            
            # voxel_map_data  =model(imgs.to(device),
            #                     rots.to(device),
            #                     trans.to(device),
            #                     intrins.to(device),
            #                     dist_coeffss.to(device),
            #                     post_rots.to(device),
            #                     post_trans.to(device),
            #                     cam_pos_embeddings.to(device),
            #                     fork_scales_gt.to(device),
            #                     fork_offsets_gt.to(device),
            #                     fork_oris_gt.to(device),
            #                     rays.to(device),
            #                     theta_mat_2d.to(device),
            #                     0,
            #                     'validation'
            #                     )

            output_img_merge = np.zeros((out_shape[1], out_shape[0], 3), dtype=np.uint8)
            if viz_gnd:
                print('viz_gnd')
                # norm_mask = (lf_norm_gt > -500)
                binimgs = binimgs.cpu().numpy()
                lf_pred = lf_preds[:, :, :1].contiguous()
                lf_norm = lf_preds[:, :, 1:(1+4)].contiguous()

                seg_out = seg_preds.sigmoid()
                seg_out = seg_out.cpu().numpy()

                lf_out = lf_pred.sigmoid().cpu().numpy()
                lf_norm = lf_norm.cpu().numpy()

                H, W = 944, 1824
                fH, fW = data_aug_conf['final_dim']
                crop0 = []
                crop1 = []
                for cam_idx in range(2):
                    resize = np.mean(data_aug_conf['resize_lim'][cam_idx])
                    resize_dims = (int(fW / resize), int(fH / resize))
                    newfW, newfH = resize_dims
                    # print(newfW, newfH)
                    crop_h = int((1 - np.mean(data_aug_conf['bot_pct_lim'][cam_idx])) * H) - newfH
                    crop_w = int(max(0, W - newfW) / 2)
                    if cam_idx == 0:
                        crop0 = (crop_w, crop_h, crop_w + newfW, crop_h + newfH)
                    else:
                        crop1 = (crop_w, crop_h, crop_w + newfW, crop_h + newfH)

                si = seq_len - 1
                imgname = img_paths[si][0][img_paths[si][0].rfind('/')+1 :]
                print('imgname = ', img_paths[-si][0])
                img_org = cv2.imread(img_paths[si][0])

                imgpath = img_paths[si][0][: img_paths[si][0].rfind('org/')-1]
                param_path = imgpath + '/gen/param_infos.json'
                param_infos = {}
                with open(param_path, 'r') as ff :
                    param_infos = json.load(ff)
                yaw = param_infos['yaw']
                pitch = param_infos['pitch']
                if pitch == 0.789806:
                    pitch = -pitch
                roll = param_infos['roll']
                tran = np.array(param_infos['xyz'])

                H, W = param_infos['imgH_ori'], param_infos['imgW_ori']
                ori_K       = np.array(param_infos['ori_K'],dtype=np.float64).reshape(3,3)
                dist_coeffs = np.array(param_infos['dist_coeffs']).astype(np.float64)

                # cam2car_matrix
                rot = convert_rollyawpitch_to_rot(roll, yaw, pitch).I
                cam2car = np.eye(4, dtype= np.float64)
                cam2car[:3, :3] = rot
                cam2car[:3, 3] = tran.T

                norm = lf_norm[0, 4]
                fork = lf_out[0, 4]
                img_res = np.ones((480, 160, 3), dtype=np.uint8)
                colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(0, 255, 255)]
                for class_id in range(6):
                    result = seg_out[0][si][class_id]
                    if class_id == 5:
                        img_res[result> 0.4] = np.array(colors[class_id])
                    else:
                        img_res[result> 0.4] = np.array(colors[class_id])

                    ys, xs = np.where(result > 0.4)
                    pt = np.array([ys*0.2125, 0.125*xs-10, np.zeros(ys.shape), np.ones(ys.shape)])
                    if pt.shape[1] == 0:
                        continue
                    car2cam = np.matrix(cam2car).I.dot(pt)[:3, :]

                    rvec, tvec = np.array([0,0,0], dtype=np.float32), np.array([0,0,0], dtype=np.float32)
                    cam2img, _ = cv2.projectPoints(np.array(car2cam.T), rvec, tvec, ori_K, dist_coeffs)

                    for ii in range(cam2img.shape[0]):
                        ptx = round(cam2img[ii,0,0])
                        pty = round(cam2img[ii,0,1])
                        cv2.circle(img_org, (ptx, pty), 3, colors[class_id], -1)


                    # gt = binimgs[0][si][class_id]
                    # img_res[gt< -0.5] = np.array((128,128,128))
                img_res = cv2.flip(cv2.flip(img_res, 0), 1)

                img_gt = np.ones((480, 160, 3), dtype=np.uint8)
                for class_id in range(6):
                    result = binimgs[0][si][class_id]
                    img_gt[result> 0.5] = np.array(colors[class_id])
                    img_gt[result< -0.5] = np.array((128,128,128))


                img_gt = cv2.flip(cv2.flip(img_gt, 0), 1)

                cv2.rectangle(img_org, (int(crop0[0]), int(crop0[1])), (int(crop0[2]), int(crop0[3])), (0,255,255), 2)
                cv2.rectangle(img_org, (int(crop1[0]), int(crop1[1])), (int(crop1[2]), int(crop1[3])), (0,255,0), 2)
                img_org = cv2.resize(img_org, (width, hh))
                img_org_show = np.zeros((hh, width+ww*6, 3), dtype=np.uint8)*255
                img_org_show[:, ww*6:] = img_org

                outs = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                outs1 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                outs2 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts1 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts2 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)

                ys, xs = np.where(lf_label_gt[0, si, 0] > -0.5)
                ys1, xs1 = np.where(lf_label_gt[0, si, 0] > 0.5)
                ys2, xs2 = np.where(lf_out[0, si, 0] > 0.5)


                gts[si][binimgs[0, si, 0] > 0.5] = np.array(colors[0])
                outs[si][seg_out[0, si, 0] > 0.5] = np.array(colors[0])

                gts[si][binimgs[0, si, 4] > 0.6] = np.array(colors[4])
                outs[si][seg_out[0, si, 4] > 0.6] = np.array(colors[4])

                gts[si][binimgs[0, si, 5] > 0.6] = np.array(colors[5])
                outs[si][seg_out[0, si, 5] > 0.6] = np.array(colors[5])

                valid_mask = np.sum(gts[si], axis=-1) > 0
                labels = np.where(valid_mask[ys, xs]> 0.5)
                ys = ys[labels]
                xs = xs[labels]
                gts1[si][ys1, xs1, :] = 255

                mask = torch.squeeze(lf_norm_gt[:,si,0])
                # gts2[si][mask < -500] = (128, 128, 128)
                if xs.shape[0] > 0:
                    for mm in range(0, xs.shape[0], 2):
                        # for mm in range(0, 800, 100):
                        y = ys[mm]
                        x = xs[mm]
                        norm = lf_norm_gt[0, si, 0:2, y, x].numpy()
                        if norm[0] == -999.:
                            continue
                        cv2.line(gts2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (0, 255, 0),1)
                        norm = lf_norm_gt[0, si, 2:4, y, x].numpy()
                        cv2.line(gts2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (255, 0, 0),1)
                        # print (norm)
                        # cv2.circle(gts2[si], (x, y), 3, (0, 255, 255))


                # ys, xs = np.where(np.logical_or(seg_out[0][si][0] > 0.5, seg_out[0][si][5] > 0.5))
                # ys, xs = np.where(np.logical_or(seg_out[0][si][0] > -0.5, seg_out[0][si][5] > -0.5))
                valid_mask = np.sum(outs[si], axis=-1) > 0
                labels = np.where(valid_mask[ys, xs]> 0.5)
                ys = ys[labels]
                xs = xs[labels]
                outs1[si][ys2, xs2, :] = 255
                if xs.shape[0] > 0:
                    for mm in range(0, xs.shape[0], 2):
                        y = ys[mm]
                        x = xs[mm]
                        norm = lf_norm[0, si, 0:2, y, x] / 5.
                        # print (norm)
                        cv2.line(outs2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (0, 255, 0),1)
                        norm = lf_norm[0, si, 2:4, y, x] / 5.
                        cv2.line(outs2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (255, 0, 0),1)

                # gts2[si][lf_label_gt[0, si, 0] < -0.5] = (128,128,128)
                # gts1[si][lf_label_gt[0, si, 0] < -0.5] = (128,128,128)

                img_org_show[:, :ww] = img_res
                img_org_show[:, ww:ww*2] = img_gt
                img_org_show[:, ww*2:ww*3] = cv2.flip(cv2.flip(outs2[si], 0), 1)
                img_org_show[:, ww*3:ww*4] = cv2.flip(cv2.flip(gts2[si], 0), 1)
                img_org_show[:, ww*4:ww*5] = cv2.flip(cv2.flip(outs1[si], 0), 1)
                img_org_show[:, ww*5:ww*6] = cv2.flip(cv2.flip(gts1[si], 0), 1)

                cv2.putText(img_org_show, "NAME:" + imgname + 'seq_id: '+ str(si), (700+320, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
                # print(idxs)

                output_img_merge[:img_org_show.shape[0], :] = img_org_show


            if viz_osr:
                si = seq_len - 1
                imgname = img_paths[si][0][img_paths[si][0].rfind('/')+1 :]
                # print('imgname = ', img_paths[-si][0])
                output_img = np.zeros((1025, 1853*2, 3), dtype=np.uint8)
                to_occ_gt_path = to_occ_gt_dir + imgname.replace('.jpg', '.ply')
                to_occ_pred_path = to_occ_pred_dir + imgname.replace('.jpg', '.ply')
                to_mesh_path = to_mesh_dir + imgname.replace('.jpg', '.ply')
                to_img_path = to_img_dir + imgname
                to_bin_path = to_img_dir + imgname.replace('.jpg', '.bin')
                idx = rays[0, si, :, 15] < 1

                pts_gt = rays[0, si, idx, 0:3] + rays[0, si, idx, 3:6]*rays[0, si, idx, 9:10]  # gt_pts
                semantic_gt = rays[0, si, idx, 8].view(-1,1)

                # pts = rays_all[si][0, :, :3] + rays_all[si][0, :, 3:6] * rays_all[si][0, :, 9:10]
                # semantic_gt = rays_all[si][0, :, 9:10]
                # np.save(to_occ_gt_path, np.concatenate([pts, semantic_gt], axis=1))

                pcd_gt = o3d.geometry.PointCloud()
                pcd_gt.points = o3d.utility.Vector3dVector(pts_gt.numpy())
                pcd_gt.paint_uniform_color([0, 1, 0])  # 绿色
                o3d.io.write_point_cloud(to_occ_gt_path, pcd_gt)

                voxel_map = {
                    "origin": (model.bx - model.dx / 2).to(device),
                    "size": (model.dx * (model.nx - 1)).to(device),
                    "dx": model.dx.to(device),
                    # "origin": (model_bx - model_dx / 2).to(device),
                    # "size": (model_dx * (model_nx - 1)).to(device),
                    # "dx": model_dx.to(device),
                    "data": voxel_map_data[0][si:si + 1, ...],
                    "all_rays": rays[0, si:si + 1, :, :].view(-1, rays.shape[-1]).to(device),
                    "rots": rots[0, si * roi_num:si * roi_num + 1, ...],
                    "trans": trans[0, si * roi_num:si * roi_num + 1, ...],
                    "intrins": intrins[0, si * roi_num:si * roi_num + 1, ...],
                    "post_rots": post_rots[0, si * roi_num:si * roi_num + 1, ...],
                    "post_trans": post_trans[0, si * roi_num:si * roi_num + 1, ...],
                    # "valid_mask": valid_mask_coo[si:si + 1, ...]
                }
                if 1:
                    all_rays = rays[0,si,idx,:].view(-1,rays.shape[-1]).to(device)                     # 确定渲染的是第几帧的rays
                    sample = {
                        "rays": torch.cat(
                            (all_rays[:, :8], all_rays[:, 9:11],all_rays[:, 15:17]), dim=-1
                        ),
                        "ts": all_rays[:,17],       # delta_t
                        # "ts": torch.ones_like(all_rays[:, -1]).long()*0.,
                        "rgbs": all_rays[:, -3:],     # 索引错的,但是不影响--rgb loss没用上
                        "semantics": all_rays[:, 8],
                    }
                    # pts_generate, depth_loss = neuconw_helper.generate_depth(sample, voxel_map, 0, args.local_rank)  # 由渲染的depth得到预测点
                    # print(">>>>>>>>>>>>>>depth_loss:",depth_loss.mean())
                    # if depth_loss.mean() > 0.2 : print('--imgname--', imgname)
                    # # depth_loss_mean_list.append(depth_loss.mean().detach().cpu().numpy())
                    # # count_list.append(count)

                    # pts_pred = o3d.geometry.PointCloud()
                    # pts_pred.points = o3d.utility.Vector3dVector(np.array(pts_generate.detach().cpu().numpy()))
                    # pts_pred.paint_uniform_color([0, 0, 1])

                    # idx_high_loss = np.where(depth_loss.cpu().numpy()>1.25)  #>0.5
                    # idx_mid_loss = np.where((depth_loss.cpu().numpy()>0.2)*(depth_loss.cpu().numpy()<=1.25))  #0.2~0.5
                    # idx_low_loss = np.where(depth_loss.cpu().numpy()<0.2)   #<0.2
                    # # idx_lower_loss = np.where(depth_loss.cpu().numpy()<0.2)   #<0.2

                    # np.asarray(pts_pred.colors)[idx_high_loss, :] = [1, 0, 0]
                    # np.asarray(pts_pred.colors)[idx_mid_loss, :] = [1, 1, 0]
                    # np.asarray(pts_pred.colors)[idx_low_loss, :] = [0, 1, 0]

                    # # o3d.io.write_point_cloud(
                    # #     f"/home/algo/1/1/debug_pts_gen_car_" + imgname.split('.jpg')[0] + ".ply", pts_pred)
                    # o3d.io.write_point_cloud(os.path.join(to_occ_pred_dir + imgname.replace('.jpg', '_pred.ply')), pts_pred)

                if 1:
                    out_info = extract_alpha(
                        voxel_map, dim=512,  # np.int(np.round(self.scene_config["radius"]/(3**(1/3))/0.1))
                        # chunk=16384,
                        chunk=8192,
                        with_color=False,
                        embedding_a=neuconw_helper.embedding_a((torch.ones(1).cuda() * 1).long()),
                        renderer=neuconw_helper.renderer
                    )

                    # mesh, out_info = extract_mesh2(voxel_map, renderer=neuconw_helper.renderer)
                    np.save(to_occ_pred_path, out_info)

                    # mesh.export(to_mesh_path)
                    # mesh = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(
                    # mesh.vertices.copy()),
                    # triangles=o3d.utility.Vector3iVector(
                    #     mesh.faces.copy()))
                    # mesh.compute_vertex_normals()

                    # for idx_v in range(n_view):
                    #     if idx_v == 0:
                    #         vis.add_geometry(mesh, True)
                    #         vis.add_geometry(pcd_gt, True)
                    #     else:
                    #         vis.add_geometry(mesh, True)

                    #     view_control = get_view_control(vis, idx_v)
                    #     vis.poll_events()
                    #     vis.update_renderer()
                    #     # vis.run()
                    #     mesh_capture_img = vis.capture_screen_float_buffer(True)
                    #     vis.clear_geometries()
                    #     mesh_capture_img = np.array(np.asarray(mesh_capture_img)[..., ::-1] * 255, dtype=np.uint8)
                    #     output_img[:, mesh_capture_img.shape[1] * idx_v:mesh_capture_img.shape[1] * (idx_v + 1),:] = mesh_capture_img
                    #     output_img_resize = cv2.resize(output_img, (out_shape[0], osr_hh))
                    #     output_img_merge[hh:, :] = output_img_resize

            cv2.imwrite(to_img_path, output_img_merge)
            # videoWriter.write(output_img_merge)
            # c = cv2.waitKey(1)%0x100
            # if c == 27:
            #     break
            print(1)
            count += 1


if __name__ == '__main__':
    main()


**老模型-mmcv [load_checkpoint] 加载模型 **

"""
Copyright (C) 2020 NVIDIA Corporation.  All rights reserved.
Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot.
Authors: Jonah Philion and Sanja Fidler
"""

import os
from pathlib import Path 
from collections import OrderedDict
import numpy as np
import torch
# from src.models_goe_1129_nornn_2d_2 import compile_model
from src.models_goe_1129_nornn_v8 import compile_model
from src.data_tfmap_newcxy_ori import compile_data
# from src.data_tfmap_newcxy_nextmask2 import compile_data
import cv2

import open3d as o3d
import json
from src.config.defaults import get_cfg_defaults
from src.options import get_opts
from src.utils.visualization import  extract_alpha
from src.rendering.neuconw_helper import NeuconWHelper

from mmcv.runner import load_checkpoint

"  推理关闭数据层train_sampler --  # train_sampler = val_sampler = None"


os.environ["CUDA_VISIBLE_DEVICES"] = "4"
os.environ['RANK'] = "0"
os.environ['WORLD_SIZE'] = "1"
os.environ['MASTER_ADDR'] = "localhost"
os.environ['MASTER_PORT'] = "12331"
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

pi = 3.1415926

def convert_rollyawpitch_to_rot(roll, yaw, pitch):
    roll *= pi/180.
    yaw *= pi/180.
    pitch *= pi/180.
    Rr = np.array([[0.0, -1.0, 0.0],
                   [0.0, 0.0, -1.0],
                   [1.0, 0.0, 0.0]], dtype=np.float32)
    Rx = np.array([[1.0, 0.0, 0.0],
                   [0.0, np.cos(roll), np.sin(roll)],
                   [0.0, -np.sin(roll), np.cos(roll)]], dtype=np.float32)
    Ry = np.array([[np.cos(pitch), 0.0, -np.sin(pitch)],
                   [0.0, 1.0, 0.0],
                   [np.sin(pitch), 0.0, np.cos(pitch)]], dtype=np.float32)
    Rz = np.array([[np.cos(yaw), np.sin(yaw), 0.0],
                   [-np.sin(yaw), np.cos(yaw), 0.0],
                   [0.0, 0.0, 1.0]], dtype=np.float32)
    R = np.matrix(Rr) * np.matrix(Rx) * np.matrix(Ry) * np.matrix(Rz)
    return R

def get_view_control(vis, idx):
    view_control = vis.get_view_control()
    if idx == 0:
        ### cam view
        # view_control.set_front([-1, 0, 0])
        # view_control.set_lookat([8, 0, 2])
        # view_control.set_up([0, 0, 1])
        # view_control.set_zoom(0.025)
        # view_control.rotate(0, 2100 / 40)

        ### bev observe object depth
        view_control.set_front([-1, 0, 1])
        view_control.set_lookat([30, 0, 0])
        view_control.set_up([0, 0, 1])
        view_control.set_zoom(0.3)
        view_control.rotate(0, 2100 / 20)

    elif idx == 1:
        view_control.set_front([-1, 0, 0])
        view_control.set_lookat([8, 0, 0])
        # view_control.set_lookat([8, 0, 2])  ### look down
        view_control.set_up([0, 0, 1])
        view_control.set_zoom(0.025)
        view_control.rotate(0, 2100 / 40)
    return view_control

def main():
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--local_rank", default = 0, type=int)
    # args = parser.parse_args()

    args = get_opts()
    config = get_cfg_defaults()
    config.merge_from_file(args.cfg_path)

    args.local_rank = 1
    print("sssss",args.local_rank)
    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
        device=torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method='env://')

    # model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231128_nornn_2d_2_st_v0_1bag_bsz4_rays800_data_tfmap_newcxy_ori"
    model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231128_nornn_2d_2_st_v0_10bag_bsz4_rays800"
    # model_path = "/home/algo/mnt/xzq/occ_project/occ_nerf_st/checkpoints/nerf_1204_nornn_v8_st_pretrain_data_tfmap_newcxy_nextmask2_1bag_adjustnearfar_newcondition"  # adjust_nearfar1
    
    model_name = "model_20000.pt"
    ckpt_path = model_path + "/checkpts/" + model_name

    to_result_path = "result/" + model_path.split('/')[-1] + '/' + model_name.split('.')[0] + '_p2'

    viz_train = False
    viz_gnd = False
    viz_osr = True


    bsz=1
    seq_len=5
    nworkers=6
    sample_num = 512
    datatype = "single"    #multi   single

    version = "0"
    # dataroot = "/home/algo/dataSpace/NeRF/bev_ground/data/aishare/share"
    #dataroot='/defaultShare/user-data'
    dataroot = "/data/zjj/data/aishare/share"

    xbound=[0.0, 96., 0.5]
    ybound=[-12.0, 12.0, 0.5]
    zbound=[-3.0, 5.0, 0.5]
    dbound=[3.0, 103.0, 2.]
    grid_conf = {
        'xbound': xbound,
        'ybound': ybound,
        'zbound': zbound,
        'dbound': dbound,
    }

    data_aug_conf = {
                'resize_lim': [(0.05, 0.4), (0.3, 0.90)],#(0.3-0.9)
                'final_dim': (128, 352),
                'rot_lim': (-5.4, 5.4),
                # 'H': H, 'W': W,
                'rand_flip': False,
                'bot_pct_lim': [(0.04, 0.35), (0.15, 0.4)],
                # 'bot_pct_lim': [(0.04, 0.35), (0.4, 0.4)],
                'cams': ['CAM_FRONT0', 'CAM_FRONT1'],
                'Ncams': 2,
            }


    train_sampler, val_sampler,trainloader, valloader = compile_data(version, dataroot, data_aug_conf=data_aug_conf,
                      grid_conf=grid_conf, bsz=bsz, seq_len=seq_len, sample_num=sample_num, nworkers=nworkers,
                      parser_name='segmentation1data', datatype=datatype)
    loader = trainloader if viz_train else valloader

    model = compile_model(grid_conf, data_aug_conf, seq_len=seq_len, batchsize=int(bsz), config=config, args=args, phase='validation')
    checkpoint = load_checkpoint(model, ckpt_path, map_location='cpu')

# #------------------------------
#     checkpoint = torch.load(ckpt_path)
#     new_state_dict = OrderedDict()
#     for k, v in checkpoint.items():

#         if "neuconw_helper" in k:
#             # name = k[22:]  # remove "neuconw_helper.module."
#             name = k[15:]  # remove "neuconw_helper."
#             print(k, name)
#             continue
#         elif "module." in k:
#             name = k[7:]  # remove "module."
#             print(k)
#         else:
#             name = k
#         new_state_dict[name] = v

#     model.load_state_dict(new_state_dict, True)
# #------------------------------

    
    model.to(device)
    neuconw_helper = NeuconWHelper(args, config, model.neuconw, model.embedding_a, None)

    ww = 160
    hh = 480
    model.eval()
    fps = 30
    flourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
    width = int(3715*300./1110)
    n_view = 2
    roi_num = 2
    osr_hh = int((width + ww * 6)/1853/2*1025)
    if viz_gnd:
        if viz_osr:
            out_shape = (width + ww * 6, hh + osr_hh)
        else:
            out_shape = (width + ww * 6, hh)
    else:
        if viz_osr:
            out_shape = (width + ww * 6, 1080)
        else:
            out_shape = (0, 0)

    colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255)]
    # vis = o3d.visualization.Visualizer()
    # vis.create_window(window_name='bev')
    cur_sce_name = None

    count = 0
    with torch.no_grad():

        for batchi, (imgs, rots, trans, intrins, dist_coeffss, post_rots, post_trans, cam_pos_embeddings, binimg, lf_label,   lf_norm,   fork_scale,    fork_offset, fork_ori, rays, pose_mats_2d, pose_mats_3d, img_paths, sce_name) in enumerate(valloader):

            if sce_name[0] != cur_sce_name:
                sname = '_'.join(sce_name[0].split('/')[-6:-3])
                # output_path = model_path + "/result/" + model_name.split('.')[0] + "/" + sname + '_roi3'
                output_path = to_result_path + "/" + sname
                os.makedirs(output_path, exist_ok=True)
                to_video_path = output_path + "/demo_" + sname + "_train.mp4"
                print(to_video_path)
                to_occ_gt_dir = output_path + '/occ_gts/'
                to_mesh_dir = output_path + '/meshes/'
                to_occ_pred_dir = output_path + '/occ_preds/'
                to_img_dir = output_path + '/img_result/'
                # if cur_sce_name is not None:
                #     videoWriter.release()
                # videoWriter = cv2.VideoWriter(to_video_path, flourcc, fps, out_shape)
                os.makedirs(to_occ_gt_dir, exist_ok=True)
                os.makedirs(to_occ_pred_dir, exist_ok=True)
                os.makedirs(to_mesh_dir, exist_ok=True)
                os.makedirs(to_img_dir, exist_ok=True)
                cur_sce_name = sce_name[0]

            voxel_map_data = model(imgs.to(device),
                                rots.to(device),
                                trans.to(device),
                                intrins.to(device),
                                dist_coeffss.to(device),
                                post_rots.to(device),
                                post_trans.to(device),
                                cam_pos_embeddings.to(device),
                                fork_scale.to(device),
                                fork_offset.to(device),
                                fork_ori.to(device),
                                rays,
                                pose_mats_2d.to(device),
                                0,
                                'validation'
                                )

            output_img_merge = np.zeros((out_shape[1], out_shape[0], 3), dtype=np.uint8)
            if viz_gnd:
                print('viz_gnd')
                # norm_mask = (lf_norm_gt > -500)
                binimgs = binimgs.cpu().numpy()
                lf_pred = lf_preds[:, :, :1].contiguous()
                lf_norm = lf_preds[:, :, 1:(1+4)].contiguous()

                seg_out = seg_preds.sigmoid()
                seg_out = seg_out.cpu().numpy()

                lf_out = lf_pred.sigmoid().cpu().numpy()
                lf_norm = lf_norm.cpu().numpy()

                H, W = 944, 1824
                fH, fW = data_aug_conf['final_dim']
                crop0 = []
                crop1 = []
                for cam_idx in range(2):
                    resize = np.mean(data_aug_conf['resize_lim'][cam_idx])
                    resize_dims = (int(fW / resize), int(fH / resize))
                    newfW, newfH = resize_dims
                    # print(newfW, newfH)
                    crop_h = int((1 - np.mean(data_aug_conf['bot_pct_lim'][cam_idx])) * H) - newfH
                    crop_w = int(max(0, W - newfW) / 2)
                    if cam_idx == 0:
                        crop0 = (crop_w, crop_h, crop_w + newfW, crop_h + newfH)
                    else:
                        crop1 = (crop_w, crop_h, crop_w + newfW, crop_h + newfH)

                si = seq_len - 1
                imgname = img_paths[si][0][img_paths[si][0].rfind('/')+1 :]
                print('imgname = ', img_paths[-si][0])
                img_org = cv2.imread(img_paths[si][0])

                imgpath = img_paths[si][0][: img_paths[si][0].rfind('org/')-1]
                param_path = imgpath + '/gen/param_infos.json'
                param_infos = {}
                with open(param_path, 'r') as ff :
                    param_infos = json.load(ff)
                yaw = param_infos['yaw']
                pitch = param_infos['pitch']
                if pitch == 0.789806:
                    pitch = -pitch
                roll = param_infos['roll']
                tran = np.array(param_infos['xyz'])

                H, W = param_infos['imgH_ori'], param_infos['imgW_ori']
                ori_K       = np.array(param_infos['ori_K'],dtype=np.float64).reshape(3,3)
                dist_coeffs = np.array(param_infos['dist_coeffs']).astype(np.float64)

                # cam2car_matrix
                rot = convert_rollyawpitch_to_rot(roll, yaw, pitch).I
                cam2car = np.eye(4, dtype= np.float64)
                cam2car[:3, :3] = rot
                cam2car[:3, 3] = tran.T

                norm = lf_norm[0, 4]
                fork = lf_out[0, 4]
                img_res = np.ones((480, 160, 3), dtype=np.uint8)
                colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(0, 255, 255)]
                for class_id in range(6):
                    result = seg_out[0][si][class_id]
                    if class_id == 5:
                        img_res[result> 0.4] = np.array(colors[class_id])
                    else:
                        img_res[result> 0.4] = np.array(colors[class_id])

                    ys, xs = np.where(result > 0.4)
                    pt = np.array([ys*0.2125, 0.125*xs-10, np.zeros(ys.shape), np.ones(ys.shape)])
                    if pt.shape[1] == 0:
                        continue
                    car2cam = np.matrix(cam2car).I.dot(pt)[:3, :]

                    rvec, tvec = np.array([0,0,0], dtype=np.float32), np.array([0,0,0], dtype=np.float32)
                    cam2img, _ = cv2.projectPoints(np.array(car2cam.T), rvec, tvec, ori_K, dist_coeffs)

                    for ii in range(cam2img.shape[0]):
                        ptx = round(cam2img[ii,0,0])
                        pty = round(cam2img[ii,0,1])
                        cv2.circle(img_org, (ptx, pty), 3, colors[class_id], -1)


                    # gt = binimgs[0][si][class_id]
                    # img_res[gt< -0.5] = np.array((128,128,128))
                img_res = cv2.flip(cv2.flip(img_res, 0), 1)

                img_gt = np.ones((480, 160, 3), dtype=np.uint8)
                for class_id in range(6):
                    result = binimgs[0][si][class_id]
                    img_gt[result> 0.5] = np.array(colors[class_id])
                    img_gt[result< -0.5] = np.array((128,128,128))


                img_gt = cv2.flip(cv2.flip(img_gt, 0), 1)

                cv2.rectangle(img_org, (int(crop0[0]), int(crop0[1])), (int(crop0[2]), int(crop0[3])), (0,255,255), 2)
                cv2.rectangle(img_org, (int(crop1[0]), int(crop1[1])), (int(crop1[2]), int(crop1[3])), (0,255,0), 2)
                img_org = cv2.resize(img_org, (width, hh))
                img_org_show = np.zeros((hh, width+ww*6, 3), dtype=np.uint8)*255
                img_org_show[:, ww*6:] = img_org

                outs = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                outs1 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                outs2 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts1 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts2 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)

                ys, xs = np.where(lf_label_gt[0, si, 0] > -0.5)
                ys1, xs1 = np.where(lf_label_gt[0, si, 0] > 0.5)
                ys2, xs2 = np.where(lf_out[0, si, 0] > 0.5)


                gts[si][binimgs[0, si, 0] > 0.5] = np.array(colors[0])
                outs[si][seg_out[0, si, 0] > 0.5] = np.array(colors[0])

                gts[si][binimgs[0, si, 4] > 0.6] = np.array(colors[4])
                outs[si][seg_out[0, si, 4] > 0.6] = np.array(colors[4])

                gts[si][binimgs[0, si, 5] > 0.6] = np.array(colors[5])
                outs[si][seg_out[0, si, 5] > 0.6] = np.array(colors[5])

                valid_mask = np.sum(gts[si], axis=-1) > 0
                labels = np.where(valid_mask[ys, xs]> 0.5)
                ys = ys[labels]
                xs = xs[labels]
                gts1[si][ys1, xs1, :] = 255

                mask = torch.squeeze(lf_norm_gt[:,si,0])
                # gts2[si][mask < -500] = (128, 128, 128)
                if xs.shape[0] > 0:
                    for mm in range(0, xs.shape[0], 2):
                        # for mm in range(0, 800, 100):
                        y = ys[mm]
                        x = xs[mm]
                        norm = lf_norm_gt[0, si, 0:2, y, x].numpy()
                        if norm[0] == -999.:
                            continue
                        cv2.line(gts2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (0, 255, 0),1)
                        norm = lf_norm_gt[0, si, 2:4, y, x].numpy()
                        cv2.line(gts2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (255, 0, 0),1)
                        # print (norm)
                        # cv2.circle(gts2[si], (x, y), 3, (0, 255, 255))


                # ys, xs = np.where(np.logical_or(seg_out[0][si][0] > 0.5, seg_out[0][si][5] > 0.5))
                # ys, xs = np.where(np.logical_or(seg_out[0][si][0] > -0.5, seg_out[0][si][5] > -0.5))
                valid_mask = np.sum(outs[si], axis=-1) > 0
                labels = np.where(valid_mask[ys, xs]> 0.5)
                ys = ys[labels]
                xs = xs[labels]
                outs1[si][ys2, xs2, :] = 255
                if xs.shape[0] > 0:
                    for mm in range(0, xs.shape[0], 2):
                        y = ys[mm]
                        x = xs[mm]
                        norm = lf_norm[0, si, 0:2, y, x] / 5.
                        # print (norm)
                        cv2.line(outs2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (0, 255, 0),1)
                        norm = lf_norm[0, si, 2:4, y, x] / 5.
                        cv2.line(outs2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (255, 0, 0),1)

                # gts2[si][lf_label_gt[0, si, 0] < -0.5] = (128,128,128)
                # gts1[si][lf_label_gt[0, si, 0] < -0.5] = (128,128,128)

                img_org_show[:, :ww] = img_res
                img_org_show[:, ww:ww*2] = img_gt
                img_org_show[:, ww*2:ww*3] = cv2.flip(cv2.flip(outs2[si], 0), 1)
                img_org_show[:, ww*3:ww*4] = cv2.flip(cv2.flip(gts2[si], 0), 1)
                img_org_show[:, ww*4:ww*5] = cv2.flip(cv2.flip(outs1[si], 0), 1)
                img_org_show[:, ww*5:ww*6] = cv2.flip(cv2.flip(gts1[si], 0), 1)

                cv2.putText(img_org_show, "NAME:" + imgname + 'seq_id: '+ str(si), (700+320, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
                # print(idxs)

                output_img_merge[:img_org_show.shape[0], :] = img_org_show


            if viz_osr:
                # si = seq_len - 1
                si = 0
                imgname = img_paths[si][0][img_paths[si][0].rfind('/')+1 :]
                # print('imgname = ', img_paths[-si][0])
                output_img = np.zeros((1025, 1853*2, 3), dtype=np.uint8)
                to_occ_gt_path = to_occ_gt_dir + imgname.replace('.jpg', '.ply')
                to_occ_pred_path = to_occ_pred_dir + imgname.replace('.jpg', '.ply')
                to_mesh_path = to_mesh_dir + imgname.replace('.jpg', '.ply')
                to_img_path = to_img_dir + imgname
                to_bin_path = to_img_dir + imgname.replace('.jpg', '.bin')
                idx = rays[0, si, :, 15] < 1

                pts_gt = rays[0, si, idx, 0:3] + rays[0, si, idx, 3:6]*rays[0, si, idx, 9:10]  # gt_pts
                semantic_gt = rays[0, si, idx, 8].view(-1,1)

                # pts = rays_all[si][0, :, :3] + rays_all[si][0, :, 3:6] * rays_all[si][0, :, 9:10]
                # semantic_gt = rays_all[si][0, :, 9:10]
                # np.save(to_occ_gt_path, np.concatenate([pts, semantic_gt], axis=1))

                pcd_gt = o3d.geometry.PointCloud()
                pcd_gt.points = o3d.utility.Vector3dVector(pts_gt.numpy())
                pcd_gt.paint_uniform_color([0, 1, 0])  # 绿色
                o3d.io.write_point_cloud(to_occ_gt_path, pcd_gt)

                voxel_map = {
                    "origin": (model.bx - model.dx / 2).to(device),
                    "size": (model.dx * (model.nx - 1)).to(device),
                    "dx": model.dx.to(device),
                    # "origin": (model_bx - model_dx / 2).to(device),
                    # "size": (model_dx * (model_nx - 1)).to(device),
                    # "dx": model_dx.to(device),
                    "data": voxel_map_data[0][si:si + 1, ...],
                    "all_rays": rays[0, si:si + 1, :, :].view(-1, rays.shape[-1]).to(device),
                    "rots": rots[0, si * roi_num:si * roi_num + 1, ...],
                    "trans": trans[0, si * roi_num:si * roi_num + 1, ...],
                    "intrins": intrins[0, si * roi_num:si * roi_num + 1, ...],
                    "post_rots": post_rots[0, si * roi_num:si * roi_num + 1, ...],
                    "post_trans": post_trans[0, si * roi_num:si * roi_num + 1, ...],
                    # "valid_mask": valid_mask_coo[si:si + 1, ...]
                }
                all_rays = rays[0,si,idx,:].view(-1,rays.shape[-1]).to(device)                     # 确定渲染的是第几帧的rays
                sample = {
                    "rays": torch.cat(
                        (all_rays[:, :8], all_rays[:, 9:11],all_rays[:, 15:17]), dim=-1
                    ),
                    "ts": all_rays[:,17],       # delta_t
                    # "ts": torch.ones_like(all_rays[:, -1]).long()*0.,
                    "rgbs": all_rays[:, -3:],     # 索引错的,但是不影响--rgb loss没用上
                    "semantics": all_rays[:, 8],
                }
                # pts_generate, depth_loss = neuconw_helper.generate_depth(sample, voxel_map, 0, args.local_rank)  # 由渲染的depth得到预测点 
                # print(">>>>>>>>>>>>>>depth_loss:",depth_loss.mean())
                # if depth_loss.mean() > 0.2 : print('--imgname--', imgname)
                # # depth_loss_mean_list.append(depth_loss.mean().detach().cpu().numpy())
                # # count_list.append(count)

                # pts_pred = o3d.geometry.PointCloud()
                # pts_pred.points = o3d.utility.Vector3dVector(np.array(pts_generate.detach().cpu().numpy()))
                # pts_pred.paint_uniform_color([0, 0, 1]) 

                # idx_high_loss = np.where(depth_loss.cpu().numpy()>1.25)  #>0.5
                # idx_mid_loss = np.where((depth_loss.cpu().numpy()>0.2)*(depth_loss.cpu().numpy()<=1.25))  #0.2~0.5
                # idx_low_loss = np.where(depth_loss.cpu().numpy()<0.2)   #<0.2
                # # idx_lower_loss = np.where(depth_loss.cpu().numpy()<0.2)   #<0.2

                # np.asarray(pts_pred.colors)[idx_high_loss, :] = [1, 0, 0]
                # np.asarray(pts_pred.colors)[idx_mid_loss, :] = [1, 1, 0]
                # np.asarray(pts_pred.colors)[idx_low_loss, :] = [0, 1, 0]

                # # o3d.io.write_point_cloud(
                # #     f"/home/algo/1/1/debug_pts_gen_car_" + imgname.split('.jpg')[0] + ".ply", pts_pred)
                # o3d.io.write_point_cloud(os.path.join(to_occ_pred_dir + imgname.replace('.jpg', '_pred.ply')), pts_pred)

                if 1:
                    out_info = extract_alpha(
                        voxel_map, dim=512,  # np.int(np.round(self.scene_config["radius"]/(3**(1/3))/0.1))
                        chunk=16384,
                        with_color=False,
                        embedding_a=neuconw_helper.embedding_a((torch.ones(1).cuda() * 1).long()),
                        renderer=neuconw_helper.renderer,
                        # model=model
                    )

                    # mesh, out_info = extract_mesh2(voxel_map, renderer=neuconw_helper.renderer)
                    np.save(to_occ_pred_path, out_info)
                    occ_pred = out_info.numpy()
                    _, alpha_static, alpha_transient, valid_masks = occ_pred[:, :3], occ_pred[:, 3], occ_pred[:, 4], occ_pred[:,5]
                    # output_mask = valid_masks * np.logical_and((alpha_transient > 0.2), alpha_transient < 1)
                    output_mask = valid_masks * (alpha_transient > 0.2)
                    out_for_vis = occ_pred[output_mask > 0, :5]
                    np.savetxt(Path(to_occ_pred_path).with_suffix('.txt'), out_for_vis)

                    # mesh.export(to_mesh_path)
                    # mesh = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(
                    # mesh.vertices.copy()),
                    # triangles=o3d.utility.Vector3iVector(
                    #     mesh.faces.copy()))
                    # mesh.compute_vertex_normals()

                    # for idx_v in range(n_view):
                    #     if idx_v == 0:
                    #         vis.add_geometry(mesh, True)
                    #         vis.add_geometry(pcd_gt, True)
                    #     else:
                    #         vis.add_geometry(mesh, True)

                    #     view_control = get_view_control(vis, idx_v)
                    #     vis.poll_events()
                    #     vis.update_renderer()
                    #     # vis.run()
                    #     mesh_capture_img = vis.capture_screen_float_buffer(True)
                    #     vis.clear_geometries()
                    #     mesh_capture_img = np.array(np.asarray(mesh_capture_img)[..., ::-1] * 255, dtype=np.uint8)
                    #     output_img[:, mesh_capture_img.shape[1] * idx_v:mesh_capture_img.shape[1] * (idx_v + 1),:] = mesh_capture_img
                    #     output_img_resize = cv2.resize(output_img, (out_shape[0], osr_hh))
                    #     output_img_merge[hh:, :] = output_img_resize

            cv2.imwrite(to_img_path, output_img_merge)
            # videoWriter.write(output_img_merge)
            # c = cv2.waitKey(1)%0x100
            # if c == 27:
            #     break
            # print(1)
            count += 1


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/581839.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

69、栈-有效的括号

思路&#xff1a; 有效的括号序列是指每个开括号都有一个对应的闭括号&#xff0c;并且括号的配对顺序正确。 比如&#xff1a;({)} 这个就是错误的&#xff0c;({}) 这个就是正确的。所以每一个做括号&#xff0c;必有一个对应的右括号&#xff0c;并且需要顺序正确。这里有…

Meilisearch 快速入门(Windows 环境) 搜索引擎 语义搜索

Meilisearch 快速入门(Windows 环境)# 简介# Meilisearch 是一个基于 rust 开发的,快速的、完全开源的轻量级搜索引擎。它的数据存储基于磁盘与内存映射,不受 RAM 限制。在一定数量级下,搜索速度不逊于 Elasticsearch。 下载# 官方服务端包下载地址:github.com/meili…

常用图像加密技术-流密码异或加密

异或加密是最常用的一种加密方式&#xff0c;广泛的适用于图像处理领域。这种加密方式依据加密密钥生成伪随机序列与图像的像素值进行异或操作&#xff0c;使得原像素值发生变化&#xff0c;进而使得图像内容发生变化&#xff0c;达到保护图像内容的目的。 该加密方法是以图像…

鸿蒙OpenHarmony【小型系统 烧录】(基于Hi3516开发板)

烧录 针对Hi3516DV300开发板&#xff0c;除了DevEco Device Tool&#xff08;操作方法请参考烧录)&#xff09;外&#xff0c;还可以使用HiTool进行烧录。 前提条件 开发板相关源码已编译完成&#xff0c;已形成烧录文件。客户端&#xff08;操作平台&#xff0c;例如Window…

深度学习模型的优化和调优de了解

深度学习模型的优化和调优&#xff1a;随着深度学习应用的广泛&#xff0c;优化和调优神经网络模型成为了一个重要的问题。这包括选择合适的网络架构、调整超参数、应对过拟合等。 深度学习模型的优化和调优是指在训练神经网络模型时&#xff0c;通过一系列技术和方法来提高模型…

FTP 文件传输协议

FTP 文件传输协议 作用 用来传输文件的 FTP协议采用的是TCP作为传输协议&#xff0c; 21号端口用来传输FTP控制命令的&#xff0c; 20号端口用来传输文件数据的 FTP传输模式&#xff1a; 主动模式&#xff1a; FTP服务端接收下载控制命令后&#xff0c;会主动从tcp/20号端口…

C语言之详细讲解文件操作

什么是文件 与普通文件载体不同&#xff0c;文件是以硬盘为载体存储在计算机上的信息集合&#xff0c;文件可以是文本文档、图片、程序等等。文件通常具有点三个字母的文件扩展名&#xff0c;用于指示文件类型&#xff08;例如&#xff0c;图片文件常常以KPEG格式保存并且文件…

修改word文件的创作者方法有哪些?如何修改文档的作者 这两个方法你一定要知道

在数字化时代&#xff0c;文件创作者的信息往往嵌入在文件的元数据中&#xff0c;这些元数据包括创作者的姓名、创建日期以及其他相关信息。然而&#xff0c;有时候我们可能需要修改这些创作者信息&#xff0c;出于隐私保护、版权调整或者其他实际需求。那么&#xff0c;有没有…

短信验证码绕过漏洞(一)

短信验证码绕过漏洞 0x01原理&#xff1a; 服务器端返回的相关参数作为最终登录凭证&#xff0c;导致可绕过登录限制。 危害&#xff1a;在相关业务中危害也不同&#xff0c;如找回密码&#xff0c;注册&#xff0c;电话换绑等地方即可形成高危漏洞&#xff0c;如果是一些普…

常用算法代码模板 (3) :搜索与图论

AcWing算法基础课笔记与常用算法模板 (3) ——搜索与图论 常用算法代码模板 (1) &#xff1a;基础算法 常用算法代码模板 (2) &#xff1a;数据结构 常用算法代码模板 (3) &#xff1a;搜索与图论 常用算法代码模板 (4) &#xff1a;数学知识 文章目录 0 搜索技巧1 树与图的存…

【Scala---01】Scala『 Scala简介 | 函数式编程简介 | Scala VS Java | 安装与部署』

文章目录 1. Scala简介2. 函数式编程简介3. Scala VS Java4. 安装与部署 1. Scala简介 Scala是由于Spark的流行而兴起的。Scala是高级语言&#xff0c;Scala底层使用的是Java&#xff0c;可以看做是对Java的进一步封装&#xff0c;更加简洁&#xff0c;代码量是Java的一半。 因…

MATLAB语音信号分析与合成——MATLAB语音信号分析学习资料汇总(图书、代码和视频)

教科书&#xff1a;MATLAB语音信号分析与合成&#xff08;第2版&#xff09; 链接&#xff08;含配套源代码&#xff09;&#xff1a;https://pan.baidu.com/s/1pXMPD_9TRpJmubPGaRKANw?pwd32rf 提取码&#xff1a;32rf 基础入门视频&#xff1a; 视频链接&#xff1a; 清…

MCU自动测量单元:自动化数据采集的未来

随着科技的飞速发展&#xff0c;自动化技术在各个领域中的应用日益广泛。其中&#xff0c;MCU(微控制器)自动测量单元以其高效、精准的特性&#xff0c;成为自动化数据采集领域的佼佼者&#xff0c;引领着未来数据采集技术的革新。本文将深入探讨MCU自动测量单元的原理、优势以…

Vue2 - 完成实现ElementUI中el-dialog弹窗的拖拽功能(宽度高度适配,且关闭后打开位置居中)

我们在做后台管理系统时常用到ElementUI 中的 el-Dialog,但是官方文档并未我们提供 el-Dialog弹窗如何实现拖拽功能,我们通常需要思考如何让用户能够自由地拖动弹窗,在页面上调整位置以获得更好的用户体验。在下面的博客文章中,我们将实现如何为 ElementUI 的 el-Dialog 弹…

网络安全 SQLmap-tamper的使用

目录 使用SQLmap Tamper脚本 1. 选择合适的Tamper脚本 2. 在命令行中使用Tamper脚本 3. 组合使用Tamper脚本 4. 注意和考虑 黑客零基础入门学习路线&规划 网络安全学习路线&学习资源 SQLmap是一款强大的自动化SQL注入和数据库取证工具。它用于检测和利用SQL注入漏…

大数据005-hadoop003-了解MR及Java的简单实现

了解MapReduce MapReduce过程分为两个阶段&#xff1a;map阶段、reduce阶段。每个阶段搜键-值对作为输入和输出。 要执行一个MR任务&#xff0c;需要完成map、reduce函数的代码开发。 Hellow World 【Hadoop权威指南】中的以分析气象数据为例&#xff0c;找到每年的最高气温。…

基于Springboot的校园博客系统

基于SpringbootVue的校园博客系统 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringbootMybatis工具&#xff1a;IDEA、Maven、Navicat 系统展示 用户登录 首页 文章信息 系统公告 后台登录 后台首页 博主管理 文章分类管理 文章信息管理 举报投诉管…

将要上市的自动驾驶新书《自动驾驶系统开发》中摘录片段

全书共分15章&#xff1a;第1章是自动驾驶系统的概述&#xff08;场景分类、开发路径和数据闭环等&#xff09;&#xff0c;第2章简介自动驾驶的基础理论&#xff0c;即计算机视觉和深度学习等&#xff0c;第3&#xff5e;4章是自动驾驶的软硬件平台分析&#xff0c;包括传感器…

面试:Redis

目录 一、缓存穿透 1、解决方案一&#xff1a; 2、解决方案二&#xff1a; 二、缓存击穿 1、解决方案一&#xff1a; 2、解决方案二&#xff1a; 三、缓存雪崩 1、解决方案一&#xff1a; 2、解决方案二&#xff1a; 3、解决方案三&#xff1a; 4、解决方案四&#x…

创建基于时间的 UUID

概述 在本文中&#xff0c;我们将会 对 UUIDs 和基于时间的 UUIDs&#xff08;time-based UUIDs&#xff09; 进行一些探讨。 当我们在对基于时间的 UUIDs 进行选择的时候&#xff0c;总会遇到一些好的方面和不好的方面&#xff0c;如何进行选择&#xff0c;也是我们将要简要…
最新文章