Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
SCL-my
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
YuxuanGuo
SCL-my
Commits
80ee6423
Commit
80ee6423
authored
Mar 10, 2022
by
yuxguo
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix
parent
152d4af0
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
22 additions
and
20 deletions
+22
-20
main.py
+8
-8
model/dataset.py
+2
-2
model/model.py
+5
-5
model/modules.py
+1
-1
model/relation_model.py
+2
-2
model/visual_model.py
+2
-2
run.sh
+2
-0
No files found.
main.py
View file @
80ee6423
...
...
@@ -38,10 +38,10 @@ logging.basicConfig(format='[%(asctime)s, %(levelname)s]: %(message)s', level=lo
# from jactorch.parallel import JacDataParallel
# from jactorch.train import TrainerEnv
from
model.const
import
MAX_VALUE
from
model.const
import
MAX_VALUE
,
ORIGIN_IMAGE_SIZE
# from model.dataset import get_dataset_name_and_num_features, load_data
from
model.dataset
import
RAVENdataset
,
PGMdataset
,
ToTensor
from
model.
utils
import
Model
,
Observer
from
model.
model
import
Model
# from model.train import Trainer
# from model.utils import plot_curve, get_exp_name, get_image_title
...
...
@@ -343,17 +343,17 @@ def get_model(args):
def
get_dataloader
():
if
args
.
dataset
==
'PGM'
:
train
=
PGMdataset
(
args
.
dataset_path
,
"train"
,
args
.
img_size
,
transform
=
transforms
.
Compose
([
ToTensor
()]),
shuffle
=
True
)
valid
=
PGMdataset
(
args
.
dataset_path
,
"val"
,
args
.
img_size
,
transform
=
transforms
.
Compose
([
ToTensor
()]))
test
=
PGMdataset
(
args
.
dataset_path
,
"test"
,
args
.
img_size
,
transform
=
transforms
.
Compose
([
ToTensor
()]))
train
=
PGMdataset
(
args
.
dataset_path
,
"train"
,
ORIGIN_IMAGE_SIZE
,
transform
=
transforms
.
Compose
([
ToTensor
()]),
shuffle
=
True
)
valid
=
PGMdataset
(
args
.
dataset_path
,
"val"
,
ORIGIN_IMAGE_SIZE
,
transform
=
transforms
.
Compose
([
ToTensor
()]))
test
=
PGMdataset
(
args
.
dataset_path
,
"test"
,
ORIGIN_IMAGE_SIZE
,
transform
=
transforms
.
Compose
([
ToTensor
()]))
elif
args
.
dataset
==
'RAVEN'
:
# (I-)RAVEN
args
.
train_figure_configurations
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]
args
.
val_figure_configurations
=
args
.
train_figure_configurations
args
.
test_figure_configurations
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]
train
=
RAVENdataset
(
args
.
dataset_path
,
"train"
,
args
.
train_figure_configurations
,
args
.
img_size
,
transform
=
transforms
.
Compose
([
ToTensor
()]),
shuffle
=
True
)
valid
=
RAVENdataset
(
args
.
dataset_path
,
"val"
,
args
.
val_figure_configurations
,
args
.
img_size
,
transform
=
transforms
.
Compose
([
ToTensor
()]))
test
=
RAVENdataset
(
args
.
dataset_path
,
"test"
,
args
.
test_figure_configurations
,
args
.
img_size
,
transform
=
transforms
.
Compose
([
ToTensor
()]))
train
=
RAVENdataset
(
args
.
dataset_path
,
"train"
,
args
.
train_figure_configurations
,
ORIGIN_IMAGE_SIZE
,
transform
=
transforms
.
Compose
([
ToTensor
()]),
shuffle
=
True
)
valid
=
RAVENdataset
(
args
.
dataset_path
,
"val"
,
args
.
val_figure_configurations
,
ORIGIN_IMAGE_SIZE
,
transform
=
transforms
.
Compose
([
ToTensor
()]))
test
=
RAVENdataset
(
args
.
dataset_path
,
"test"
,
args
.
test_figure_configurations
,
ORIGIN_IMAGE_SIZE
,
transform
=
transforms
.
Compose
([
ToTensor
()]))
trainloader
=
DataLoader
(
train
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
num_workers
=
args
.
load_workers
)
validloader
=
DataLoader
(
valid
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
args
.
load_workers
)
...
...
model/dataset.py
View file @
80ee6423
...
...
@@ -43,7 +43,7 @@ class PGMdataset(Dataset):
resize_image
=
[]
for
idx
in
range
(
0
,
16
):
resize_image
.
append
(
misc
.
imresize
(
image
[
idx
,:,:],
(
self
.
img_size
,
self
.
img_size
)
))
resize_image
.
append
(
misc
.
imresize
(
image
[
idx
,:,:],
self
.
img_size
))
resize_image
=
np
.
stack
(
resize_image
)
if
meta_target
.
dtype
==
np
.
int8
:
...
...
@@ -98,7 +98,7 @@ class RAVENdataset(Dataset):
resize_image
=
[]
for
idx
in
range
(
0
,
16
):
resize_image
.
append
(
misc
.
imresize
(
image
[
idx
,:,:],
(
self
.
img_size
,
self
.
img_size
)
))
resize_image
.
append
(
misc
.
imresize
(
image
[
idx
,:,:],
self
.
img_size
))
resize_image
=
np
.
stack
(
resize_image
)
del
data
...
...
model/model.py
View file @
80ee6423
...
...
@@ -15,13 +15,13 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
logging
from
modules
import
MLPModel
from
.
modules
import
MLPModel
# from .baselines import SimpleModel, SharedModel
from
relation_model
import
RelationModel
from
utils
import
compute_entropy
,
compute_mi
from
visual_model
import
VisualModel
from
const
import
ORIGIN_IMAGE_SIZE
,
MAX_VALUE
from
.
relation_model
import
RelationModel
from
.
utils
import
compute_entropy
,
compute_mi
from
.
visual_model
import
VisualModel
from
.
const
import
ORIGIN_IMAGE_SIZE
,
MAX_VALUE
__all__
=
[
'Model'
]
...
...
model/modules.py
View file @
80ee6423
...
...
@@ -15,7 +15,7 @@ import torch.nn.functional as F
from
torchvision.models.resnet
import
BasicBlock
,
Bottleneck
import
logging
from
const
import
ORIGIN_IMAGE_SIZE
from
.
const
import
ORIGIN_IMAGE_SIZE
__all__
=
[
'FCResBlock'
,
'Expert'
,
'Scorer'
,
'ConvBlock'
,
'ConvNet'
,
...
...
model/relation_model.py
View file @
80ee6423
...
...
@@ -13,8 +13,8 @@ import torch.nn.functional as F
import
logging
from
modules
import
FCResBlock
,
SharedGroupMLP
,
Expert
,
Scorer
,
MLPModel
from
utils
import
transform
from
.
modules
import
FCResBlock
,
SharedGroupMLP
,
Expert
,
Scorer
,
MLPModel
from
.
utils
import
transform
__all__
=
[
'AnalogyModel'
]
...
...
model/visual_model.py
View file @
80ee6423
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
modules
import
ResNetWrapper
,
ConvNet
,
MLPModel
,
SharedGroupMLP
from
const
import
ORIGIN_IMAGE_SIZE
from
.
modules
import
ResNetWrapper
,
ConvNet
,
MLPModel
,
SharedGroupMLP
from
.
const
import
ORIGIN_IMAGE_SIZE
class
VisualModel
(
nn
.
Module
):
...
...
run.sh
0 → 100644
View file @
80ee6423
python main.py
--dataset
"RAVEN"
--dataset_path
"./dataset/I-RAVEN"
-t
center_single left_right up_down in_out distribute_four distribute_nine in_distri
--use-gpu
-v
-lr
0.005
-vlr
0.005
-wd
0.01
-tsd
80
-fg
10
-nf
80
-chd
16 16 32 32
-rb
-erb
-vhd
128
-hd
64 32
-lhd
128
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment