Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
R
REST
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
lvzhengyang
REST
Commits
312301fd
Commit
312301fd
authored
Sep 09, 2022
by
Yijun Tan
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
垃圾代码,没有reset中间state调了一天
parent
c01b9271
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
793 additions
and
0 deletions
+793
-0
a22c2.py
+274
-0
env2.py
+213
-0
log.py
+32
-0
model22.py
+274
-0
No files found.
a22c2.py
0 → 100644
View file @
312301fd
from
tkinter
import
W
import
numpy
as
np
import
matplotlib.pyplot
as
plt
from
itertools
import
count
from
collections
import
namedtuple
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.nn.functional
as
F
from
torch.distributions
import
Categorical
from
env2
import
*
from
log
import
*
from
model22
import
*
env
=
Grid
()
learning_rate
=
0.01
gamma
=
0.99
episodes
=
20000
render
=
True
eps
=
np
.
finfo
(
np
.
float32
)
.
eps
.
item
()
SavedAction
=
namedtuple
(
'SavedAction'
,
[
'log_prob'
,
'value'
])
n
=
5
LOG
=
1
class
Policy
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Policy
,
self
)
.
__init__
()
self
.
fc1
=
nn
.
Linear
(
2
*
n
,
32
)
#debug(self.fc1.weight)
self
.
actor_embedding
=
ResEncoder
(
10
,
2
,
128
,
16
,
512
,
3
)
#self.critic_embedding = ResEncoder(10, 2, 128, 16, 512, 3)
self
.
actor_decoder
=
ResDecoder
(
128
,
360
)
'''
self.actoru = nn.Sequential(
nn.Linear(32, n)
)
self.actorw = nn.Sequential(
nn.Linear(32, 2*n)
)
'''
self
.
critic_embedding
=
ResEncoder
(
10
,
2
,
128
,
16
,
512
,
3
)
self
.
g
=
nn
.
Linear
(
128
,
1
,
bias
=
False
)
self
.
critic
=
nn
.
Sequential
(
nn
.
Linear
(
128
,
256
,
bias
=
True
),
nn
.
ReLU
(),
nn
.
Linear
(
256
,
1
,
bias
=
True
)
)
self
.
save_actions
=
[]
self
.
rewards
=
[]
def
get_actor_embedding
(
self
,
x
):
#debug("embedding: ", x)
#x = self.fc1(x)
#debug("embedding: ", self.fc1.weight)
#x = F.relu(x)
#action_score = self.actoru(x)
#state_value = self.critic(x)
#debug("embedding: ", x)
am
=
torch
.
zeros
(
n
,
n
)
kpm
=
torch
.
zeros
(
1
,
n
)
actor_emb
=
self
.
actor_embedding
(
x
,
am
,
kpm
)
self
.
actor_emb
=
actor_emb
return
actor_emb
def
get_critic
(
self
,
x
):
am
=
torch
.
zeros
(
n
,
n
)
kpm
=
torch
.
zeros
(
1
,
n
)
#debug(self.critic_embedding)
critic_emb
=
self
.
critic_embedding
(
x
,
am
,
kpm
)
debug
(
critic_emb
.
shape
)
glimpse
=
self
.
g
(
torch
.
tanh
(
critic_emb
))
.
view
(
1
,
n
)
debug
(
glimpse
.
shape
)
glimpse
=
torch
.
softmax
(
glimpse
,
dim
=
1
)
.
unsqueeze
(
1
)
glimpse
=
torch
.
bmm
(
glimpse
,
critic_emb
)
debug
(
glimpse
.
shape
)
x
=
self
.
critic
(
glimpse
)
return
x
def
getu
(
self
,
visited
):
#x = nn.ReLU(self.fc1(x))
#debug("get u: ", x)
#x = self.actoru(x)
#debug("get u: ", x)
#debug("get u: ", x)
x
=
self
.
actor_decoder
.
getu
(
self
.
actor_emb
,
visited
)
u_prob
=
torch
.
where
(
visited
,
-
1e6
*
torch
.
ones_like
(
x
),
x
)
return
F
.
softmax
(
u_prob
,
dim
=-
1
)
def
getw
(
self
,
u
,
unvisited
):
#x = nn.ReLU(self.fc1(x))
#x = self.actorw(x)
debug
(
"getting w"
)
debug
(
u
)
x
=
self
.
actor_decoder
.
getw
(
self
.
actor_emb
,
u
,
unvisited
)
debug
(
x
.
shape
,
x
)
debug
(
unvisited
)
w_prob
=
torch
.
where
(
unvisited
,
-
1e6
*
torch
.
ones_like
(
x
),
x
)
return
F
.
softmax
(
w_prob
,
dim
=-
1
)
def
update_decoder
(
self
,
u
,
w
=-
1
,
v
=-
1
,
h
=-
1
):
self
.
actor_decoder
.
update
(
self
.
actor_emb
,
u
,
w
,
v
,
h
)
model
=
Policy
()
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
learning_rate
)
for
i
in
model
.
parameters
():
debug
(
i
.
shape
)
def
select_action
(
state
):
#debug(state)
actions
=
[]
points
=
torch
.
tensor
(
state
[
0
])
/
10
points
=
points
.
view
(
1
,
n
,
2
)
#points = torch.flatten(points)
visited
=
torch
.
tensor
(
state
[
1
])
.
to
(
torch
.
bool
)
visited
=
torch
.
logical_not
(
torch
.
tensor
([
True
,
False
,
False
,
False
,
False
]))
debug
(
points
)
#debug(visited, unvisited)
points
=
torch
.
tensor
(
points
)
.
float
()
model
.
actor_decoder
.
reset_state
()
for
i
in
range
(
n
-
1
):
#embedding = model(points)
#model.get_actor_embedding(points)
model
.
actor_emb
=
torch
.
rand
([
1
,
5
,
128
])
#select u
#debug(points)
probs
=
model
.
getu
(
visited
)
p
=
Categorical
(
probs
)
u
=
p
.
sample
()
#debug("u: ", u, probs, visited)
#debug()
visited
[
u
]
=
False
#debug("u: ", u)
#select w
unvisited
=
torch
.
logical_not
(
visited
)
unvisited
=
torch
.
cat
((
unvisited
,
unvisited
),
dim
=
0
)
probs
=
model
.
getw
(
u
,
unvisited
)
p
=
Categorical
(
probs
)
w0
=
p
.
sample
()
#debug("w: ", w, probs, unvisited)
#debug()
if
w0
<
n
:
v
=
u
h
=
w0
w
=
w0
else
:
w
=
w0
-
n
v
=
w
h
=
u
visited
[
w
]
=
False
if
i
<
n
-
2
:
model
.
update_decoder
(
u
,
w
,
v
,
h
)
actions
.
append
([
u
,
w
,
v
,
h
])
debug
(
points
)
state_value
=
model
.
get_critic
(
points
)
state_value
=
state_value
.
view
(
1
)
debug
(
state_value
)
model
.
save_actions
.
append
(
SavedAction
(
p
.
log_prob
(
w0
),
state_value
))
#debug(actions)
return
actions
def
finish_episode
():
R
=
0
save_actions
=
model
.
save_actions
policy_loss
=
[]
value_loss
=
[]
rewards
=
[]
for
r
in
model
.
rewards
[::
-
1
]:
R
=
r
+
gamma
*
R
rewards
.
insert
(
0
,
R
)
rewards
=
torch
.
tensor
(
rewards
)
#rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
debug
()
debug
(
"training"
)
debug
(
save_actions
)
debug
(
rewards
)
for
(
log_prob
,
value
),
r
in
zip
(
save_actions
,
rewards
):
reward
=
r
-
value
.
view
(
-
1
)
.
item
()
policy_loss
.
append
(
-
log_prob
*
reward
)
value_loss
.
append
(
F
.
smooth_l1_loss
(
value
,
torch
.
tensor
([
r
])))
#debug(policy_loss)
#debug(value_loss)
optimizer
.
zero_grad
()
loss
=
torch
.
stack
(
policy_loss
)
.
sum
()
+
torch
.
stack
(
value_loss
)
.
sum
()
#debug(loss)
loss
.
backward
(
retain_graph
=
True
)
optimizer
.
step
()
#optimizer.zero_grad()
#debug("policy_loss", policy_loss)
#loss = policy_loss[0]
#loss.backward(retain_graph=True)
#optimizer.step()
del
model
.
rewards
[:]
del
model
.
save_actions
[:]
def
main
():
for
i_episode
in
count
(
episodes
):
state
=
env
.
reset
()
debug
(
state
)
actions
=
select_action
(
state
)
debug
(
"actions: "
,
actions
)
state
,
reward
,
done
=
env
.
multistep
(
actions
)
model
.
rewards
.
append
(
reward
)
#env.render()
finish_episode
()
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
env2.py
0 → 100644
View file @
312301fd
from
email.header
import
Header
from
re
import
L
import
torch
import
random
import
numpy
as
np
import
os
from
log
import
*
HEIGHT
,
WIDTH
=
[
8
,
8
]
OBSTACLE
=
1000
BLANK
=
0
LINE
=
1
NODE
=
2
OBSTACLE
=
3
class
Interval
:
def
__init__
(
self
,
low
,
high
):
self
.
low
=
min
(
low
,
high
)
self
.
high
=
max
(
low
,
high
)
def
intersect
(
self
,
inter
):
return
max
(
self
.
low
,
inter
.
low
)
<
min
(
self
.
high
,
inter
.
high
)
def
merge
(
self
,
inter
):
assert
self
.
intersect
(
inter
)
self
.
low
=
min
(
self
.
low
,
inter
.
low
)
self
.
high
=
max
(
self
.
high
,
inter
.
high
)
def
log
(
a
,
k
=
0
):
print
(
"
\033
[3"
,
k
,
"m"
,
a
,
"
\033
[0m"
,
end
=
""
,
sep
=
""
)
class
Point
:
def
__init__
(
self
,
x
=-
1
,
y
=-
1
):
self
.
x
=
x
self
.
y
=
y
class
Grid
:
def
__init__
(
self
,
h
=
10
,
w
=
10
,
n
=
5
):
self
.
h
=
h
self
.
w
=
w
self
.
n
=
n
self
.
grid
=
[[
BLANK
]
*
self
.
w
for
i
in
range
(
self
.
h
)]
self
.
hlines
=
[[]]
*
h
self
.
wlines
=
[[]]
*
w
self
.
steps
=
0
random
.
seed
(
1
)
#self.generate_nodes()
def
reset
(
self
):
self
.
steps
=
0
self
.
grid
=
[[
BLANK
]
*
self
.
w
for
i
in
range
(
self
.
h
)]
self
.
hlines
=
[[]]
*
self
.
h
self
.
wlines
=
[[]]
*
self
.
w
self
.
generate_nodes
()
state
=
[[[
p
.
x
,
p
.
y
]
for
p
in
self
.
nodes
],
self
.
nodes_used
]
#print(state)
return
state
def
generate_nodes
(
self
):
random
.
seed
(
1
)
self
.
hashnodes
=
set
()
while
len
(
self
.
hashnodes
)
<
self
.
n
:
x
=
random
.
randint
(
0
,
self
.
h
*
self
.
w
-
1
)
self
.
hashnodes
.
add
(
x
)
#print(x)
self
.
hashnodes
=
[
32
,
97
,
72
,
17
,
8
]
self
.
nodes
=
[]
for
i
in
self
.
hashnodes
:
self
.
nodes
.
append
(
Point
(
i
//
self
.
w
,
i
%
self
.
w
))
#print(self.nodes[-1].x, self.nodes[-1].y)
self
.
grid
[
self
.
nodes
[
-
1
]
.
x
][
self
.
nodes
[
-
1
]
.
y
]
=
NODE
self
.
nodes_used
=
[
0
]
*
self
.
n
def
occupy
(
self
,
x
,
y
):
if
self
.
grid
[
x
][
y
]
==
BLANK
:
self
.
grid
[
x
][
y
]
=
LINE
def
connect
(
self
,
start
,
end
,
dir
):
#assert self.nodes_used[start] == 1
#assert self.nodes_used[end] == 0
self
.
nodes_used
[
start
]
=
1
self
.
nodes_used
[
end
]
=
1
n0
=
self
.
nodes
[
start
]
n1
=
self
.
nodes
[
end
]
inter0
=
Interval
(
n0
.
x
,
n1
.
x
)
inter1
=
Interval
(
n0
.
y
,
n1
.
y
)
if
dir
:
#mid = Point(n0.x, n1.y)
self
.
hlines
[
n0
.
y
]
.
append
(
inter0
)
######################### merge #############
self
.
wlines
[
n1
.
x
]
.
append
(
inter1
)
for
i
in
range
(
inter0
.
low
,
inter0
.
high
+
1
,
1
):
self
.
occupy
(
i
,
n0
.
y
)
for
i
in
range
(
inter1
.
low
,
inter1
.
high
+
1
,
1
):
self
.
occupy
(
n1
.
x
,
i
)
else
:
self
.
hlines
[
n1
.
y
]
.
append
(
inter0
)
######################### merge #############
self
.
wlines
[
n0
.
x
]
.
append
(
inter1
)
for
i
in
range
(
inter0
.
low
,
inter0
.
high
+
1
,
1
):
self
.
occupy
(
i
,
n1
.
y
)
for
i
in
range
(
inter1
.
low
,
inter1
.
high
+
1
,
1
):
self
.
occupy
(
n0
.
x
,
i
)
def
wire_length
(
self
):
res
=
0
for
i
in
range
(
self
.
h
):
for
j
in
range
(
self
.
w
):
res
+=
(
self
.
grid
[
i
][
j
]
!=
BLANK
)
return
res
-
1
def
finish
(
self
):
return
self
.
sum
(
self
.
nodes_used
)
==
self
.
n
def
step
(
self
,
start
,
end
,
dir
):
self
.
connect
(
start
,
end
,
dir
)
if
self
.
finish
():
reward
=
-
self
.
wire_length
()
done
=
1
else
:
reward
=
0
done
=
0
state
=
[[[
p
.
x
,
p
.
y
]
for
p
in
self
.
nodes
],
self
.
nodes_used
]
return
state
,
reward
,
done
def
multistep
(
self
,
actions
):
for
i
in
actions
:
dir
=
int
(
i
[
0
]
==
i
[
1
])
self
.
connect
(
i
[
0
],
i
[
1
],
dir
)
reward
=
-
self
.
wire_length
()
state
=
[[[
p
.
x
,
p
.
y
]
for
p
in
self
.
nodes
],
self
.
nodes_used
]
return
state
,
reward
,
1
def
render
(
self
):
#os.system("cls")
print
()
for
i
in
range
(
self
.
h
):
log
(
str
(
self
.
h
-
i
-
1
)
+
" "
,
4
)
for
j
in
range
(
self
.
w
):
log
(
"o "
,
self
.
grid
[
self
.
h
-
i
-
1
][
j
])
print
()
print
()
log
(
" "
)
for
i
in
range
(
self
.
w
):
log
(
str
(
i
)
+
" "
,
4
)
print
()
for
i
in
range
(
self
.
n
):
print
(
i
,
"("
,
self
.
nodes
[
i
]
.
x
,
self
.
nodes
[
i
]
.
y
,
"): "
,
self
.
nodes_used
[
i
])
def
finish
(
self
):
#assert sum(self.nodes_used) == self.steps+1
return
self
.
steps
+
1
==
self
.
n
def
step
(
self
,
start
,
end
,
dir
):
self
.
steps
+=
1
self
.
connect
(
start
,
end
,
dir
)
if
self
.
finish
():
reward
=
-
self
.
wire_length
()
done
=
1
else
:
reward
=
0
done
=
0
return
np
.
array
(
self
.
nodes_used
),
reward
,
done
,
None
if
__name__
==
"__main__"
:
env
=
Grid
()
print
(
env
.
reset
())
#_, reward, done = env.step(0, 1, 1)
#_, reward, done = env.step(1, 2, 0)
#_, reward, done = env.step(2, 3, 0)
#_, reward, done = env.step(3, 4, 0)
#print(reward, done)
actions
=
[[
0
,
1
,
1
,
0
],
[
1
,
2
,
1
,
2
],
[
2
,
3
,
2
,
3
],
[
3
,
4
,
3
,
4
],
]
_
,
reward
,
done
=
env
.
multistep
(
actions
)
env
.
render
()
\ No newline at end of file
log.py
0 → 100644
View file @
312301fd
INFO
=
0
DISPLAY
=
1
DEBUG
=
2
LOG_LEVEL
=
2
def
log
(
a
,
k
=
0
):
print
(
"
\033
[3"
,
k
,
"m"
,
a
,
"
\033
[0m"
,
end
=
""
,
sep
=
""
)
def
debug
(
*
v
):
if
debug
:
print
(
" "
.
join
(
v
))
import
builtins
as
__builtin__
def
info
(
*
args
,
**
kwargs
):
if
LOG_LEVEL
>=
INFO
:
return
__builtin__
.
print
(
*
args
,
**
kwargs
)
def
display
(
*
args
,
**
kwargs
):
if
LOG_LEVEL
>=
DISPLAY
:
return
__builtin__
.
print
(
*
args
,
**
kwargs
)
def
debug
(
*
args
,
**
kwargs
):
if
LOG_LEVEL
>=
DEBUG
:
return
__builtin__
.
print
(
*
args
,
**
kwargs
)
if
__name__
==
"__main__"
:
debug
(
"1"
,
"2"
,
"3"
)
\ No newline at end of file
model22.py
0 → 100644
View file @
312301fd
from
collections
import
namedtuple
import
tkinter
as
tk
import
torch
import
torch.nn
as
nn
from
log
import
*
class
ResEncoder
(
nn
.
Module
):
def
__init__
(
self
,
max_n
,
gdim
=
2
,
edim
=
128
,
nheads
=
16
,
hdim
=
512
,
num_layer
=
3
):
super
(
ResEncoder
,
self
)
.
__init__
()
self
.
max_n
=
max_n
self
.
edim
=
edim
self
.
nheads
=
nheads
self
.
hdim
=
hdim
self
.
num_layer
=
num_layer
self
.
fc0
=
nn
.
Linear
(
gdim
,
edim
)
#self.key_padding_mask = torch.zeros(batch_size, 50)
#self.attn_mask = torch.zeros(50, 50)
if
torch
.
__version__
[
0
:
3
]
==
"1.8"
:
self
.
encoder_layer
=
nn
.
TransformerEncoderLayer
(
d_model
=
edim
,
nhead
=
nheads
,
dim_feedforward
=
hdim
,
dropout
=
0
)
#norm_first, batch_first=True)
else
:
self
.
encoder_layer
=
nn
.
TransformerEncoderLayer
(
d_model
=
edim
,
nhead
=
nheads
,
dim_feedforward
=
hdim
,
dropout
=
0
,
norm_first
=
True
,
batch_first
=
True
)
self
.
encoder
=
nn
.
TransformerEncoder
(
self
.
encoder_layer
,
num_layer
)
def
forward
(
self
,
x
,
attn_mask
,
key_padding_mask
):
x
=
self
.
fc0
(
x
)
#print(x)
#attn_mask = torch.zeros(self.max_n, self.max_n)
#key_padding_mask = torch.zeros(10, self.50)
x
=
self
.
encoder
(
x
,
mask
=
attn_mask
,
src_key_padding_mask
=
key_padding_mask
)
return
x
class
PTM
(
nn
.
Module
):
def
__init__
(
self
,
edim
,
qdim
):
super
(
PTM
,
self
)
.
__init__
()
self
.
edim
=
qdim
self
.
qdim
=
qdim
self
.
efc
=
nn
.
Linear
(
edim
,
qdim
,
bias
=
False
)
self
.
qfc
=
nn
.
Linear
(
qdim
,
qdim
,
bias
=
False
)
self
.
th0
=
nn
.
Tanh
()
self
.
gfc
=
nn
.
Linear
(
qdim
,
1
,
bias
=
False
)
self
.
th1
=
nn
.
Tanh
()
def
forward
(
self
,
e
,
q
,
visited
=
None
):
#print(e.shape, q.shape)
e
=
self
.
efc
(
e
)
q
=
self
.
qfc
(
q
)
e
=
self
.
gfc
(
self
.
th0
(
e
+
q
))
#print(e.shape)
#print("visited: ", visited.shape)
e
=
torch
.
squeeze
(
e
)
#if visited != None:
# l = torch.where(visited, -1e6 * torch.ones_like(e), e)
#print(l.shape)
l
=
e
l
=
self
.
th1
(
l
)
*
10.0
#l = torch.softmax(l, dim=1)
return
l
State
=
namedtuple
(
"State"
,
[
"edge"
,
"subtree"
,
"eu"
,
"ew"
,
"ev"
,
"eh"
])
class
ResDecoder
(
nn
.
Module
):
def
__init__
(
self
,
edim
,
qdim
):
super
(
ResDecoder
,
self
)
.
__init__
()
self
.
edim
=
edim
self
.
qdim
=
qdim
self
.
ptm
=
PTM
(
edim
,
qdim
)
self
.
eptm0
=
PTM
(
edim
,
qdim
)
self
.
eptm1
=
PTM
(
edim
,
qdim
)
Edge
=
namedtuple
(
"Edge"
,
[
"ufc"
,
"wfc"
,
"vfc"
,
"hfc"
])
self
.
edgefc
=
Edge
(
nn
.
Linear
(
edim
,
qdim
),
nn
.
Linear
(
edim
,
qdim
),
nn
.
Linear
(
edim
,
qdim
),
nn
.
Linear
(
edim
,
qdim
),
)
self
.
subtfc
=
nn
.
Linear
(
qdim
,
qdim
)
self
.
fc_w5
=
nn
.
Linear
(
edim
,
qdim
)
self
.
state
=
[]
self
.
state
.
append
(
State
(
torch
.
zeros
(
self
.
qdim
),
torch
.
zeros
(
self
.
qdim
),
-
1
,
-
1
,
-
1
,
-
1
))
#self.subtree = torch.zeros(self.qdim)
def
reset_state
(
self
):
self
.
state
=
[]
self
.
state
.
append
(
State
(
torch
.
zeros
(
self
.
qdim
),
torch
.
zeros
(
self
.
qdim
),
-
1
,
-
1
,
-
1
,
-
1
))
def
getu
(
self
,
e
,
visited
,
step
=
0
):
qu
=
torch
.
max
(
self
.
state
[
-
1
]
.
edge
+
self
.
state
[
-
1
]
.
subtree
,
torch
.
zeros_like
(
self
.
state
[
-
1
]
.
edge
))
#debug(qu)
u
=
self
.
ptm
(
e
,
qu
,
visited
)
return
u
def
getw
(
self
,
e
,
u
,
visited
,
step
=
1
):
qw
=
torch
.
max
(
torch
.
zeros_like
(
self
.
state
[
-
1
]
.
edge
),
self
.
state
[
-
1
]
.
edge
+
self
.
state
[
-
1
]
.
subtree
+
self
.
fc_w5
(
e
[:,
self
.
state
[
-
1
]
.
eu
,
:]))
w0
=
self
.
eptm0
(
e
,
qw
,
visited
)
w1
=
self
.
eptm1
(
e
,
qw
,
visited
)
w
=
torch
.
cat
([
w0
,
w1
])
.
view
(
1
,
-
1
)
return
w
def
update
(
self
,
e
,
u
,
w
=-
1
,
v
=-
1
,
h
=-
1
):
edg
=
self
.
edgefc
.
ufc
(
e
[:,
u
,
:]
.
clone
())
subtr
=
torch
.
max
(
self
.
state
[
-
1
]
.
subtree
,
self
.
subtfc
(
edg
))
self
.
state
.
append
(
State
(
edge
=
edg
,
subtree
=
subtr
,
eu
=
u
,
ew
=
w
,
ev
=
v
,
eh
=
h
))
return
self
.
state
[
-
1
]
if
__name__
==
"__main__"
:
### encoder test
print
(
"encoder test"
)
a
=
torch
.
zeros
([
10
,
50
,
2
])
if
torch
.
__version__
[
0
:
3
]
==
"1.8"
:
a
=
torch
.
zeros
([
50
,
10
,
2
])
a
=
a
print
(
a
.
shape
)
RE
=
ResEncoder
(
10
,
2
,
128
,
16
,
512
,
3
)
amask
=
torch
.
zeros
(
50
,
50
)
kpmask
=
torch
.
zeros
(
10
,
50
)
b
=
RE
(
a
,
amask
,
kpmask
)
print
(
b
.
shape
)
'''
a = torch.ones(2, 2)
b = torch.ones(2, 4)
ptm = PTM(2, 4)
c = ptm(a, b, torch.tensor([0, 1, 0, 1])>0)
print(c)
'''
### decoder test
a
=
torch
.
zeros
([
1
,
50
,
2
])
RE
=
ResEncoder
(
10
,
2
,
16
,
16
,
16
,
1
)
RD
=
ResDecoder
(
16
,
32
)
amask
=
torch
.
zeros
(
50
,
50
)
kpmask
=
torch
.
zeros
(
1
,
50
)
embedding
=
RE
(
a
,
amask
,
kpmask
)
embedding
=
torch
.
squeeze
(
embedding
)
print
(
"embedding: "
,
embedding
.
shape
)
optimizer
=
torch
.
optim
.
Adam
(
RD
.
parameters
()
+
RE
.
parameters
(),
lr
=
0.1
)
### decoder step 0
print
(
"decoder step 0: "
)
from
torch.distributions
import
Categorical
visited
=
torch
.
zeros
([
50
])
.
to
(
torch
.
bool
)
u_prob
=
RD
.
getu
(
embedding
,
visited
)
#print(u_prob.shape)
u_prob
=
torch
.
softmax
(
u_prob
,
dim
=
0
)
#print(u_prob.shape)
catu
=
Categorical
(
u_prob
)
u
=
catu
.
sample
()
print
(
u
)
state
=
RD
.
update
(
embedding
,
u
)
print
(
state
)
### decoder step 1
print
(
"decoder step 1: "
)
u_prob
=
RD
.
getu
(
embedding
,
visited
)
u_prob
=
torch
.
softmax
(
u_prob
,
dim
=
0
)
catu
=
Categorical
(
u_prob
)
u
=
catu
.
sample
()
print
(
u
)
w_prob
=
RD
.
getw
(
embedding
,
u
,
visited
)
w_prob
=
torch
.
softmax
(
w_prob
,
dim
=
0
)
print
(
w_prob
.
shape
)
catw
=
Categorical
(
w_prob
)
w
=
catw
.
sample
()
print
(
w
)
if
w
<
50
:
v
=
u
h
=
w
else
:
w
=
w
-
50
v
=
w
h
=
u
state
=
RD
.
update
(
embedding
,
u
,
w
,
v
,
h
)
print
(
state
)
### decoder step 2
print
(
"decoder step 2: "
)
u_prob
=
RD
.
getu
(
embedding
,
visited
)
u_prob
=
torch
.
softmax
(
u_prob
,
dim
=
0
)
catu
=
Categorical
(
u_prob
)
u
=
catu
.
sample
()
print
(
u
)
w_prob
=
RD
.
getw
(
embedding
,
u
,
visited
)
w_prob
=
torch
.
softmax
(
w_prob
,
dim
=
0
)
print
(
w_prob
.
shape
)
catw
=
Categorical
(
w_prob
)
w
=
catw
.
sample
()
print
(
w
)
if
w
<
50
:
v
=
u
h
=
w
else
:
w
=
w
-
50
v
=
w
h
=
u
state
=
RD
.
update
(
embedding
,
u
,
w
,
v
,
h
)
print
(
state
)
reward
=
-
30
loss
=
-
u_prob
[
u
]
*
reward
reward
.
backward
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment