Commit 982474b1 by yuxguo

init

parents
# SCL-my
Rewrite SCL without thirdparty lib.
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : main.py
# Author : Honghua Dong
# Email : dhh19951@gmail.com
# Date : 11/06/2019
#
# Distributed under terms of the MIT license.
'''
# Usage
jac-run main.py -d $DATASET_DIR -t $TASK
'''
import argparse
import collections
import numpy as np
import json
import os
import os.path as osp
import pickle
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn import DataParallel
import logging
logging.basicConfig(format='[%(asctime)s, %(levelname)s]: %(message)s', level=logging.DEBUG)
# from jactorch.cli import dump_metainfo
# from jactorch.data.dataloader import JacDataLoader
# from jactorch.parallel import JacDataParallel
# from jactorch.train import TrainerEnv
from model.const import MAX_VALUE
# from model.dataset import get_dataset_name_and_num_features, load_data
from model.dataset import RAVENdataset, PGMdataset, ToTensor
from model.utils import Model, Observer
# from model.train import Trainer
# from model.utils import plot_curve, get_exp_name, get_image_title
# import matplotlib.pyplot as plt
# from sklearn.manifold import TSNE
parser = argparse.ArgumentParser()
seeds = parser.add_argument_group('Random Seeds')
seeds.add_argument('--random-seed', '-seed', type=int, default=None,
help='The random seed for random')
seeds.add_argument('--numpy-random-seed', '-nseed', type=int, default=None,
help='The random seed for np.random')
seeds.add_argument('--torch-random-seed', '-tseed', type=int, default=None,
help='The random seed for torch.random')
dataset_args = parser.add_argument_group('Dataset Related Args')
# Dataset Spec
dataset_args.add_argument('--dataset', type=str, required=True,
help='the dataset: [PGM, RAVEN]')
dataset_args.add_argument('--dataset_path', type=str, required=True,
help='the dataset path')
dataset_args.add_argument('--task', '-t', nargs='+', type=str, required=True,
choices=['center_single', 'up_down', 'left_right', 'in_out', 'in_distri',
'distribute_four', 'distribute_nine'], help='the task')
dataset_args.add_argument('--index-file-dir', '-ifd', type=str, default=None,
help='the dir containing index files')
dataset_args.add_argument('--train-index-file', '-tif', type=str, default=None,
help='the training index file defining the training dataset')
dataset_args.add_argument('--val-index-file', '-vif', type=str, default=None,
help='the validation index file defining the validation dataset')
dataset_args.add_argument('--dataset-size', '-ds', type=int, nargs='+',
default=None, help='the size of training/val dataset')
dataset_args.add_argument('--num-context', '-nco', type=int, default=None,
help='the number of context images')
dataset_args.add_argument('--num-candidates', '-nca', type=int, default=None,
help='the number of candidate images')
# dataset_args.add_argument('--split', '-sp', type=str, default='',
# help='the string specify the dataset split, empty for original split')
dataset_args.add_argument('--use-test-as-val', '-tv', action='store_true',
help='Use the test set during each validation time')
# Dataset Manupilation
dataset_args.add_argument('--trunc-train-data', '-ttd', type=int, default=None,
help='the truncated size of training dataset (default: None)')
dataset_args.add_argument('--trunc-val-data', '-tvd', type=int, default=None,
help='the truncated size of validation dataset (default: None)')
dataset_args.add_argument('--adjust-size', '-adj', action='store_true',
help='adjust size attr (+1) of the inputs')
# Misc
dataset_args.add_argument('--num-workers', '-w', type=int, default=None,
help='The number of workers for data loader')
trainer_args = parser.add_argument_group('Trainer Related Args')
# Basic
trainer_args.add_argument('--use-gpu', action='store_true',
help='use GPU or not')
trainer_args.add_argument('--epochs', '-e', type=int, default=200,
help='the number of epochs')
trainer_args.add_argument('--obs-epochs', '-oe', type=int, default=5,
help='the number of sub epochs for observation stage')
trainer_args.add_argument('--batch-size', '-bs', type=int, default=128,
help='input batch size for training (default: 128)')
trainer_args.add_argument('--eval-batch-size', '-ebs', type=int, default=32,
help='input batch size for evaluation (default: 32)')
trainer_args.add_argument('--lr', '-lr', type=float, default=0.005,
help='learning rate (default: 0.005)')
trainer_args.add_argument('--obs-lr', '-olr', type=float, default=None,
help='learning rate for observation (default: None, recommend 0.2)')
trainer_args.add_argument('--observe-interval', '-oi', type=int, default=None,
help='The interval of doing observation for feature space (default: None)')
trainer_args.add_argument('--v2s-lr', '-vlr', type=float, default=None,
help='learning rate (default: None, the same as lr)')
trainer_args.add_argument('--lr-anneal-start', '-lrs', type=int, default=None,
help='The epoch when learning rate start annealing (default: None)')
trainer_args.add_argument('--lr-anneal-interval', '-lri', type=int, default=1,
help='The interval of learning rate annealing (default: 1)')
trainer_args.add_argument('--lr-anneal-ratio', '-lrr', type=float, default=1.0,
help='learning rate annealing ratio (default: 1.0)')
trainer_args.add_argument('--weight-decay', '-wd', type=float, default=0,
help='The weight decay factor')
trainer_args.add_argument('--obs-val-only', '-ov', action='store_true',
help='observe the val dataset only if True')
trainer_args.add_argument('--obs-thresh', '-ot', type=float, default=0.25,
help='The thresh used by observation module')
# Save & Load
trainer_args.add_argument('--load', '-l', type=str, default=None,
help='load the weights from a pretrained model (default: none)')
trainer_args.add_argument('--save-interval', '-si', type=int, default=50,
help='model save interval (epochs) (default: 50)')
# Utils
trainer_args.add_argument('--monitor-grads', '-mg', action='store_true',
help='monitor the grad of weights if True')
trainer_args.add_argument('--test-only', '-test', action='store_true',
help='Test only')
trainer_args.add_argument('--test-obs', '-to', action='store_true',
help='Test obs')
trainer_args.add_argument('--dump-dir', '-du', type=str, default='dumps',
help='The dir to dump meters/fail-cases/fig/checkpoints')
trainer_args.add_argument('--resume-dir', '-rd', type=str, default=None,
help='The dir to dump last epoch info that used for resume')
trainer_args.add_argument('--disable-resume', '-dr', action='store_true',
help='disable resume from saved checkpoint')
trainer_args.add_argument('--dump-fail-cases', '-dfc', action='store_true',
help='dump fail cases if True')
trainer_args.add_argument('--image-title', '-it', type=str, default=None,
help='the title of the plot image')
trainer_args.add_argument('--extra', '-ex', type=str, default='',
help='the extra string for the name of the experiment')
# Misc
trainer_args.add_argument('--exclude-angle-attr', '-ea', action='store_true',
help='exclude the angle attr when predicting symbolic representation')
trainer_args.add_argument('--key-attr-only', '-ko', action='store_true',
help='only the key attrs(color/size/type) are considered during observation')
trainer_args.add_argument('--tsne-key', '-tk', type=str, default=None,
choices=['latent_logits', 'sgm_inter_1', 'sgm_inter_2'],
help='key name to be visualized by tsne')
trainer_args.add_argument('--tsne-thresh', '-tst', type=float, default=None,
help='the thresh of ind features to be considered')
trainer_args.add_argument('--tsne-fit-output-file', '-tso', type=str,
default='tsne_fit_results.pkl',
help='the name of saved tsne fit results file')
trainer_args.add_argument('--tsne-positive-k', '-tsp', action='store_true',
help='record positive k only for tsne if True')
model_args = parser.add_argument_group('Model Related Args')
# Basic
model_args.add_argument('--model', '-m', type=str, default='analogy',
choices=['simple', 'shared', 'analogy'], help='which model to be used')
model_args.add_argument('--normal-group-mlp', '-ngm', action='store_true',
help='use normal group mlp instead of shared group mlp if True')
model_args.add_argument('--num-features', '-nf', type=int, default=None,
help='The number of feature dimensions in the given task')
model_args.add_argument('--one-hot', '-oh', action='store_true',
help='use one hot representation for attributes inputs if True')
model_args.add_argument('--embedding-hidden-dims', '-ehd', type=int, nargs='+',
default=[],
help='the hidden dimensions of the mlp model for embedding (default: [])')
model_args.add_argument('--embedding-dim', '-ed', type=int, default=None,
help='the dimension of the embedding of features')
model_args.add_argument('--enable-residual-block', '-rb', action='store_true',
help='use the residual block after embedding if True')
model_args.add_argument('--use-ordinary-mlp', '-om', action='store_true',
help='use the ordinary mlp instead shared group mlp if True')
model_args.add_argument('--enable-rb-after-experts', '-erb', action='store_true',
help='use the residual block after experts if True')
# make sure it divs num features or embedding dims (when ed is not None)
model_args.add_argument('--feature-embedding-dim', '-fed', type=int, default=1,
help='The dimension of embedding space of features')
model_args.add_argument('--num-experts', '-ne', type=int, default=5,
help='The number of experts used')
# In simple/shared model, this is used as hidden dims of the mlp layer.
model_args.add_argument('--hidden-dims', '-hd', type=int, nargs='+',
default=[32, 16],
help='the hidden dimensions of the expert model (default: [32, 16])')
model_args.add_argument('--v2s-softmax', '-vss', action='store_true',
help='use the softmax after the output of visual mlp if True')
model_args.add_argument('--not-use-softmax', '-ns', action='store_true',
help='not use the softmax after the output of experts if True')
model_args.add_argument('--reduction-groups', '-rg', type=int, nargs='+',
default=[2], help='the reduction groups')
model_args.add_argument('--sum-as-reduction', '-sum', type=int, default=0,
help='how many reductions use sum (only a suffix of reductions)')
model_args.add_argument('--lastmlp-hidden-dims', '-lhd', type=int, nargs='+',
default=[],
help='the hidden dimensions of the last mlp model (default: [])')
# Visual
model_args.add_argument('--use-visual-inputs', '-v', action='store_true',
help='Use visual inputs if True')
model_args.add_argument('--image-size', '-is', type=int, nargs='+',
default=[160], help='the size of the image')
model_args.add_argument('--use-resnet', '-ur', action='store_true',
help='use resnet instead of convnets')
model_args.add_argument('--conv-hidden-dims', '-chd', type=int, nargs='+',
default=[8, 16, 32, 32],
help='the hidden dimensions of the conv model (default: [8, 16, 32, 32])')
model_args.add_argument('--conv-repeats', '-cr', type=int, nargs='+',
default=None,
help='the repeat times of conv layers in each block (default: None)')
model_args.add_argument('--conv-kernels', '-ck', type=int, nargs='+', default=3,
help='the kernel size of conv layers in each block (default: 3)')
model_args.add_argument('--conv-residual-link', '-crl', action='store_true',
help='enable the residual links in the conv block')
model_args.add_argument('--use-layer-norm', '-ln', action='store_true',
help='Use layer norm instead of batch norm for convs')
model_args.add_argument('--visual-mlp-hidden-dims', '-vhd', type=int, nargs='+',
default=[],
help='the hidden dimensions of the visual mlp model (default: [])')
model_args.add_argument('--num-visual-experts', '-nve', type=int, default=1,
help='the number of experts for visual feature extraction')
model_args.add_argument('--factor-groups', '-fg', type=int, default=1,
help='the groups used in visual model object-feature decomposition')
model_args.add_argument('--split-channel', '-sc', action='store_true',
help='split the channel dim instead of spatial dim for the output of conv')
model_args.add_argument('--transformed-spatial-dim', '-tsd', type=int,
default=None, help='the output dim of mlp_transform in visual model')
model_args.add_argument('--mlp-transform-hidden-dims', '-thd',
type=int, nargs='+', default=[],
help='the hidden dimensions of the mlp transform model (default: [])')
# Aux Loss
model_args.add_argument('--entropy-beta', '-en', type=float, default=0.0,
help='To control the weight of entropy loss')
model_args.add_argument('--symbolic-beta', '-sb', type=float, default=0.0,
help='To control the weight of symbolic loss')
model_args.add_argument('--prediction-beta', '-pb', type=float, default=1.0,
help='To control the weight of prediction loss')
args = parser.parse_args()
if args.random_seed is not None:
random.seed(args.random_seed)
if args.numpy_random_seed is not None:
np.random.seed(args.numpy_random_seed)
if args.torch_random_seed is not None:
torch.random.manual_seed(args.torch_random_seed)
torch.cuda.random.manual_seed_all(args.torch_random_seed)
# Control Randomness
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# To get reproducible results, set the random seed for torch and numpy
# If there are still inconsistence, set worker_init_fn for the dataloader
args.use_gpu = args.use_gpu and torch.cuda.is_available()
# The name 'split' is not that proper here,
# used to specify different variation of the dataset.
if len(args.image_size) == 1:
args.image_size.append(args.image_size[0])
args.image_size = args.image_size[:2]
args.tsne = args.tsne_key is not None
args.is_in_out_structure = len(args.task) == 1 and args.task[0] in [
'in_out', 'in_distri']
args.is_two_part_structure = len(args.task) == 1 and args.task[0] in [
'left_right', 'up_down', 'in_out', 'in_distri']
if args.num_context is None:
args.num_context = 8
if args.num_candidates is None:
args.num_candidates = 8
args.shared_group_mlp = not args.normal_group_mlp
# Better to use w=2 and cpu=4 for visual inputs
# And w=0 and cpu=2 for symbolic inputs
if args.num_workers is None:
if args.use_visual_inputs:
args.num_workers = 2
else:
args.num_workers = 0
def get_model(args):
return Model(
model_name=args.model,
nr_features=args.num_features,
nr_experts=args.num_experts,
shared_group_mlp=args.shared_group_mlp,
one_hot=args.one_hot,
v2s_softmax=args.v2s_softmax,
not_use_softmax=args.not_use_softmax,
# visual related
visual_inputs=args.use_visual_inputs,
factor_groups=args.factor_groups,
split_channel=args.split_channel,
image_size=args.image_size,
use_layer_norm=args.use_layer_norm,
use_resnet=args.use_resnet,
conv_hidden_dims=args.conv_hidden_dims,
conv_repeats=args.conv_repeats,
conv_kernels=args.conv_kernels,
conv_residual_link=args.conv_residual_link,
nr_visual_experts=args.num_visual_experts,
visual_mlp_hidden_dims=args.visual_mlp_hidden_dims,
transformed_spatial_dim=args.transformed_spatial_dim,
mlp_transform_hidden_dims=args.mlp_transform_hidden_dims,
exclude_angle_attr=args.exclude_angle_attr,
symbolic_beta=args.symbolic_beta,
prediction_beta=args.prediction_beta,
# embedding
embedding_dim=args.embedding_dim,
embedding_hidden_dims=args.embedding_hidden_dims,
enable_residual_block=args.enable_residual_block,
use_ordinary_mlp=args.use_ordinary_mlp,
enable_rb_after_experts=args.enable_rb_after_experts,
feature_embedding_dim=args.feature_embedding_dim,
# experts/simple/shared
hidden_dims=args.hidden_dims,
# reduction
reduction_groups=args.reduction_groups,
sum_as_reduction=args.sum_as_reduction,
lastmlp_hidden_dims=args.lastmlp_hidden_dims,
# input format
nr_context=args.num_context,
nr_candidates=args.num_candidates,
# TSNE
collect_inter_key=args.tsne_key)
def get_dataloader():
if args.dataset == 'PGM':
train = PGMdataset(args.dataset_path, "train", args.img_size, transform=transforms.Compose([ToTensor()]), shuffle = True)
valid = PGMdataset(args.dataset_path, "val", args.img_size, transform=transforms.Compose([ToTensor()]))
test = PGMdataset(args.dataset_path, "test", args.img_size, transform=transforms.Compose([ToTensor()]))
elif args.dataset == 'RAVEN':
# (I-)RAVEN
args.train_figure_configurations = [0,1,2,3,4,5,6]
args.val_figure_configurations = args.train_figure_configurations
args.test_figure_configurations = [0,1,2,3,4,5,6]
train = RAVENdataset(args.dataset_path, "train", args.train_figure_configurations, args.img_size, transform=transforms.Compose([ToTensor()]), shuffle = True)
valid = RAVENdataset(args.dataset_path, "val", args.val_figure_configurations, args.img_size, transform=transforms.Compose([ToTensor()]))
test = RAVENdataset(args.dataset_path, "test", args.test_figure_configurations, args.img_size, transform=transforms.Compose([ToTensor()]))
trainloader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=args.load_workers)
validloader = DataLoader(valid, batch_size=args.batch_size, shuffle=False, num_workers=args.load_workers)
testloader = DataLoader(test, batch_size=args.batch_size, shuffle=False, num_workers=args.load_workers)
return trainloader, validloader, testloader
def train(trainloader, optimizer, model, epoch):
model.train()
train_loss = 0
accuracy = 0
loss_all = 0.0
acc_all = 0.0
counter = 0
for batch_idx, (images, label, meta_target) in enumerate(trainloader):
counter += 1
if args.use_cuda:
images = images.cuda()
label = label.cuda()
meta_target = meta_target.cuda()
optimizer.zero_grad()
inputs = {
"images": images,
"label": label
}
loss, logits = model(inputs)
loss.backward()
optimizer.step()
pred = logits.argmax(dim=-1)
correct = pred.eq(label.data).cpu().sum().numpy()
accuracy = correct * 100. / label.size()[0]
logging.info('Train: Epoch:{}, Batch:{}, Loss:{:.6f}, Acc:{:.4f}.'.format(epoch, batch_idx, loss.item(), accuracy))
loss_all += loss
acc_all += accuracy
if counter > 0:
logging.info("Avg Training Loss: {:.6f}".format(loss_all/float(counter)))
return loss_all/float(counter), acc_all/float(counter)
def validate(validloader, model, epoch):
model.eval()
accuracy = 0
acc_all = 0.0
counter = 0
with torch.no_grad():
for batch_idx, (images, label, meta_target) in enumerate(validloader):
counter += 1
if args.use_cuda:
images = images.cuda()
label = label.cuda()
meta_target = meta_target.cuda()
inputs = {
"images": images,
"label": label
}
loss, logits = model(inputs)
pred = logits.argmax(dim=-1)
correct = pred.eq(label.data).cpu().sum().numpy()
accuracy = correct * 100. / label.size()[0]
acc_all += accuracy
if counter > 0:
logging.info("Total Validation Acc: {:.4f}".format(acc_all/float(counter)))
return acc_all/float(counter)
def test(testloader, model, epoch):
model.eval()
accuracy = 0
acc_all = 0.0
counter = 0
with torch.no_grad():
for batch_idx, (images, label, meta_target) in enumerate(testloader):
counter += 1
if args.use_cuda:
images = images.cuda()
label = label.cuda()
meta_target = meta_target.cuda()
inputs = {
"images": images,
"label": label
}
loss, logits = model(inputs)
pred = logits.argmax(dim=-1)
correct = pred.eq(label.data).cpu().sum().numpy()
accuracy = correct * 100. / label.size()[0]
acc_all += accuracy
if counter > 0:
logging.info("Total Testing Acc: {:.4f}".format(acc_all / float(counter)))
return acc_all/float(counter)
def main():
trainloader, validloader, testloader = get_dataloader()
model = get_model(args)
if args.use_gpu:
model = DataParallel(model).cuda()
if args.weight_decay == 0:
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
optimizer = optim.AdamW(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
for epoch in range(args.epochs):
train(trainloader, optimizer, model, epoch)
validate(validloader, model, epoch)
test(trainloader, model, epoch)
if __name__ == '__main__':
main()
VALUES = [8, 10, 6, 6]
TOTAL_VALUES = sum(VALUES)
MAX_VALUE = max(VALUES)
NUM_ATTRS = len(VALUES)
ORIGIN_IMAGE_SIZE = (160, 160)
\ No newline at end of file
import os
import glob
import numpy as np
from scipy import misc
import torch
from torch.utils.data import Dataset
from torchvision import transforms, utils
class ToTensor(object):
def __call__(self, sample):
return torch.tensor(sample, dtype=torch.float32)
class PGMdataset(Dataset):
def __init__(self, root_dir, dataset_type, img_size, transform=None, shuffle=False):
self.root_dir = root_dir
self.transform = transform
self.file_names = [f for f in glob.glob(os.path.join(root_dir, "*.npz")) if dataset_type in os.path.basename(f)]
self.img_size = img_size
self.shuffle = shuffle
def __len__(self):
return len(self.file_names)
def __getitem__(self, idx):
data_path = self.file_names[idx]
data = np.load(data_path)
image = data["image"].reshape(16, 160, 160)
target = data["target"]
meta_target = data["meta_target"]
if self.shuffle:
context = image[:8, :, :]
choices = image[8:, :, :]
indices = np.arange(8)
np.random.shuffle(indices)
new_target = np.where(indices == target)[0][0]
new_choices = choices[indices, :, :]
image = np.concatenate((context, new_choices))
target = new_target
resize_image = []
for idx in range(0, 16):
resize_image.append(misc.imresize(image[idx,:,:], (self.img_size, self.img_size)))
resize_image = np.stack(resize_image)
if meta_target.dtype == np.int8:
meta_target = meta_target.astype(np.uint8)
del data
if self.transform:
resize_image = self.transform(resize_image)
target = torch.tensor(target, dtype=torch.long)
meta_target = self.transform(meta_target)
return resize_image, target, meta_target
figure_configuration_names = ['center_single', 'distribute_four', 'distribute_nine', 'in_center_single_out_center_single', 'in_distribute_four_out_center_single', 'left_center_single_right_center_single', 'up_center_single_down_center_single']
class RAVENdataset(Dataset):
def __init__(self, root_dir, dataset_type, figure_configurations, img_size, transform=None, shuffle=False):
self.root_dir = root_dir
self.transform = transform
self.file_names = []
for idx in figure_configurations:
tmp = [f for f in glob.glob(os.path.join(root_dir, figure_configuration_names[idx], "*.npz")) if dataset_type in os.path.basename(f)]
self.file_names += tmp
self.img_size = img_size
self.shuffle = shuffle
self.switch = [3,4,5,0,1,2,6,7]
def __len__(self):
return len(self.file_names)
def __getitem__(self, idx):
data_path = self.file_names[idx]
data = np.load(data_path)
image = data["image"].reshape(16, 160, 160)
target = data["target"]
meta_target = data["meta_target"]
if self.shuffle:
context = image[:8, :, :]
choices = image[8:, :, :]
indices = np.arange(8)
np.random.shuffle(indices)
new_target = np.where(indices == target)[0][0]
new_choices = choices[indices, :, :]
switch_2_rows = np.random.rand()
if switch_2_rows < 0.5:
context = context[self.switch, :, :]
image = np.concatenate((context, new_choices))
target = new_target
resize_image = []
for idx in range(0, 16):
resize_image.append(misc.imresize(image[idx,:,:], (self.img_size, self.img_size)))
resize_image = np.stack(resize_image)
del data
if self.transform:
resize_image = self.transform(resize_image)
target = torch.tensor(target, dtype=torch.long)
meta_target = self.transform(meta_target)
return resize_image, target, meta_target
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File : model.py
# Author : Honghua Dong
# Email : dhh19951@gmail.com
# Date : 11/21/2019
#
# Distributed under terms of the MIT license.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from modules import MLPModel
# from .baselines import SimpleModel, SharedModel
from relation_model import RelationModel
from utils import compute_entropy, compute_mi
from visual_model import VisualModel
from const import ORIGIN_IMAGE_SIZE, MAX_VALUE
__all__ = ['Model']
"""
The model is composed of two parts, one from visual to symbolic, another
from symbolic to prediction. The losses are computed for both symbolic
format and the final prediction, so it can fit all three kind of tasks:
from visual to symbolic, from visual to predicion, and from symbolic to
prediction.
Use prediction_beta (default 1.0) and symbolic_beta (default 0.0)
to control the loss.
"""
class Model(nn.Module):
def __init__(self,
nr_features,
model_name='analogy',
nr_experts=5,
shared_group_mlp=True,
one_hot=False,
v2s_softmax=False,
max_value=MAX_VALUE,
not_use_softmax=False,
visual_inputs=False,
factor_groups=1,
split_channel=False,
image_size=ORIGIN_IMAGE_SIZE,
use_layer_norm=False,
use_resnet=False,
conv_hidden_dims=[],
conv_repeats=None,
conv_kernels=3,
conv_residual_link=True,
nr_visual_experts=1,
visual_mlp_hidden_dims=[],
transformed_spatial_dim=None,
mlp_transform_hidden_dims=[],
exclude_angle_attr=False,
prediction_beta=1.0,
symbolic_beta=0.0,
embedding_dim=None,
embedding_hidden_dims=[],
enable_residual_block=False,
use_ordinary_mlp=False,
enable_rb_after_experts=False,
feature_embedding_dim=1,
hidden_dims=[16],
reduction_groups=[3],
sum_as_reduction=0,
lastmlp_hidden_dims=[],
nr_context=8,
nr_candidates=8,
collect_inter_key=None):
super().__init__()
# nr_features means the num of features for the symbolic representation.
# When it is the visual inputs case, nr_feature can also be the
# output dimension of the visual model. If this dimension matches the
# symbolic representation, the symbolic_loss is enabled.
# In one-hot case, nr_feature will be multiplied by max_value.
# symbolic loss can be weighted by symbolic_beta, as aux_loss.
self.original_nr_feature = nr_features
if one_hot:
logging.warning(('The nr_feature has been multiplied by {} to fit '
'the one hot representation. before: {}, after: {}').format(
max_value, nr_features, nr_features * max_value))
nr_features *= max_value
self.visual_inputs = visual_inputs
if visual_inputs:
self.v2s = VisualModel(
input_dim=1,
shared_group_mlp=shared_group_mlp,
conv_hidden_dims=conv_hidden_dims,
output_dim=nr_features,
use_resnet=use_resnet,
conv_repeats=conv_repeats,
conv_kernels=conv_kernels,
conv_residual_link=conv_residual_link,
nr_visual_experts=nr_visual_experts,
mlp_hidden_dims=visual_mlp_hidden_dims,
groups=factor_groups,
split_channel=split_channel,
transformed_spatial_dim=transformed_spatial_dim,
mlp_transform_hidden_dims=mlp_transform_hidden_dims,
image_size=image_size,
use_layer_norm=use_layer_norm)
self.symbolic_loss = nn.MSELoss()
if one_hot:
self.symbolic_loss = nn.CrossEntropyLoss()
def get_relation_model(name):
assert name in ['analogy', 'simple', 'shared'], \
'Unknown model name: {}'.format(name)
if name == 'analogy':
return RelationModel(
nr_features=nr_features,
nr_experts=nr_experts,
shared_group_mlp=shared_group_mlp,
not_use_softmax=not_use_softmax,
embedding_dim=embedding_dim,
embedding_hidden_dims=embedding_hidden_dims,
enable_residual_block=enable_residual_block,
enable_rb_after_experts=enable_rb_after_experts,
feature_embedding_dim=feature_embedding_dim,
hidden_dims=hidden_dims,
reduction_groups=reduction_groups,
sum_as_reduction=sum_as_reduction,
lastmlp_hidden_dims=lastmlp_hidden_dims,
use_ordinary_mlp=use_ordinary_mlp,
nr_context=nr_context,
nr_candidates=nr_candidates,
collect_inter_key=collect_inter_key)
# SymbolicModel = SimpleModel if name == 'simple' else SharedModel
# return SymbolicModel(
# nr_features=nr_features,
# hidden_dims=hidden_dims,
# nr_context=nr_context,
# nr_candidates=nr_candidates)
self.nr_features = nr_features
self.one_hot = one_hot
self.v2s_softmax = v2s_softmax
self.max_value = max_value
self.symbolic_model = get_relation_model(model_name)
self.nr_ind_features = nr_features
if model_name == 'analogy':
self.nr_ind_features = self.symbolic_model.nr_ind_features
self.collect_inter_key = collect_inter_key
# self.nr_experts = nr_experts
# self.not_use_softmax = not_use_softmax
# self.embedding_dim = embedding_dim
# self.embedding_hidden_dims = embedding_hidden_dims
# self.enable_residual_block = enable_residual_block
# self.hidden_dims = hidden_dims
# self.reduction_groups = reduction_groups
# self.sum_as_reduction = sum_as_reduction
# self.lastmlp_hidden_dims = lastmlp_hidden_dims
self.exclude_angle_attr = exclude_angle_attr
self.prediction_beta = prediction_beta
self.symbolic_beta = symbolic_beta
self.pred_loss = nn.CrossEntropyLoss()
self.have_shown_mismatch_warning = False
# TODO: get name by args
# def get_name(short=False):
# names = ['model']
# if self.visual_inputs:
# names.append(self.v2s.get_model_name(short))
# names.append('nf{}'.format(nr_features))
# if one_hot:
# names.append('oh')
# if prediction_beta != 1.0:
# names.append('pb{}'.format(prediction_beta))
# if symbolic_beta != 0.0:
# names.append('sb{}'.format(symbolic_beta))
# if symbolic_beta != 0.0 and exclude_angle_attr:
# names.append('ea')
# names.append(self.symbolic_model.get_model_name(short))
# if short:
# return ','.join(names)
# return '_'.join(names)
# self.name = get_name()
# self.short_name = get_name(short=True)
def compute_symbolic_loss(self, output, target):
if self.symbolic_beta == 0.0:
return None
batch, nr_images, nr_parts, nr_attrs = target.size()
if self.one_hot:
target = target.flatten().long()
else:
output = output.flatten()
target = target.flatten().float()
if self.exclude_angle_attr:
flag = np.zeros([batch, nr_images, nr_parts, nr_attrs])
flag[:, :, :, 0] = 1 # the angle attr is index 0 in nr_attrs dim
flag = flag.reshape(-1)
ind = np.where(flag == 0)[0]
output = output[ind]
target = target[ind]
# print(output.shape, target.shape)
return self.symbolic_loss(output, target)
def forward(self, inputs):
# images: [B, 16, h, w]
pred_symbol = self.v2s(inputs["images"])
# pred_symbol = F.softmax(pred_symbol, dim=-1)
output = self.symbolic_model(pred_symbol)
logits = output['logits']
label = inputs['label']
pred_loss = self.pred_loss(logits, label)
return pred_loss, logits
# # pred = logits.argmax(dim=-1)
# output_dict = {}
# if self.visual_inputs:
# visual = inputs['image'].float()
# pred_symbol = self.v2s(visual)
# if self.one_hot:
# pred_symbol = pred_symbol.view(-1, self.max_value)
# # Both loss are calculated using original symbolic ground truth.
# symbolic_loss = self.compute_symbolic_loss(pred_symbol, symbol)
# if self.one_hot and self.v2s_softmax:
# pred_symbol = F.softmax(pred_symbol, dim=-1)
# output_dict['pred_symbol'] = pred_symbol
# output_dict['ind_features'] = pred_symbol
# symbol = pred_symbol
# # NOTE: the shape of symbol is not consistent for different setup
# # resize the inputs in the symbolic model
# monitors = {}
# output = self.symbolic_model(symbol)
# trans_keys = ['ind_features', self.collect_inter_key]
# for k in trans_keys:
# if k is not None and k in output:
# output_dict[k] = output[k]
# logits = output['logits']
# label = inputs['label']
# pred = logits.argmax(dim=-1)
# # print(pred.shape)
# pred_result = pred.eq(label).float()
# output_dict['pred_result'] = pred_result
# monitors['acc'] = pred_result.mean()
# pred_loss = self.pred_loss(logits, label)
# output_dict['logits'] = logits
# output_dict['pred_loss'] = pred_loss
# # if symbolic_loss:
# # monitors['symbolic_loss'] = symbolic_loss
# # loss = pred_loss * self.prediction_beta + \
# # symbolic_loss * self.symbolic_beta
# # else:
# # loss = pred_loss
# # if 'latent_logits' in output:
# # latent_logits = output['latent_logits']
# # output_dict['latent_logits'] = latent_logits
# # entropy = compute_entropy(latent_logits)
# # # print(label, latent_logits.shape)
# # batch_size = label.size(0)
# # filtered = latent_logits[np.arange(batch_size), label]
# # mi = compute_mi(filtered)
# # mi = mi.mean()
# # monitors['entropy'] = entropy
# # monitors['mi'] = mi
# return loss, monitors, output_dict
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : modules.py
# Author : Honghua Dong, Tony Wu
# Email : dhh19951@gmail.com, tonywu0206@gmail.com
# Date : 11/06/2019
#
# Distributed under terms of the MIT license.
from turtle import forward
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import BasicBlock, Bottleneck
import logging
from const import ORIGIN_IMAGE_SIZE
__all__ = ['FCResBlock', 'Expert', 'Scorer', 'ConvBlock', 'ConvNet',
'ResNet', 'ResNetWrapper']
class MLPModel(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dims):
super().__init__()
if hidden_dims is None:
hidden_dims = []
elif type(hidden_dims) is int:
hidden_dims = [hidden_dims]
layers = []
dims = [in_dim] + hidden_dims + [out_dim]
for i in range(len(dims) - 1):
layers.append(nn.Linear(dims[i], dims[i + 1]))
self.mlp = nn.Sequential(*layers)
def forward(self, x):
return self.mlp(x)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class FCResBlock(nn.Module):
def __init__(self, nn_dim, use_layer_norm=True,):
self.use_layer_norm = use_layer_norm
super(FCResBlock, self).__init__()
self.norm_in = nn.LayerNorm(nn_dim)
self.norm_out = nn.LayerNorm(nn_dim)
self.transform1 = torch.nn.Linear(nn_dim, nn_dim)
torch.nn.init.normal_(self.transform1.weight, std=0.005)
self.transform2 = torch.nn.Linear(nn_dim, nn_dim)
torch.nn.init.normal_(self.transform2.weight, std=0.005)
def forward(self, x):
if self.use_layer_norm:
x_branch = self.norm_in(x)
else:
x_branch = x
x_branch = self.transform1(F.relu(x_branch))
if self.use_layer_norm:
x_branch = self.norm_out(x_branch)
x_out = x + self.transform2(F.relu(x_branch))
#x_out = self.transform2(F.relu(x_branch))
#x_out = F.relu(self.transform2(x_branch))
# return F.relu(x_out)
return x_out
class Expert(nn.Module):
def __init__(self, input_dim, hidden_dims):
super().__init__()
self.mlp = MLPModel(input_dim, 1, hidden_dims)
def forward(self, x):
return self.mlp(x)
"""
The reduction group meaning the axis to be reduced together, in the order of
(nr_fetures, nr_experts), from end to the front.
groups = [2] means reduce all two once.
groups = [1, 1] means first reduce nr_experts, then reduce nr_features
"""
class Scorer(nn.Module):
def __init__(self, nr_features, nr_experts,
hidden_dims=[], reduction_groups=[2], sum_as_reduction=0):
super().__init__()
assert sum(reduction_groups) == 2
if len(reduction_groups) == 1 and sum_as_reduction > 0:
logging.warning('The scorer is just a summation.')
dims = [nr_features, nr_experts]
self.nr_reductions = len(reduction_groups)
self.sum_as_reduction = sum_as_reduction
self.input_dims = []
mlps = []
for g in reduction_groups:
input_dim = np.prod(dims[-g:])
output_dim = 1
dims = dims[:-g]
mlp = MLPModel(input_dim, output_dim, hidden_dims=hidden_dims)
self.input_dims.append(input_dim)
mlps.append(mlp)
self.mlps = nn.ModuleList(mlps)
def forward(self, x):
# x.shape: (batch * nr_candidates, nr_features, nr_experts, num)
for i in range(self.nr_reductions):
x = x.view(-1, self.input_dims[i])
s = self.sum_as_reduction
if s is not None and i + s >= self.nr_reductions:
x = x.sum(dim=-1)
else:
x = self.mlps[i](x)
return x
class ConvBlock(nn.Module):
def __init__(self, input_dim, output_dim, h, w, repeats=1,
kernel_size=3, padding=1, residual_link=True, use_layer_norm=False):
super().__init__()
convs = []
norms = []
if type(kernel_size) is int:
kh, kw = kernel_size, kernel_size
else:
kh, kw = kernel_size
current_dim = input_dim
for i in range(repeats):
stride = 1
if i == 0:
# The reduction conv
stride = 2
h = (h + 2 * padding - kh + stride) // stride
w = (w + 2 * padding - kw + stride) // stride
convs.append(nn.Conv2d(current_dim, output_dim,
kernel_size=kernel_size, stride=stride, padding=padding))
current_dim = output_dim
if use_layer_norm:
norms.append(nn.LayerNorm([current_dim, h, w]))
else:
norms.append(nn.BatchNorm2d(current_dim))
self.residual_link = residual_link
self.convs = nn.ModuleList(convs)
self.norms = nn.ModuleList(norms)
self.output_size = (h, w)
def forward(self, x):
is_reduction = True
for conv, norm in zip(self.convs, self.norms):
# ConvNormReLU
_ = x
_ = conv(_)
_ = norm(_)
_ = F.relu(_)
if is_reduction or not self.residual_link:
x = _
else:
x = x + _
is_reduction = False
return x
class ConvNet(nn.Module):
def __init__(self,
input_dim,
hidden_dims,
repeats=None,
kernels=3,
residual_link=True,
image_size=ORIGIN_IMAGE_SIZE,
flatten=False,
use_layer_norm=False):
super().__init__()
h, w = image_size
if type(kernels) is list:
if len(kernels) == 1:
kernel_size = kernels[0]
else:
kernel_size = tuple(kernels)
else:
kernel_size = kernels
if repeats is None:
repeats = [1 for i in range(len(hidden_dims))]
else:
assert len(repeats) == len(hidden_dims)
conv_blocks = []
current_dim = input_dim
# NOTE: The last hidden dim is the output dim
for rep, hidden_dim in zip(repeats, hidden_dims):
block = ConvBlock(current_dim, hidden_dim, h, w,
repeats=rep,
kernel_size=kernel_size,
residual_link=residual_link,
use_layer_norm=use_layer_norm)
current_dim = hidden_dim
conv_blocks.append(block)
h, w = block.output_size
self.conv_blocks = nn.ModuleList(conv_blocks)
self.flatten = flatten
self.output_dim = hidden_dims[-1]
self.output_image_size = (h, w)
# self.output_size = hidden_dims[-1] * h * w
def forward(self, x):
for conv_block in self.conv_blocks:
x = conv_block(x)
# default: image_size = (80, 80)
# batch, input_dim, 80, 80
# batch, hidden_dim[0], 40, 40
# batch, hidden_dim[1], 20, 20
# batch, hidden_dim[2], 10, 10
# batch, hidden_dim[3], 5, 5
if self.flatten:
x = x.flatten(1, -1)
# batch, hidden_dim[4] * 5 * 5
return x
# adapt from https://pytorch.org/docs/master/_modules/torchvision/models/resnet.html
class ResNet(nn.Module):
def __init__(self,
block,
repeats,
inplanes=64,
channels=[64, 128, 256, 512],
input_dim=3,
zero_init_residual=False,
norm_layer=None,
enable_maxpool=True):
super(ResNet, self).__init__()
assert repeats is not None
nr_layers = len(repeats)
assert len(channels) == nr_layers
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = inplanes
self.conv1 = nn.Conv2d(input_dim, self.inplanes,
kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = None
if enable_maxpool:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
layers = []
for i in range(nr_layers):
stride = 2 if i > 0 else 1
layers.append(self._make_layer(
block, channels[i], repeats[i], stride=stride))
self.layers = nn.Sequential(*layers)
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.maxpool is not None:
x = self.maxpool(x)
x = self.layers(x)
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
return x
def forward(self, x):
return self._forward_impl(x)
class ResNetWrapper(nn.Module):
def __init__(self, repeats=[2, 2, 2], inplanes=8, channels=[8, 16, 32],
input_dim=1, image_size=ORIGIN_IMAGE_SIZE):
super().__init__()
self.resnet = ResNet(
block=BasicBlock,
repeats=repeats,
inplanes=inplanes,
channels=channels,
input_dim=input_dim,
enable_maxpool=False)
h, w = image_size
for i in range(len(repeats)):
h = h // 2
w = w // 2
self.output_dim = channels[-1]
self.output_image_size = (h, w)
def forward(self, x):
# input.shape: 1, h, w
# after conv1: inplanes, h//2, w//2
# after layer1: inplanes, h//2, w//2
# after layer2: inplanes, h//4, w//4
# after layer3: inplanes, h//8, w//8
# after layer2: inplanes, h//16, w//16
return self.resnet(x)
# Use conv1D with group param as group mlp
class GroupMLP(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dims=[], groups=1,
batch_norm=None, dropout=None, activation='relu', flatten=True):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
assert input_dim % groups == 0
dims = [input_dim]
dims.extend(hidden_dims)
dims.append(output_dim)
layers = []
nr_layers = len(dims) - 1
for i in range(nr_layers):
assert dims[i + 1] % groups == 0
if i + 1 < nr_layers:
layers.append(Conv1DLayer(dims[i], dims[i + 1],
kernel_size=1, groups=groups, batch_norm=batch_norm,
dropout=dropout, activation=activation))
else: # last layer
layers.append(Conv1DLayer(dims[i], dims[i + 1], kernel_size=1,
groups=groups, bias=True))
self.mlp = nn.ModuleList(layers)
# self.flatten = flatten
def forward(self, x):
shape = list(x.size())
# x.shape: (batch, *, input_dim)
x = x.flatten(0, -2)
# x.shape: (batch', input_dim)
x = x.unsqueeze(-1)
# x.shape: (batch', input_dim, 1)
for layer in self.mlp:
x = layer(x)
# x.shape: (batch', output_dim, 1)
shape[-1] = self.output_dim
x = x.view(*shape)
# x.shape: (batch, *, output_dim)
return x
"""
SharedGroupMLP: split over the last dimension
and put to the batch dimension, then apply MLP.
"""
class SharedGroupMLP(nn.Module):
# shared mlp over groups splited over the last dim
# take the last two dims as input dims (while the last one is splitted)
def __init__(self, groups, group_input_dim, group_output_dim,
hidden_dims=[], add_res_block=True, nr_mlps=1, flatten=True,
shared=True):
super().__init__()
self.shared = shared
self.groups = groups
self.group_input_dim = group_input_dim
self.group_output_dim = group_output_dim
# mlps indicates different experts
if shared:
mlps = [MLPModel(group_input_dim, group_output_dim,
hidden_dims=hidden_dims, flatten=False) for i in range(nr_mlps)]
else:
mlps = [GroupMLP(group_input_dim * groups, group_output_dim * groups,
hidden_dims=hidden_dims) for i in range(nr_mlps)]
self.mlps = nn.ModuleList(mlps)
self.FCblocks = None
if shared and add_res_block:
FCblocks = [FCResBlock(group_output_dim) for i in range(nr_mlps)]
self.FCblocks = nn.ModuleList(FCblocks)
self.flatten = flatten
def forward(self, x, inter_results_container=None, inter_layer=0):
assert x.size(-1) % self.groups == 0
group_size = x.size(-1) // self.groups
xs = x.split(group_size, dim=-1)
new_xs = []
for i in xs:
# apply on last two axis, the last axis is splitted
x = i.flatten(-2, -1)
# x.shape: (batch, *, group_input_dim)
new_xs.append(x)
x = torch.stack(new_xs, dim=-2)
# x.shape: (batch, *, groups, group_input_dim)
if not self.shared:
x = x.flatten(-2, -1)
# x.shape: (batch, *, groups * group_input_dim)
ys = []
for ind, mlp in enumerate(self.mlps):
if inter_results_container is not None:
nr_layers = min(inter_layer, len(mlp.mlp))
t = x
for j in range(nr_layers):
t = mlp.mlp[j](t)
# t.shape: (batch, *, groups, hidden_dims[nr_layers - 1])
inter_results_container.append(t)
y = mlp(x)
if self.FCblocks:
y = self.FCblocks[ind](y)
ys.append(y)
x = torch.cat(ys, dim=-1)
# [no-share] x.shape: (batch, *, nr_mlps * groups * group_output_dim)
# [shared] x.shape: (batch, *, groups, nr_mlps * group_output_dim)
if self.shared and self.flatten:
x = x.flatten(-2, -1)
# x.shape: (batch, *, groups * nr_mlps * group_output_dim)
return x
class Conv1DLayer(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding_mode='default', padding=0, border_mode='zeros',
dilation=1, groups=1,
batch_norm=None, dropout=None, bias=None, activation=None):
if bias is None:
bias = (batch_norm is None)
modules = [
nn.Conv1d(in_channels, out_channels, kernel_size,
stride=stride, padding_mode=padding_mode, padding=padding, border_mode=border_mode,
dilation=dilation, groups=groups, bias=bias)
]
if batch_norm is not None and batch_norm is not False:
modules.append(nn.BatchNorm1d(out_channels))
if dropout is not None and dropout is not False:
modules.append(nn.Dropout(0.5, True))
if activation is not None and activation is not False:
modules.append(get_activation(activation))
super().__init__(*modules)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding_mode = padding_mode
self.padding = padding
self.border_mode = border_mode
self.dilation = dilation
self.groups = groups
def reset_parameters(self):
for module in self.modules():
if 'Conv' in module.__class__.__name__:
module.reset_parameters()
@property
def input_dim(self):
return self.in_channels
@property
def output_dim(self):
return self.out_channels
def get_activation(act):
if isinstance(act, nn.Module):
return act
assert type(act) is str, 'Unknown type of activation: {}.'.format(act)
act_lower = act.lower()
if act_lower == 'identity':
return nn.Identity()
elif act_lower == 'relu':
return nn.ReLU(True)
elif act_lower == 'sigmoid':
return nn.Sigmoid()
elif act_lower == 'tanh':
return nn.Tanh()
else:
try:
return getattr(nn, act)
except AttributeError:
raise ValueError('Unknown activation function: {}.'.format(act))
\ No newline at end of file
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : model.py
# Author : Honghua Dong, Tony Wu
# Email : dhh19951@gmail.com, tonywu0206@gmail.com
# Date : 11/06/2019
#
# Distributed under terms of the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from modules import FCResBlock, SharedGroupMLP, Expert, Scorer, MLPModel
from utils import transform
__all__ = ['AnalogyModel']
class RelationModel(nn.Module):
def __init__(self,
nr_features,
nr_experts=5,
shared_group_mlp=True,
expert_output_dim=1,
not_use_softmax=False,
embedding_dim=None,
embedding_hidden_dims=[],
enable_residual_block=False,
enable_rb_after_experts=False,
feature_embedding_dim=1,
hidden_dims=[16],
reduction_groups=[3],
sum_as_reduction=0,
lastmlp_hidden_dims=[],
use_ordinary_mlp=False,
nr_context=8,
nr_candidates=8,
collect_inter_key=None):
super().__init__()
self.nr_context = nr_context
self.nr_candidates = nr_candidates
self.nr_images = nr_context + nr_candidates
self.nr_input_images = nr_context + 1
self.nr_experts = nr_experts
self.feature_embedding_dim = feature_embedding_dim
self.not_use_softmax = not_use_softmax
self.nr_features = nr_features
self.nr_candidates = nr_candidates
self.collect_inter_key = collect_inter_key
current_dim = nr_features
self.enable_residual_block = enable_residual_block
if self.enable_residual_block:
self.FCblock = FCResBlock(current_dim)
self.embedding = None
self.embedding_dim = embedding_dim
if embedding_dim is not None:
self.embedding = MLPModel(current_dim, embedding_dim,
hidden_dims=embedding_hidden_dims)
current_dim = embedding_dim
assert feature_embedding_dim > 0
assert current_dim % feature_embedding_dim == 0, (
'feature embedding dim should divs current dim '
'(nr_feature or embedding_dim)')
current_dim = current_dim // feature_embedding_dim
self.nr_ind_features = current_dim
self.group_input_dim = self.nr_input_images * feature_embedding_dim
# experts = [Expert(self.group_input_dim, hidden_dims)
# for i in range(nr_experts)]
# self.experts = nn.ModuleList(experts)
assert expert_output_dim == 1, 'only supports expert_output_dim == 1'
groups = current_dim
# group_size = feature_embedding_dim
group_output_dim = expert_output_dim
if use_ordinary_mlp:
groups = 1
# group_size = current_dim * feature_embedding_dim
self.group_input_dim *= current_dim
group_output_dim = self.nr_ind_features * expert_output_dim
self.experts = SharedGroupMLP(
groups=groups,
group_input_dim=self.group_input_dim,
group_output_dim=group_output_dim,
hidden_dims=hidden_dims,
add_res_block=enable_rb_after_experts,
nr_mlps=nr_experts,
shared=shared_group_mlp)
self.scorer = Scorer(current_dim, nr_experts,
hidden_dims=lastmlp_hidden_dims,
reduction_groups=reduction_groups,
sum_as_reduction=sum_as_reduction)
def forward(self, x):
current_dim = self.nr_features
x = x.view(-1, self.nr_images, current_dim).float()
# x.shape: (batch, nr_images, current_dim)
if self.enable_residual_block:
x = self.FCblock(x)
if self.embedding:
x = x.view(-1, current_dim)
# x.shape: (batch * nr_images, current_dim)
x = self.embedding(x)
# x.shape: (batch * nr_images, embedding_dims)
current_dim = self.embedding_dim
x = x.view(-1, self.nr_images, current_dim)
# x.shape: (batch, nr_images, current_dim)
fe_dim = self.feature_embedding_dim
ind_features = x
# The $x$ is the extracted features from inputs
# And the scorers regard $x$ as indenpent features
x = transform(x,
nr_context=self.nr_context, nr_candidates=self.nr_candidates)
# x.shape: (batch, nr_candidates, nr_context + 1, current_dim)
nr_ind_features = self.nr_ind_features
# Using SharedGroupMLP
nr_input_images = self.nr_input_images
x = x.view(-1, nr_input_images, current_dim)
# x.shape: (batch * nr_candidates, nr_input_images, current_dim)
# experts: split groups over the last dim, with group size fed
# each group corresponding to a feature
container = None
inter_layer = 0
ci_key = self.collect_inter_key
if ci_key is not None and ci_key.startswith('sgm_inter'):
container = []
inter_layer = int(ci_key[-1])
latent_logits = self.experts(x,
inter_results_container=container, inter_layer=inter_layer)
latent_logits = latent_logits.view(-1, nr_ind_features, self.nr_experts)
# latent_logits.shape: (batch * nr_candidates,
# nr_ind_features * nr_experts * expert_output_dim)
# # Using Expert
# x = x.view(-1, self.nr_candidates, nr_input_images, nr_ind_features, fe_dim)
# # x.shape: (batch, nr_candidates, nr_input_images, nr_ind_features, fe_dim)
# x = x.permute(0, 1, 3, 2, 4).contiguous()
# # x.shape: (batch, nr_candidates, nr_ind_features, nr_input_images, fe_dim)
# x = x.view(-1, self.group_input_dim)
# # x.shape: (batch * nr_candidates * nr_ind_features, group_input_dim)
# latent_logits = torch.cat([
# expert(x) for expert in self.experts], dim=-1)
# # latent_logits/x.shape: (batch * nr_candidates * nr_ind_features, nr_experts)
if self.not_use_softmax:
x = latent_logits
else:
x = F.softmax(latent_logits, dim=-1)
latent_logits = latent_logits.view(-1,
self.nr_candidates, nr_ind_features, self.nr_experts)
# x.shape: (batch * nr_candidates * nr_ind_features, nr_experts)
x = x.view(-1, nr_ind_features, self.nr_experts)
# x.shape: (batch * nr_candidates, nr_ind_features, nr_experts)
x = self.scorer(x)
x = x.view(-1, self.nr_candidates)
results = dict(logits=x,
latent_logits=latent_logits,
ind_features=ind_features)
if ci_key is not None and ci_key.startswith('sgm_inter'):
sgm_inter = torch.cat(container, dim=-1)
num = sgm_inter.size(-1)
sgm_inter = sgm_inter.view(
-1, self.nr_candidates, nr_ind_features, num)
results[ci_key] = sgm_inter
return results
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : utils.py
# Author : Honghua Dong, Tony Wu
# Email : dhh19951@gmail.com, tonywu0206@gmail.com
# Date : 11/06/2019
#
# Distributed under terms of the MIT license.
# TODO: compute mutual information instead of entropy
import torch
import torch.nn.functional as F
__all__ = ['compute_entropy', 'compute_mi', 'transform', 'vis_transform']
def compute_mi(logits, eps=1e-8):
# logits.shape: (batch, current_dim, nr_experts)
logits = logits.permute(1, 0, 2).contiguous()
# logits.shape: (current_dim, batch, nr_experts)
policy = F.softmax(logits, dim=-1)
log_policy = F.log_softmax(logits, dim=-1)
entropy = -(policy * log_policy).sum(dim=-1)
H_expert_given_x = entropy.mean(dim=-1)
avg_policy = policy.mean(dim=-2)
log_avg_policy = (avg_policy + eps).log()
H_expert = -(avg_policy * log_avg_policy).sum(dim=-1)
return H_expert - H_expert_given_x
def compute_entropy(logits):
policy = F.softmax(logits, dim=-1)
log_policy = F.log_softmax(logits, dim=-1)
entropy = -(policy * log_policy).sum(dim=-1)
return entropy.mean()
'''
To fill the candidates in and regard them as batch
'''
def transform(inputs, nr_context=8, nr_candidates=8):
context = inputs.narrow(1, 0, nr_context)
# context.shape: (batch, nr_context, nr_features)
candidates = inputs.narrow(1, nr_context, nr_candidates)
# candidates.shape: (batch, nr_candidates, nr_features)
context = context.unsqueeze(1)
# context.shape: (batch, 1, nr_context, nr_features)
context = context.expand(-1, nr_candidates, -1, -1)
# context.shape: (batch, nr_candidates, nr_context, nr_features)
candidates = candidates.unsqueeze(2)
# candidates.shape: (batch, nr_candidates, 1, nr_features)
merged = torch.cat([context, candidates], dim=2)
# merged.shape: (batch, nr_candidates, nr_context + 1, nr_features)
return merged
def vis_transform(inputs, nr_context=8, nr_candidates=8):
context = inputs.narrow(1, 0, nr_context)
# context.shape: (batch, nr_context, IMG_SIZE, IMG_SIZE)
candidates = inputs.narrow(1, nr_context, nr_candidates)
# candidates.shape: (batch, nr_candidates, IMG_SIZE, IMG_SIZE)
context = context.unsqueeze(1)
# context.shape: (batch, 1, nr_context, IMG_SIZE, IMG_SIZE)
context = context.expand(-1, nr_candidates, -1, -1, -1)
# context.shape: (batch, nr_candidates, nr_context, IMG_SIZE, IMG_SIZE)
candidates = candidates.unsqueeze(2)
# candidates.shape: (batch, nr_candidates, 1, IMG_SIZE, IMG_SIZE)
merged = torch.cat([context, candidates], dim=2)
# merged.shape: (batch, nr_candidates, nr_context + 1, IMG_SIZE, IMG_SIZE)
return merged
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import ResNetWrapper, ConvNet, MLPModel, SharedGroupMLP
from const import ORIGIN_IMAGE_SIZE
class VisualModel(nn.Module):
def __init__(self,
conv_hidden_dims,
output_dim,
input_dim=1,
use_resnet=False,
conv_repeats=None,
conv_kernels=3,
conv_residual_link=True,
transformed_spatial_dim=None,
mlp_transform_hidden_dims=[],
image_size=ORIGIN_IMAGE_SIZE,
use_layer_norm=False,
shared_group_mlp=True,
nr_visual_experts=1,
mlp_hidden_dims=[],
groups=1,
split_channel=False,
):
super().__init__()
if use_resnet:
self.cnn = ResNetWrapper(
repeats=conv_repeats,
inplanes=conv_hidden_dims[0],
channels=conv_hidden_dims,
image_size=image_size)
else:
self.cnn = ConvNet(
input_dim=input_dim,
hidden_dims=conv_hidden_dims,
repeats=conv_repeats,
kernels=conv_kernels,
residual_link=conv_residual_link,
image_size=image_size,
use_layer_norm=use_layer_norm)
# self.cnn_output_size = self.cnn.output_size
self.cnn_output_dim = self.cnn.output_dim
h, w = self.cnn.output_image_size
current_dim = h * w
self.spatial_dim = current_dim
self.transformed_spatial_dim = transformed_spatial_dim
self.mlp_transform = None
if transformed_spatial_dim is not None and transformed_spatial_dim > 0:
self.mlp_transform = MLPModel(
current_dim, transformed_spatial_dim,
hidden_dims=mlp_transform_hidden_dims)
current_dim = transformed_spatial_dim
total_dim = self.cnn_output_dim * current_dim
self.split_channel = split_channel
if split_channel:
current_dim = self.cnn_output_dim
assert current_dim % groups == 0, ('the spatial dim {} should be '
'divided by the number of groups {}').format(current_dim, groups)
assert output_dim % (groups * nr_visual_experts) == 0, (
'the output dim {} should be divided by the prod of number of '
'groups {} and the number of visual experts {}').format(
output_dim, groups, nr_visual_experts)
self.shared_group_mlp = SharedGroupMLP(
groups=groups,
group_input_dim=total_dim // groups,
group_output_dim=output_dim // (groups * nr_visual_experts),
hidden_dims=mlp_hidden_dims,
nr_mlps=nr_visual_experts,
shared=shared_group_mlp)
self.output_dim = output_dim
def forward(self, x):
x = x.float().contiguous()
nr_images, h, w = x.size()[1:]
# x.shape: (batch, nr_img, h, w)
x = x.view(-1, 1, h, w)
# x.shape: (batch * nr_img, 1, h, w)
x = self.cnn(x)
# x.shape: (batch * nr_img, cnn_output_dim, h', w')
current_dim = self.spatial_dim
x = x.view(-1, current_dim)
# x.shape: (batch * nr_img * cnn_output_dim, current_dim)
if self.mlp_transform:
x = self.mlp_transform(x)
current_dim = self.transformed_spatial_dim
# x.shape: (batch * nr_img * cnn_output_dim, current_dim)
x = x.view(-1, self.cnn_output_dim, current_dim)
# x.shape: (batch * nr_img, cnn_output_dim, current_dim)
if self.split_channel:
x = x.permute(0, 2, 1).contiguous()
# x.shape: (batch * nr_img, current_dim, cnn_output_dim)
current_dim = self.cnn_output_dim
x = self.shared_group_mlp(x)
# x.shape: (batch * nr_img, output_dim)
x = x.view(-1, nr_images, self.output_dim)
# x.shape: (batch, nr_img, output_dim)
return x
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment