Commit 708dde03 by lvzhengyang

Initial commit

parents
# Reimplementation for REST
DAC'21: REST: Constructing Rectilinear Steiner Minimum Tree via Reinforcement Learning
* the code is mainly for the part of the Network Part
from model import Encoder, Decoder, D
import torch
import torch.nn as nn
# TODO: finish the mask
# TODO: make the vectorized version
class Actor(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
self.encoder = Encoder(dim=dim_e, N=3)
self.decoder = Decoder(dim_e, dim_q)
def forward(self, nodes, last_u, last_w, last_v, last_h,
last_subtree, mask_visited, mask_unvisited):
e = self.encoder(nodes)
u, w, s = self.decoder(e, last_u, last_w, last_v, last_h,
last_subtree, mask_visited, mask_unvisited)
return u, w, s
class Critic(nn.Module):
def __init__(self, dim_e=D, dim_c=256) -> None:
super().__init__()
self.dim_e = dim_e
self.dim_c = dim_c
self.encoder = Encoder(dim=dim_e, N=3)
self.g = nn.Linear(dim_e, 1, bias=False)
self.final = nn.Sequential(
nn.Linear(dim_e, dim_c, bias=True),
nn.ReLU(),
nn.Linear(dim_c, 1, bias=True)
)
def forward(self, nodes):
"""
@param nodes: [#batch, #num_nodes, 2] in dtype float
"""
e = self.encoder(nodes) # [#batch_size, #num_nodes, D]
glimpse_w = self.g(torch.tanh(e)).squeeze_() # [#batch_size, #num_nodes]
glimpse_w = torch.softmax(glimpse_w, dim=1).unsqueeze(1) # [#batch_size, 1, #num_nodes]
glimpse = torch.bmm(glimpse_w, e).squeeze() # [#batch_size, D]
b = self.final(glimpse).squeeze() # [#batch_size]
return b
"""
@brief: Network model of Actor and Critic
@author: Zhengyang Lyu
@date: 2022.8.30
"""
import torch
import torch.nn as nn
import pdb
D = 128
class EncoderLayer(nn.Module):
def __init__(self, num_heads=16, dim_per_head=16, dim=D, h_dim=512) -> None:
super().__init__()
self.num_heads = num_heads
self.dim_per_head = dim_per_head # seem no need
self.dim = dim
self.h_dim = h_dim
self.query = nn.Linear(self.dim, self.dim)
self.key = nn.Linear(self.dim, self.dim)
self.value = nn.Linear(self.dim, self.dim)
self.attention = nn.MultiheadAttention(
embed_dim=self.dim,
num_heads=self.num_heads,
batch_first=True
)
self.feed_foward = nn.Sequential(
nn.Linear(self.dim, self.h_dim),
nn.ReLU(),
nn.Linear(self.h_dim, self.dim)
)
self.norm = nn.BatchNorm1d(num_features=self.dim)
def forward(self, x):
"""
@param x: [#batch_size, #num_nodes, dim=128]
"""
batch_size = x.shape[0]
num_nodes = x.shape[1]
q = self.query(x)
k = self.key(x)
v = self.value(x)
x1, _ = self.attention(query=q, key=k, value=v, need_weights=False)
x.add_(x1)
x = self.norm(x.view(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x1 = self.feed_foward(x)
x.add_(x1)
x = self.norm(x.view(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
return x
class Encoder(nn.Module):
def __init__(self, dim=D, N=3) -> None:
super().__init__()
self.dim = dim
self.N = N
self.W_emb = nn.Linear(2, self.dim, bias=False) # get E_0
# TODO: implement batch norm
self.encoder_layers = nn.Sequential()
for i in range(self.N):
self.encoder_layers.add_module('{}_{}'.format(EncoderLayer.__name__, i),
EncoderLayer(num_heads=16))
self.norm = nn.BatchNorm1d(num_features=self.dim)
def forward(self, x):
"""
@param x: [#batch, #num_nodes, 2] in dtype float
"""
batch_size = x.shape[0]
num_nodes = x.shape[1]
x = self.W_emb(x)
x = self.norm(x.view(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x = self.encoder_layers(x)
return x
class PTM(nn.Module):
# PTM: PoinTing Mechanism
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
self.dim_e = dim_e
self.dim_q = dim_q
self.W_e = nn.Linear(dim_e, dim_q, bias=False)
self.W_q = nn.Linear(dim_q, dim_q, bias=False)
self.W_g = nn.Sequential(
nn.Tanh(),
nn.Linear(dim_q, 1, bias=False),
)
self.C = 10.0
def forward(self, e, q, mask=None, need_l=False):
"""
@param e: [#num_batch, #num_nodes, D]
@param q: [#num_batch, 360]
@param mask: [#num_batch, #num_nodes], with masked points set to True and the reset are False
set visited points True
@return [#num_batch, #num_nodes], [#num_batch, #num_nodes]
"""
e = self.W_e(e)
q = self.W_q(q)
# e: [#num_batch, #num_nodes, 360]
# q: [#num_batch, , 360]
# for each batch, perform e + q (q is broadcasted in to #num_nodes)
e.transpose_(0, 1).add_(q).transpose_(0, 1)
e = self.W_g(e) # get l
e.squeeze_()
l = e.clone()
if mask != None:
# points be masked is set to be -INF
e = torch.where(mask == False, e, -torch.inf)
e = torch.tanh(e)
e.mul_(self.C)
e = nn.functional.softmax(e, dim=1)
if need_l:
return e, l
return e, None
class EdgeGen(nn.Module):
"""
@brief Generate query for t-th action
"""
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
self.dim_e = dim_e
self.dim_q = dim_q
self.W_u = nn.Linear(self.dim_e, self.dim_q, bias=False)
self.W_w = nn.Linear(self.dim_e, self.dim_q, bias=False)
self.W_v = nn.Linear(self.dim_e, self.dim_q, bias=False)
self.W_h = nn.Linear(self.dim_e, self.dim_q, bias=False)
def forward(self, u, w, v, h, last_subtree):
"""
@param u, w, v, h: [#num_batch, D]
"""
edge = self.W_u(u) + self.W_w(w) + self.W_v(v) + self.W_h(h)
return edge
class SubTreeGen(nn.Module):
def __init__(self, dim_q=360) -> None:
super().__init__()
self.dim_q = dim_q
self.W_edge = nn.Linear(self.dim_q, self.dim_q, bias=False)
def forward(self, last_subtree, cur_edge):
cur_subtree = torch.max(last_subtree, self.W_edge(cur_edge))
return cur_subtree
class QGen(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
self.W_5 = nn.Linear(dim_e, dim_q, bias=False)
def forward(self, last_edge, last_subtree, cur_u=None):
if cur_u == None:
cur_q = torch.max(0.0, last_edge + last_subtree)
else:
cur_q = torch.max(0.0, last_edge + last_subtree + self.W_5(cur_u))
return cur_q
class EPTM(nn.Module):
# EPTM: Extended PoinTing Mechanism
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
self.PTM_0 = PTM(dim_e=dim_e, dim_q=dim_q)
self.PTM_1 = PTM(dim_e=dim_e, dim_q=dim_q)
self.C = 10.0
def forward(self, e, q, mask=None):
"""
@param e: [#num_batch, #num_nodes, D]
@param q: [#num_batch, 360]
@param mask: [#num_batch, #num_nodes], with masked points set to True and the reset are False
set unvisited points True
@return [#num_batch, 2, #num_nodes]
@note result[i, 0] is for batch i and s=0
"""
_, l_0 = self.PTM_0(e, q, need_l=True) # [#num_batch, #num_nodes]
_, l_1 = self.PTM_0(e, q, need_l=True) # [#num_batch, #num_nodes]
l = torch.stack([l_0, l_1], dim=1) # [#num_batch, 2, #num_nodes]
if mask is not None:
if mask.dim() == 2:
mask = torch.stack([mask, mask], dim=1)
l = torch.where(mask == False, l, -torch.inf)
l = torch.tanh(l).mul_(self.C)
p4w = torch.softmax(l, dim=-1) # probability of w, [#num_batch, 2, #num_nodes]
p4s = torch.softmax(l.transpose(1, -1), dim=-1) # probabiliti of s, [#num_batch, #num_nodes, 2]
# p4s[i, j, 0]: in batch i, if node j is selected as w, the probability for s = 0
return p4w, p4s
class Decoder(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
self.dim_e = dim_e
self.dim_q = dim_q
self.edge_gen = EdgeGen(dim_e, dim_q)
self.subtree_gen = SubTreeGen(dim_q)
self.ptm = PTM(dim_e, dim_q)
self.eptm = EPTM(dim_e, dim_q)
self.q_gen = QGen(dim_e, dim_q)
# TODO: finish the mask
def forward(self, e, last_u, last_w, last_v, last_h,
last_subtree, mask_visited=None, mask_unvisited=None):
"""
@param e: input embeddings for all nodes, [#num_batch, #num_nodes, D]
@param last_u/w/v/h: embeddings for last u/w/v/h
@param last_subtree: embeddings for last subtree
@note define subtree(t=0) = 0
@return probability of u, w, s
u, w: [#num_batch, #num_nodes]
s: [#num_batch, #num_nodes, 2]
"""
last_edge = self.edge_gen(last_u, last_w, last_v, last_h)
cur_q4u = self.q_gen(last_edge, last_subtree)
u, _ = self.ptm(e, cur_q4u, mask=mask_visited)
cur_q4w = self.q_gen(last_edge, last_subtree, u)
w, s = self.eptm(e, cur_q4w, mask=mask_unvisited)
return u, w, s
import torch
import pdb
from model import Encoder, PTM, EPTM, Critic
encoder = Encoder()
batch_size = 4
num_nodes = 8
x = torch.randn(batch_size, num_nodes, 2)
e = encoder(x)
ptm_0 = PTM() # for starting point
q = torch.zeros(4 * 360).reshape(4, 360) # query for starting point
mask = torch.randint(0, 2, (batch_size, num_nodes)).bool()
output, _ = ptm_0(e, q, mask)
start_points = output.max(dim=1)
eptm = EPTM()
p_1 = eptm(e, q, mask)
critic = Critic()
critic(x)
pdb.set_trace()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment