Commit 3dc620ab by lvzhengyang

add env framework

parent 708dde03
"""
@brief: agent for RL
@author: Zhengyang Lyu
@date: 2022.9.1
"""
from model import Encoder, Decoder, D
import torch
import torch.nn as nn
# TODO: finish the mask
# leave the mask provided by the environment
# TODO: make the vectorized version
class Actor(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None:
......
"""
@brief: environment for RL
@author: Zhengyang Lyu
@date: 2022.9.1
"""
import gym
from gym import spaces
class RSMTEnv(gym.Env):
def __init__(self, num_nodes, pos_l, pos_h, render_mode=None) -> None:
"""
@param pos_l/pos_h: min/max of the positions
"""
super().__init__()
self.render_mode = render_mode
self.num_nodes = num_nodes
self.pos_l = pos_l
self.pos_h = pos_h
self.observation_space = spaces.Dict({
"positions": spaces.Box(pos_l, pos_h, shape=(num_nodes, 2), dtype=float),
"mask_unvisited": spaces.MultiBinary(num_nodes),
"mask_visited": spaces.MultiBinary(num_nodes),
})
self.action_space = spaces.Dict({
"u": spaces.Discrete(num_nodes),
"w": spaces.Discrete(num_nodes),
"s": spaces.Discrete(2),
})
self.mask_unvisited = None
self.mask_visited = None
def _get_obs(self):
pass
def _get_info(self):
pass
def reset(self, seed=None):
super().reset(seed=seed)
pass
def step(self, action):
obs = self._get_obs()
info = self._get_info()
# return obs, reward, done, info
def close(self):
pass
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment