Commit 982474b1 by yuxguo

init

parents
# SCL-my
Rewrite SCL without thirdparty lib.
This diff is collapsed. Click to expand it.
VALUES = [8, 10, 6, 6]
TOTAL_VALUES = sum(VALUES)
MAX_VALUE = max(VALUES)
NUM_ATTRS = len(VALUES)
ORIGIN_IMAGE_SIZE = (160, 160)
\ No newline at end of file
import os
import glob
import numpy as np
from scipy import misc
import torch
from torch.utils.data import Dataset
from torchvision import transforms, utils
class ToTensor(object):
def __call__(self, sample):
return torch.tensor(sample, dtype=torch.float32)
class PGMdataset(Dataset):
def __init__(self, root_dir, dataset_type, img_size, transform=None, shuffle=False):
self.root_dir = root_dir
self.transform = transform
self.file_names = [f for f in glob.glob(os.path.join(root_dir, "*.npz")) if dataset_type in os.path.basename(f)]
self.img_size = img_size
self.shuffle = shuffle
def __len__(self):
return len(self.file_names)
def __getitem__(self, idx):
data_path = self.file_names[idx]
data = np.load(data_path)
image = data["image"].reshape(16, 160, 160)
target = data["target"]
meta_target = data["meta_target"]
if self.shuffle:
context = image[:8, :, :]
choices = image[8:, :, :]
indices = np.arange(8)
np.random.shuffle(indices)
new_target = np.where(indices == target)[0][0]
new_choices = choices[indices, :, :]
image = np.concatenate((context, new_choices))
target = new_target
resize_image = []
for idx in range(0, 16):
resize_image.append(misc.imresize(image[idx,:,:], (self.img_size, self.img_size)))
resize_image = np.stack(resize_image)
if meta_target.dtype == np.int8:
meta_target = meta_target.astype(np.uint8)
del data
if self.transform:
resize_image = self.transform(resize_image)
target = torch.tensor(target, dtype=torch.long)
meta_target = self.transform(meta_target)
return resize_image, target, meta_target
figure_configuration_names = ['center_single', 'distribute_four', 'distribute_nine', 'in_center_single_out_center_single', 'in_distribute_four_out_center_single', 'left_center_single_right_center_single', 'up_center_single_down_center_single']
class RAVENdataset(Dataset):
def __init__(self, root_dir, dataset_type, figure_configurations, img_size, transform=None, shuffle=False):
self.root_dir = root_dir
self.transform = transform
self.file_names = []
for idx in figure_configurations:
tmp = [f for f in glob.glob(os.path.join(root_dir, figure_configuration_names[idx], "*.npz")) if dataset_type in os.path.basename(f)]
self.file_names += tmp
self.img_size = img_size
self.shuffle = shuffle
self.switch = [3,4,5,0,1,2,6,7]
def __len__(self):
return len(self.file_names)
def __getitem__(self, idx):
data_path = self.file_names[idx]
data = np.load(data_path)
image = data["image"].reshape(16, 160, 160)
target = data["target"]
meta_target = data["meta_target"]
if self.shuffle:
context = image[:8, :, :]
choices = image[8:, :, :]
indices = np.arange(8)
np.random.shuffle(indices)
new_target = np.where(indices == target)[0][0]
new_choices = choices[indices, :, :]
switch_2_rows = np.random.rand()
if switch_2_rows < 0.5:
context = context[self.switch, :, :]
image = np.concatenate((context, new_choices))
target = new_target
resize_image = []
for idx in range(0, 16):
resize_image.append(misc.imresize(image[idx,:,:], (self.img_size, self.img_size)))
resize_image = np.stack(resize_image)
del data
if self.transform:
resize_image = self.transform(resize_image)
target = torch.tensor(target, dtype=torch.long)
meta_target = self.transform(meta_target)
return resize_image, target, meta_target
\ No newline at end of file
This diff is collapsed. Click to expand it.
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : model.py
# Author : Honghua Dong, Tony Wu
# Email : dhh19951@gmail.com, tonywu0206@gmail.com
# Date : 11/06/2019
#
# Distributed under terms of the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from modules import FCResBlock, SharedGroupMLP, Expert, Scorer, MLPModel
from utils import transform
__all__ = ['AnalogyModel']
class RelationModel(nn.Module):
def __init__(self,
nr_features,
nr_experts=5,
shared_group_mlp=True,
expert_output_dim=1,
not_use_softmax=False,
embedding_dim=None,
embedding_hidden_dims=[],
enable_residual_block=False,
enable_rb_after_experts=False,
feature_embedding_dim=1,
hidden_dims=[16],
reduction_groups=[3],
sum_as_reduction=0,
lastmlp_hidden_dims=[],
use_ordinary_mlp=False,
nr_context=8,
nr_candidates=8,
collect_inter_key=None):
super().__init__()
self.nr_context = nr_context
self.nr_candidates = nr_candidates
self.nr_images = nr_context + nr_candidates
self.nr_input_images = nr_context + 1
self.nr_experts = nr_experts
self.feature_embedding_dim = feature_embedding_dim
self.not_use_softmax = not_use_softmax
self.nr_features = nr_features
self.nr_candidates = nr_candidates
self.collect_inter_key = collect_inter_key
current_dim = nr_features
self.enable_residual_block = enable_residual_block
if self.enable_residual_block:
self.FCblock = FCResBlock(current_dim)
self.embedding = None
self.embedding_dim = embedding_dim
if embedding_dim is not None:
self.embedding = MLPModel(current_dim, embedding_dim,
hidden_dims=embedding_hidden_dims)
current_dim = embedding_dim
assert feature_embedding_dim > 0
assert current_dim % feature_embedding_dim == 0, (
'feature embedding dim should divs current dim '
'(nr_feature or embedding_dim)')
current_dim = current_dim // feature_embedding_dim
self.nr_ind_features = current_dim
self.group_input_dim = self.nr_input_images * feature_embedding_dim
# experts = [Expert(self.group_input_dim, hidden_dims)
# for i in range(nr_experts)]
# self.experts = nn.ModuleList(experts)
assert expert_output_dim == 1, 'only supports expert_output_dim == 1'
groups = current_dim
# group_size = feature_embedding_dim
group_output_dim = expert_output_dim
if use_ordinary_mlp:
groups = 1
# group_size = current_dim * feature_embedding_dim
self.group_input_dim *= current_dim
group_output_dim = self.nr_ind_features * expert_output_dim
self.experts = SharedGroupMLP(
groups=groups,
group_input_dim=self.group_input_dim,
group_output_dim=group_output_dim,
hidden_dims=hidden_dims,
add_res_block=enable_rb_after_experts,
nr_mlps=nr_experts,
shared=shared_group_mlp)
self.scorer = Scorer(current_dim, nr_experts,
hidden_dims=lastmlp_hidden_dims,
reduction_groups=reduction_groups,
sum_as_reduction=sum_as_reduction)
def forward(self, x):
current_dim = self.nr_features
x = x.view(-1, self.nr_images, current_dim).float()
# x.shape: (batch, nr_images, current_dim)
if self.enable_residual_block:
x = self.FCblock(x)
if self.embedding:
x = x.view(-1, current_dim)
# x.shape: (batch * nr_images, current_dim)
x = self.embedding(x)
# x.shape: (batch * nr_images, embedding_dims)
current_dim = self.embedding_dim
x = x.view(-1, self.nr_images, current_dim)
# x.shape: (batch, nr_images, current_dim)
fe_dim = self.feature_embedding_dim
ind_features = x
# The $x$ is the extracted features from inputs
# And the scorers regard $x$ as indenpent features
x = transform(x,
nr_context=self.nr_context, nr_candidates=self.nr_candidates)
# x.shape: (batch, nr_candidates, nr_context + 1, current_dim)
nr_ind_features = self.nr_ind_features
# Using SharedGroupMLP
nr_input_images = self.nr_input_images
x = x.view(-1, nr_input_images, current_dim)
# x.shape: (batch * nr_candidates, nr_input_images, current_dim)
# experts: split groups over the last dim, with group size fed
# each group corresponding to a feature
container = None
inter_layer = 0
ci_key = self.collect_inter_key
if ci_key is not None and ci_key.startswith('sgm_inter'):
container = []
inter_layer = int(ci_key[-1])
latent_logits = self.experts(x,
inter_results_container=container, inter_layer=inter_layer)
latent_logits = latent_logits.view(-1, nr_ind_features, self.nr_experts)
# latent_logits.shape: (batch * nr_candidates,
# nr_ind_features * nr_experts * expert_output_dim)
# # Using Expert
# x = x.view(-1, self.nr_candidates, nr_input_images, nr_ind_features, fe_dim)
# # x.shape: (batch, nr_candidates, nr_input_images, nr_ind_features, fe_dim)
# x = x.permute(0, 1, 3, 2, 4).contiguous()
# # x.shape: (batch, nr_candidates, nr_ind_features, nr_input_images, fe_dim)
# x = x.view(-1, self.group_input_dim)
# # x.shape: (batch * nr_candidates * nr_ind_features, group_input_dim)
# latent_logits = torch.cat([
# expert(x) for expert in self.experts], dim=-1)
# # latent_logits/x.shape: (batch * nr_candidates * nr_ind_features, nr_experts)
if self.not_use_softmax:
x = latent_logits
else:
x = F.softmax(latent_logits, dim=-1)
latent_logits = latent_logits.view(-1,
self.nr_candidates, nr_ind_features, self.nr_experts)
# x.shape: (batch * nr_candidates * nr_ind_features, nr_experts)
x = x.view(-1, nr_ind_features, self.nr_experts)
# x.shape: (batch * nr_candidates, nr_ind_features, nr_experts)
x = self.scorer(x)
x = x.view(-1, self.nr_candidates)
results = dict(logits=x,
latent_logits=latent_logits,
ind_features=ind_features)
if ci_key is not None and ci_key.startswith('sgm_inter'):
sgm_inter = torch.cat(container, dim=-1)
num = sgm_inter.size(-1)
sgm_inter = sgm_inter.view(
-1, self.nr_candidates, nr_ind_features, num)
results[ci_key] = sgm_inter
return results
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : utils.py
# Author : Honghua Dong, Tony Wu
# Email : dhh19951@gmail.com, tonywu0206@gmail.com
# Date : 11/06/2019
#
# Distributed under terms of the MIT license.
# TODO: compute mutual information instead of entropy
import torch
import torch.nn.functional as F
__all__ = ['compute_entropy', 'compute_mi', 'transform', 'vis_transform']
def compute_mi(logits, eps=1e-8):
# logits.shape: (batch, current_dim, nr_experts)
logits = logits.permute(1, 0, 2).contiguous()
# logits.shape: (current_dim, batch, nr_experts)
policy = F.softmax(logits, dim=-1)
log_policy = F.log_softmax(logits, dim=-1)
entropy = -(policy * log_policy).sum(dim=-1)
H_expert_given_x = entropy.mean(dim=-1)
avg_policy = policy.mean(dim=-2)
log_avg_policy = (avg_policy + eps).log()
H_expert = -(avg_policy * log_avg_policy).sum(dim=-1)
return H_expert - H_expert_given_x
def compute_entropy(logits):
policy = F.softmax(logits, dim=-1)
log_policy = F.log_softmax(logits, dim=-1)
entropy = -(policy * log_policy).sum(dim=-1)
return entropy.mean()
'''
To fill the candidates in and regard them as batch
'''
def transform(inputs, nr_context=8, nr_candidates=8):
context = inputs.narrow(1, 0, nr_context)
# context.shape: (batch, nr_context, nr_features)
candidates = inputs.narrow(1, nr_context, nr_candidates)
# candidates.shape: (batch, nr_candidates, nr_features)
context = context.unsqueeze(1)
# context.shape: (batch, 1, nr_context, nr_features)
context = context.expand(-1, nr_candidates, -1, -1)
# context.shape: (batch, nr_candidates, nr_context, nr_features)
candidates = candidates.unsqueeze(2)
# candidates.shape: (batch, nr_candidates, 1, nr_features)
merged = torch.cat([context, candidates], dim=2)
# merged.shape: (batch, nr_candidates, nr_context + 1, nr_features)
return merged
def vis_transform(inputs, nr_context=8, nr_candidates=8):
context = inputs.narrow(1, 0, nr_context)
# context.shape: (batch, nr_context, IMG_SIZE, IMG_SIZE)
candidates = inputs.narrow(1, nr_context, nr_candidates)
# candidates.shape: (batch, nr_candidates, IMG_SIZE, IMG_SIZE)
context = context.unsqueeze(1)
# context.shape: (batch, 1, nr_context, IMG_SIZE, IMG_SIZE)
context = context.expand(-1, nr_candidates, -1, -1, -1)
# context.shape: (batch, nr_candidates, nr_context, IMG_SIZE, IMG_SIZE)
candidates = candidates.unsqueeze(2)
# candidates.shape: (batch, nr_candidates, 1, IMG_SIZE, IMG_SIZE)
merged = torch.cat([context, candidates], dim=2)
# merged.shape: (batch, nr_candidates, nr_context + 1, IMG_SIZE, IMG_SIZE)
return merged
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import ResNetWrapper, ConvNet, MLPModel, SharedGroupMLP
from const import ORIGIN_IMAGE_SIZE
class VisualModel(nn.Module):
def __init__(self,
conv_hidden_dims,
output_dim,
input_dim=1,
use_resnet=False,
conv_repeats=None,
conv_kernels=3,
conv_residual_link=True,
transformed_spatial_dim=None,
mlp_transform_hidden_dims=[],
image_size=ORIGIN_IMAGE_SIZE,
use_layer_norm=False,
shared_group_mlp=True,
nr_visual_experts=1,
mlp_hidden_dims=[],
groups=1,
split_channel=False,
):
super().__init__()
if use_resnet:
self.cnn = ResNetWrapper(
repeats=conv_repeats,
inplanes=conv_hidden_dims[0],
channels=conv_hidden_dims,
image_size=image_size)
else:
self.cnn = ConvNet(
input_dim=input_dim,
hidden_dims=conv_hidden_dims,
repeats=conv_repeats,
kernels=conv_kernels,
residual_link=conv_residual_link,
image_size=image_size,
use_layer_norm=use_layer_norm)
# self.cnn_output_size = self.cnn.output_size
self.cnn_output_dim = self.cnn.output_dim
h, w = self.cnn.output_image_size
current_dim = h * w
self.spatial_dim = current_dim
self.transformed_spatial_dim = transformed_spatial_dim
self.mlp_transform = None
if transformed_spatial_dim is not None and transformed_spatial_dim > 0:
self.mlp_transform = MLPModel(
current_dim, transformed_spatial_dim,
hidden_dims=mlp_transform_hidden_dims)
current_dim = transformed_spatial_dim
total_dim = self.cnn_output_dim * current_dim
self.split_channel = split_channel
if split_channel:
current_dim = self.cnn_output_dim
assert current_dim % groups == 0, ('the spatial dim {} should be '
'divided by the number of groups {}').format(current_dim, groups)
assert output_dim % (groups * nr_visual_experts) == 0, (
'the output dim {} should be divided by the prod of number of '
'groups {} and the number of visual experts {}').format(
output_dim, groups, nr_visual_experts)
self.shared_group_mlp = SharedGroupMLP(
groups=groups,
group_input_dim=total_dim // groups,
group_output_dim=output_dim // (groups * nr_visual_experts),
hidden_dims=mlp_hidden_dims,
nr_mlps=nr_visual_experts,
shared=shared_group_mlp)
self.output_dim = output_dim
def forward(self, x):
x = x.float().contiguous()
nr_images, h, w = x.size()[1:]
# x.shape: (batch, nr_img, h, w)
x = x.view(-1, 1, h, w)
# x.shape: (batch * nr_img, 1, h, w)
x = self.cnn(x)
# x.shape: (batch * nr_img, cnn_output_dim, h', w')
current_dim = self.spatial_dim
x = x.view(-1, current_dim)
# x.shape: (batch * nr_img * cnn_output_dim, current_dim)
if self.mlp_transform:
x = self.mlp_transform(x)
current_dim = self.transformed_spatial_dim
# x.shape: (batch * nr_img * cnn_output_dim, current_dim)
x = x.view(-1, self.cnn_output_dim, current_dim)
# x.shape: (batch * nr_img, cnn_output_dim, current_dim)
if self.split_channel:
x = x.permute(0, 2, 1).contiguous()
# x.shape: (batch * nr_img, current_dim, cnn_output_dim)
current_dim = self.cnn_output_dim
x = self.shared_group_mlp(x)
# x.shape: (batch * nr_img, output_dim)
x = x.view(-1, nr_images, self.output_dim)
# x.shape: (batch, nr_img, output_dim)
return x
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment