Commit 312301fd by Yijun Tan

垃圾代码,没有reset中间state调了一天

parent c01b9271
from tkinter import W
import numpy as np
import matplotlib.pyplot as plt
from itertools import count
from collections import namedtuple
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from env2 import *
from log import *
from model22 import *
env = Grid()
learning_rate = 0.01
gamma = 0.99
episodes = 20000
render = True
eps = np.finfo(np.float32).eps.item()
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
n = 5
LOG = 1
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.fc1 = nn.Linear(2*n, 32)
#debug(self.fc1.weight)
self.actor_embedding = ResEncoder(10, 2, 128, 16, 512, 3)
#self.critic_embedding = ResEncoder(10, 2, 128, 16, 512, 3)
self.actor_decoder = ResDecoder(128, 360)
'''
self.actoru = nn.Sequential(
nn.Linear(32, n)
)
self.actorw = nn.Sequential(
nn.Linear(32, 2*n)
)
'''
self.critic_embedding = ResEncoder(10, 2, 128, 16, 512, 3)
self.g = nn.Linear(128, 1, bias=False)
self.critic = nn.Sequential(
nn.Linear(128, 256, bias=True),
nn.ReLU(),
nn.Linear(256, 1, bias=True)
)
self.save_actions = []
self.rewards = []
def get_actor_embedding(self, x):
#debug("embedding: ", x)
#x = self.fc1(x)
#debug("embedding: ", self.fc1.weight)
#x = F.relu(x)
#action_score = self.actoru(x)
#state_value = self.critic(x)
#debug("embedding: ", x)
am = torch.zeros(n, n)
kpm = torch.zeros(1, n)
actor_emb = self.actor_embedding(x, am, kpm)
self.actor_emb = actor_emb
return actor_emb
def get_critic(self, x):
am = torch.zeros(n, n)
kpm = torch.zeros(1, n)
#debug(self.critic_embedding)
critic_emb = self.critic_embedding(x, am, kpm)
debug(critic_emb.shape)
glimpse = self.g(torch.tanh(critic_emb)).view(1, n)
debug(glimpse.shape)
glimpse = torch.softmax(glimpse, dim=1).unsqueeze(1)
glimpse = torch.bmm(glimpse, critic_emb)
debug(glimpse.shape)
x = self.critic(glimpse)
return x
def getu(self, visited):
#x = nn.ReLU(self.fc1(x))
#debug("get u: ", x)
#x = self.actoru(x)
#debug("get u: ", x)
#debug("get u: ", x)
x = self.actor_decoder.getu(self.actor_emb, visited)
u_prob = torch.where(visited, -1e6 * torch.ones_like(x), x)
return F.softmax(u_prob, dim=-1)
def getw(self, u, unvisited):
#x = nn.ReLU(self.fc1(x))
#x = self.actorw(x)
debug("getting w")
debug(u)
x = self.actor_decoder.getw(self.actor_emb, u, unvisited)
debug(x.shape, x)
debug(unvisited)
w_prob = torch.where(unvisited, -1e6 * torch.ones_like(x), x)
return F.softmax(w_prob, dim=-1)
def update_decoder(self, u, w=-1, v=-1, h=-1):
self.actor_decoder.update(self.actor_emb, u, w, v, h)
model = Policy()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for i in model.parameters():
debug(i.shape)
def select_action(state):
#debug(state)
actions = []
points = torch.tensor(state[0])/10
points = points.view(1, n, 2)
#points = torch.flatten(points)
visited = torch.tensor(state[1]).to(torch.bool)
visited = torch.logical_not(torch.tensor([True, False, False, False, False]))
debug(points)
#debug(visited, unvisited)
points = torch.tensor(points).float()
model.actor_decoder.reset_state()
for i in range(n-1):
#embedding = model(points)
#model.get_actor_embedding(points)
model.actor_emb = torch.rand([1, 5, 128])
#select u
#debug(points)
probs = model.getu(visited)
p = Categorical(probs)
u = p.sample()
#debug("u: ", u, probs, visited)
#debug()
visited[u] = False
#debug("u: ", u)
#select w
unvisited = torch.logical_not(visited)
unvisited = torch.cat((unvisited, unvisited), dim=0)
probs = model.getw(u, unvisited)
p = Categorical(probs)
w0 = p.sample()
#debug("w: ", w, probs, unvisited)
#debug()
if w0 < n:
v = u
h = w0
w = w0
else:
w = w0-n
v = w
h = u
visited[w] = False
if i < n-2:
model.update_decoder(u, w, v, h)
actions.append([u, w, v, h])
debug(points)
state_value = model.get_critic(points)
state_value = state_value.view(1)
debug(state_value)
model.save_actions.append(SavedAction(p.log_prob(w0), state_value))
#debug(actions)
return actions
def finish_episode():
R = 0
save_actions = model.save_actions
policy_loss = []
value_loss = []
rewards = []
for r in model.rewards[::-1]:
R = r + gamma * R
rewards.insert(0, R)
rewards = torch.tensor(rewards)
#rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
debug()
debug("training")
debug(save_actions)
debug(rewards)
for (log_prob , value), r in zip(save_actions, rewards):
reward = r - value.view(-1).item()
policy_loss.append(-log_prob * reward)
value_loss.append(F.smooth_l1_loss(value, torch.tensor([r])))
#debug(policy_loss)
#debug(value_loss)
optimizer.zero_grad()
loss = torch.stack(policy_loss).sum() + torch.stack(value_loss).sum()
#debug(loss)
loss.backward(retain_graph=True)
optimizer.step()
#optimizer.zero_grad()
#debug("policy_loss", policy_loss)
#loss = policy_loss[0]
#loss.backward(retain_graph=True)
#optimizer.step()
del model.rewards[:]
del model.save_actions[:]
def main():
for i_episode in count(episodes):
state = env.reset()
debug(state)
actions = select_action(state)
debug("actions: ", actions)
state, reward, done= env.multistep(actions)
model.rewards.append(reward)
#env.render()
finish_episode()
if __name__ == '__main__':
main()
\ No newline at end of file
from email.header import Header
from re import L
import torch
import random
import numpy as np
import os
from log import *
HEIGHT, WIDTH = [8,8]
OBSTACLE = 1000
BLANK=0
LINE=1
NODE=2
OBSTACLE=3
class Interval:
def __init__(self, low, high):
self.low = min(low,high)
self.high = max(low, high)
def intersect(self, inter):
return max(self.low, inter.low) < min(self.high, inter.high)
def merge(self, inter):
assert self.intersect(inter)
self.low = min(self.low, inter.low)
self.high = max(self.high, inter.high)
def log(a, k=0):
print("\033[3", k, "m", a, "\033[0m", end="", sep="")
class Point:
def __init__(self, x=-1, y=-1):
self.x = x
self.y = y
class Grid:
def __init__(self, h=10, w=10, n=5):
self.h = h
self.w = w
self.n = n
self.grid = [[BLANK]*self.w for i in range(self.h)]
self.hlines = [[]] * h
self.wlines = [[]] * w
self.steps = 0
random.seed(1)
#self.generate_nodes()
def reset(self):
self.steps = 0
self.grid = [[BLANK]*self.w for i in range(self.h)]
self.hlines = [[]] * self.h
self.wlines = [[]] * self.w
self.generate_nodes()
state = [[[p.x, p.y] for p in self.nodes], self.nodes_used]
#print(state)
return state
def generate_nodes(self):
random.seed(1)
self.hashnodes = set()
while len(self.hashnodes) < self.n:
x = random.randint(0, self.h*self.w - 1)
self.hashnodes.add(x)
#print(x)
self.hashnodes = [32, 97, 72, 17, 8]
self.nodes = []
for i in self.hashnodes:
self.nodes.append(Point(i // self.w, i % self.w))
#print(self.nodes[-1].x, self.nodes[-1].y)
self.grid[self.nodes[-1].x][self.nodes[-1].y] = NODE
self.nodes_used = [0] * self.n
def occupy(self, x, y):
if self.grid[x][y] == BLANK:
self.grid[x][y] = LINE
def connect(self, start, end, dir):
#assert self.nodes_used[start] == 1
#assert self.nodes_used[end] == 0
self.nodes_used[start] = 1
self.nodes_used[end] = 1
n0 = self.nodes[start]
n1 = self.nodes[end]
inter0 = Interval(n0.x, n1.x)
inter1 = Interval(n0.y, n1.y)
if dir:
#mid = Point(n0.x, n1.y)
self.hlines[n0.y].append(inter0) ######################### merge #############
self.wlines[n1.x].append(inter1)
for i in range(inter0.low, inter0.high+1, 1):
self.occupy(i, n0.y)
for i in range(inter1.low, inter1.high+1, 1):
self.occupy(n1.x, i)
else:
self.hlines[n1.y].append(inter0) ######################### merge #############
self.wlines[n0.x].append(inter1)
for i in range(inter0.low, inter0.high+1, 1):
self.occupy(i, n1.y)
for i in range(inter1.low, inter1.high+1, 1):
self.occupy(n0.x, i)
def wire_length(self):
res = 0
for i in range(self.h):
for j in range(self.w):
res += (self.grid[i][j] != BLANK)
return res-1
def finish(self):
return self.sum(self.nodes_used) == self.n
def step(self, start, end, dir):
self.connect(start, end, dir)
if self.finish():
reward = -self.wire_length()
done = 1
else:
reward = 0
done = 0
state = [[[p.x, p.y] for p in self.nodes], self.nodes_used]
return state, reward, done
def multistep(self, actions):
for i in actions:
dir = int(i[0] == i[1])
self.connect(i[0], i[1], dir)
reward= -self.wire_length()
state = [[[p.x, p.y] for p in self.nodes], self.nodes_used]
return state, reward, 1
def render(self):
#os.system("cls")
print()
for i in range(self.h):
log(str(self.h-i-1)+" ", 4)
for j in range(self.w):
log("o ", self.grid[self.h-i-1][j])
print()
print()
log(" ")
for i in range(self.w):
log(str(i)+" ", 4)
print()
for i in range(self.n):
print(i, "(", self.nodes[i].x, self.nodes[i].y, "): ", self.nodes_used[i])
def finish(self):
#assert sum(self.nodes_used) == self.steps+1
return self.steps+1 == self.n
def step(self, start, end, dir):
self.steps += 1
self.connect(start, end, dir)
if self.finish():
reward = -self.wire_length()
done = 1
else:
reward = 0
done = 0
return np.array(self.nodes_used), reward, done, None
if __name__ == "__main__":
env = Grid()
print(env.reset())
#_, reward, done = env.step(0, 1, 1)
#_, reward, done = env.step(1, 2, 0)
#_, reward, done = env.step(2, 3, 0)
#_, reward, done = env.step(3, 4, 0)
#print(reward, done)
actions = [[0, 1, 1, 0],
[1, 2, 1, 2],
[2, 3, 2, 3],
[3, 4, 3, 4],
]
_, reward, done = env.multistep(actions)
env.render()
\ No newline at end of file
INFO = 0
DISPLAY = 1
DEBUG = 2
LOG_LEVEL = 2
def log(a, k=0):
print("\033[3", k, "m", a, "\033[0m", end="", sep="")
def debug(*v):
if debug:
print(" ".join(v))
import builtins as __builtin__
def info(*args, **kwargs):
if LOG_LEVEL >= INFO:
return __builtin__.print(*args, **kwargs)
def display(*args, **kwargs):
if LOG_LEVEL >= DISPLAY:
return __builtin__.print(*args, **kwargs)
def debug(*args, **kwargs):
if LOG_LEVEL >= DEBUG:
return __builtin__.print(*args, **kwargs)
if __name__ == "__main__":
debug("1", "2", "3")
\ No newline at end of file
from collections import namedtuple
import tkinter as tk
import torch
import torch.nn as nn
from log import *
class ResEncoder(nn.Module):
def __init__(self, max_n, gdim=2, edim=128, nheads=16, hdim=512, num_layer=3):
super(ResEncoder, self).__init__()
self.max_n = max_n
self.edim = edim
self.nheads = nheads
self.hdim = hdim
self.num_layer = num_layer
self.fc0 = nn.Linear(gdim, edim)
#self.key_padding_mask = torch.zeros(batch_size, 50)
#self.attn_mask = torch.zeros(50, 50)
if torch.__version__[0:3] == "1.8":
self.encoder_layer = nn.TransformerEncoderLayer(d_model=edim, nhead=nheads, dim_feedforward=hdim, dropout=0)#norm_first, batch_first=True)
else:
self.encoder_layer = nn.TransformerEncoderLayer(d_model=edim, nhead=nheads, dim_feedforward=hdim, dropout=0, norm_first=True, batch_first=True)
self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layer)
def forward(self, x, attn_mask, key_padding_mask):
x = self.fc0(x)
#print(x)
#attn_mask = torch.zeros(self.max_n, self.max_n)
#key_padding_mask = torch.zeros(10, self.50)
x = self.encoder(x, mask=attn_mask, src_key_padding_mask=key_padding_mask)
return x
class PTM(nn.Module):
def __init__(self, edim, qdim):
super(PTM, self).__init__()
self.edim = qdim
self.qdim = qdim
self.efc = nn.Linear(edim, qdim, bias=False)
self.qfc = nn.Linear(qdim, qdim, bias=False)
self.th0 = nn.Tanh()
self.gfc = nn.Linear(qdim, 1, bias=False)
self.th1 = nn.Tanh()
def forward(self, e, q, visited=None):
#print(e.shape, q.shape)
e = self.efc(e)
q = self.qfc(q)
e = self.gfc(self.th0(e+q))
#print(e.shape)
#print("visited: ", visited.shape)
e = torch.squeeze(e)
#if visited != None:
# l = torch.where(visited, -1e6 * torch.ones_like(e), e)
#print(l.shape)
l = e
l = self.th1(l) * 10.0
#l = torch.softmax(l, dim=1)
return l
State = namedtuple("State", ["edge", "subtree", "eu", "ew", "ev", "eh"])
class ResDecoder(nn.Module):
def __init__(self, edim, qdim):
super(ResDecoder, self).__init__()
self.edim = edim
self.qdim = qdim
self.ptm = PTM(edim, qdim)
self.eptm0 = PTM(edim, qdim)
self.eptm1 = PTM(edim, qdim)
Edge = namedtuple("Edge", ["ufc", "wfc", "vfc", "hfc"])
self.edgefc = Edge(nn.Linear(edim, qdim),
nn.Linear(edim, qdim),
nn.Linear(edim, qdim),
nn.Linear(edim, qdim),
)
self.subtfc = nn.Linear(qdim, qdim)
self.fc_w5 = nn.Linear(edim, qdim)
self.state = []
self.state.append(State(
torch.zeros(self.qdim),
torch.zeros(self.qdim),
-1,
-1,
-1,
-1
))
#self.subtree = torch.zeros(self.qdim)
def reset_state(self):
self.state = []
self.state.append(State(
torch.zeros(self.qdim),
torch.zeros(self.qdim),
-1,
-1,
-1,
-1
))
def getu(self, e, visited, step=0):
qu = torch.max(self.state[-1].edge + self.state[-1].subtree, torch.zeros_like(self.state[-1].edge))
#debug(qu)
u = self.ptm(e, qu, visited)
return u
def getw(self, e, u, visited, step=1):
qw = torch.max(torch.zeros_like(self.state[-1].edge), self.state[-1].edge + self.state[-1].subtree + self.fc_w5(e[:, self.state[-1].eu, :]))
w0 = self.eptm0(e, qw, visited)
w1 = self.eptm1(e, qw, visited)
w = torch.cat([w0, w1]).view(1, -1)
return w
def update(self, e, u, w=-1, v=-1, h=-1):
edg = self.edgefc.ufc(e[:, u, :].clone())
subtr = torch.max(self.state[-1].subtree, self.subtfc(edg))
self.state.append(State(edge=edg, subtree=subtr, eu=u, ew=w, ev=v, eh=h))
return self.state[-1]
if __name__ == "__main__":
### encoder test
print("encoder test")
a = torch.zeros([10, 50, 2])
if torch.__version__[0:3] == "1.8":
a = torch.zeros([50, 10, 2])
a=a
print(a.shape)
RE = ResEncoder(10, 2, 128, 16, 512, 3)
amask = torch.zeros(50, 50)
kpmask = torch.zeros(10, 50)
b = RE(a, amask, kpmask)
print(b.shape)
'''
a = torch.ones(2, 2)
b = torch.ones(2, 4)
ptm = PTM(2, 4)
c = ptm(a, b, torch.tensor([0, 1, 0, 1])>0)
print(c)
'''
### decoder test
a = torch.zeros([1, 50, 2])
RE = ResEncoder(10, 2, 16, 16, 16, 1)
RD = ResDecoder(16, 32)
amask = torch.zeros(50, 50)
kpmask = torch.zeros(1, 50)
embedding = RE(a, amask, kpmask)
embedding = torch.squeeze(embedding)
print("embedding: ", embedding.shape)
optimizer = torch.optim.Adam(RD.parameters()+RE.parameters(), lr=0.1)
### decoder step 0
print("decoder step 0: ")
from torch.distributions import Categorical
visited = torch.zeros([50]).to(torch.bool)
u_prob = RD.getu(embedding, visited)
#print(u_prob.shape)
u_prob = torch.softmax(u_prob, dim=0)
#print(u_prob.shape)
catu = Categorical(u_prob)
u = catu.sample()
print(u)
state = RD.update(embedding, u)
print(state)
### decoder step 1
print("decoder step 1: ")
u_prob = RD.getu(embedding, visited)
u_prob = torch.softmax(u_prob, dim=0)
catu = Categorical(u_prob)
u = catu.sample()
print(u)
w_prob = RD.getw(embedding, u, visited)
w_prob = torch.softmax(w_prob, dim=0)
print(w_prob.shape)
catw = Categorical(w_prob)
w = catw.sample()
print(w)
if w < 50:
v = u
h = w
else:
w = w-50
v = w
h = u
state = RD.update(embedding, u, w, v, h)
print(state)
### decoder step 2
print("decoder step 2: ")
u_prob = RD.getu(embedding, visited)
u_prob = torch.softmax(u_prob, dim=0)
catu = Categorical(u_prob)
u = catu.sample()
print(u)
w_prob = RD.getw(embedding, u, visited)
w_prob = torch.softmax(w_prob, dim=0)
print(w_prob.shape)
catw = Categorical(w_prob)
w = catw.sample()
print(w)
if w < 50:
v = u
h = w
else:
w = w-50
v = w
h = u
state = RD.update(embedding, u, w, v, h)
print(state)
reward= -30
loss = -u_prob[u] * reward
reward.backward()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment