Commit 50ae1f7d by Gaoyunkai

add goal correct

parent 6ed87ec7
......@@ -8,5 +8,9 @@ The python dependencies are as follows.
* [Gym](https://gym.openai.com/)
* [Mujoco](https://www.roboti.us)
Run the codes with ``python train_hier_sac.py``. The tensorboard files are saved in the ``runs`` folder and the
trained models are saved in the ``saved_models`` folder.
The tensorboard files are saved in ``/lustre/S/gaoyunkai/RL/LESSON/runs/hier/`` folder
the trained models are saved in the ``save-dir`` of arguments_hier_sac.py folder.
Run the origin codes with ``python train_hier_sac.py``.
Run code of the high agent and low agent that use SAC algorithm with ``python train_hier_double_sac.py``
Run code of double_SAC and goal_correct with ``python train_hier_double_sac_goal_correct.py``
\ No newline at end of file
......@@ -200,6 +200,23 @@ class GaussianPolicy(nn.Module):
mean = torch.tanh(mean) * self.action_scale + self.action_bias
return action, log_prob, mean
def correct(self, state_candidate, action):
candidate_num = state_candidate.shape[1]
mean, log_std = self.forward(state_candidate)
std = log_std.exp()
normal = Normal(mean, std)
x_t = torch.arctanh((action - self.action_bias) / self.action_scale)
x_t = x_t.unsqueeze(1).expand(-1, candidate_num, -1, -1)
# print("x_t", x_t.shape)
log_prob = normal.log_prob(x_t)
# print("log_prob:", log_prob.shape)
log_prob = log_prob.sum(-1).sum(-1)
# print("log_prob:", log_prob.shape)
correct_index = log_prob.argmax(1, keepdim=True)
# print("correct_index:", correct_index.shape)
return correct_index
def to(self, device):
self.action_scale = self.action_scale.to(device)
self.action_bias = self.action_bias.to(device)
......
......@@ -2,21 +2,30 @@ import random
import numpy as np
class ReplayMemory:
def __init__(self, capacity):
def __init__(self, capacity, use_goal_correct=False):
self.capacity = capacity
self.buffer = []
self.position = 0
self.use_goal_correct = use_goal_correct
def push(self, state, action, reward, next_state, done, epoch):
def push(self, state, action, reward, next_state, done, epoch, state_c_step=None, low_action=None):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done, epoch+1)
if not self.use_goal_correct:
self.buffer[self.position] = (state, action, reward, next_state, done, epoch+1)
else:
assert not low_action == None
self.buffer[self.position] = (state, action, reward, next_state, done, epoch+1, state_c_step, low_action)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done, _ = map(np.stack, zip(*batch))
return state, action, reward, next_state, done
if not self.use_goal_correct:
state, action, reward, next_state, done, _ = map(np.stack, zip(*batch))
return state, action, reward, next_state, done
else:
state, action, reward, next_state, done, _ , state_c_step, low_action= map(np.stack, zip(*batch))
return state, action, reward, next_state, done, state_c_step, low_action
def __len__(self):
return len(self.buffer)
......@@ -36,8 +45,12 @@ class ReplayMemory:
p_trajectory = p_trajectory.astype(np.float64)
idxs = np.random.choice(len(self.buffer), size=batch_size, replace=False, p=p_trajectory)
batch = [self.buffer[i] for i in idxs]
state, action, reward, next_state, done, _ = map(np.stack, zip(*batch))
return state, action, reward, next_state, done
if not self.use_goal_correct:
state, action, reward, next_state, done, _ = map(np.stack, zip(*batch))
return state, action, reward, next_state, done
else:
state, action, reward, next_state, done, _ , state_c_step, low_action= map(np.stack, zip(*batch))
return state, action, reward, next_state, done, state_c_step, low_action
def random_sample(self, batch_size):
idxs = np.random.randint(0, len(self.buffer), batch_size)
......
......@@ -4,10 +4,10 @@ import torch.nn.functional as F
from torch.optim import Adam
from algos.sac.utils import soft_update, hard_update
from algos.sac.model import GaussianPolicy, QNetwork, DeterministicPolicy, QNetwork_phi
import numpy as np
class SAC(object):
def __init__(self, num_inputs, action_space, args, pri_replay, goal_dim, gradient_flow_value, abs_range, tanh_output):
def __init__(self, num_inputs, action_space, args, pri_replay, goal_dim, gradient_flow_value, abs_range, tanh_output, use_goal_correct=False):
self.gamma = args.gamma
self.tau = args.tau
......@@ -20,6 +20,7 @@ class SAC(object):
self.device = args.device
self.gradient_flow_value = gradient_flow_value
self.use_goal_correct = use_goal_correct
if not gradient_flow_value:
self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
......@@ -64,18 +65,41 @@ class SAC(object):
_, _, action = self.policy.sample(state)
return action.detach().cpu().numpy()[0]
def update_parameters(self, memory, batch_size, env_params, hi_sparse, feature_data):
def select_num_action(self, state, num):
action_candidate = np.array([])
batch = state.shape[0]
for i in range(num):
action, _, _ = self.policy.sample(state)
action_candidate = np.append(action_candidate, action.detach().cpu().numpy())
return action_candidate.reshape(batch, num, -1)
def update_parameters(self, memory, batch_size, env_params, hi_sparse, feature_data, low_policy=None, representation=None):
# Sample a batch from memory
if self.pri_replay:
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.pri_sample(batch_size=batch_size)
if not self.use_goal_correct:
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.pri_sample(batch_size=batch_size)
else:
state_batch, action_batch, reward_batch, next_state_batch, mask_batch, low_state_c_step_batch, low_action_batch = memory.pri_sample(batch_size=batch_size)
else:
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
if not self.use_goal_correct:
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
else:
state_batch, action_batch, reward_batch, next_state_batch, mask_batch, low_state_c_step_batch, low_action_batch = memory.sample(batch_size=batch_size)
state_batch = torch.FloatTensor(state_batch).to(self.device)
next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
action_batch = torch.FloatTensor(action_batch).to(self.device)
reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
low_state_c_step_batch = torch.FloatTensor(low_state_c_step_batch).to(self.device)
low_action_batch = torch.FloatTensor(low_action_batch).to(self.device)
# print("state_batch shape:", state_batch.shape)
# print("action_batch shape:", action_batch.shape)
# print("reward_batch:", reward_batch.shape)
# print("mask_batch:", mask_batch.shape)
# print("low_state_c_step_batch shape:", low_state_c_step_batch.shape)
# print("low_action_batch shape:", low_action_batch.shape)
with torch.no_grad():
next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
......@@ -86,7 +110,32 @@ class SAC(object):
if hi_sparse:
# clip target value
next_q_value = torch.clamp(next_q_value, -env_params['max_timesteps'], 0.)
qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step
hi_action_candidate_num = 10
c_step = low_state_c_step_batch.shape[1]
real_goal_dim = env_params["real_goal_dim"]
if self.use_goal_correct:
with torch.no_grad():
action_batch_candidate = torch.FloatTensor(self.select_num_action(state_batch, hi_action_candidate_num-2)).to(self.device)
# print("action_batch_candidate:", action_batch_candidate.shape)
mean, _ = self.policy(state_batch)
# print("mean:", mean.shape)
action_batch_candidate = torch.cat([action_batch_candidate, action_batch.unsqueeze(1), mean.unsqueeze(1)], dim=1)
# print("action_batch_candidate:", action_batch_candidate.shape)
ag = representation(state_batch[:, :env_params["obs"]]).unsqueeze(1)
# print("ag shape:", ag.shape)
goal_batch_candidate = action_batch_candidate + ag
low_state_batch_candidate = torch.cat([low_state_c_step_batch.unsqueeze(1).expand(-1, hi_action_candidate_num, -1, -1), goal_batch_candidate.unsqueeze(2).expand(-1, -1, c_step, -1)], dim=-1)
# print("low_state_batch_candidate:", low_state_batch_candidate.shape)
goal_correct_index = low_policy.correct(low_state_batch_candidate, low_action_batch)
goal_correct_index = goal_correct_index.expand(-1, hi_action_candidate_num * real_goal_dim).reshape(-1, hi_action_candidate_num, real_goal_dim)
action_batch_correct = torch.gather(action_batch_candidate, 1, goal_correct_index)[:,0,:]
# print("action_batch_candidate:", action_batch_candidate)
# print("action_batch_correct:", action_batch_correct)
# print("action_batch:", action_batch)
qf1, qf2 = self.critic(state_batch, action_batch_correct) # Two Q-functions to mitigate positive bias in the policy improvement step
# print("qf1", qf1.shape)
# print("next_q", next_q_value.shape)
qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
......
......@@ -18,7 +18,7 @@ def get_args_ant():
parser.add_argument('--seed', type=int, default=125, help='random seed')
parser.add_argument('--replay-strategy', type=str, default='none', help='the HER strategy')
parser.add_argument('--save-dir', type=str, default='saved_models/', help='the path to save the models')
parser.add_argument('--save-dir', type=str, default='/lustre/S/gaoyunkai/RL/LESSON/saved_models/', help='the path to save the models')
parser.add_argument('--noise-eps', type=float, default=0.2, help='noise factor for Gaussian')
parser.add_argument('--random-eps', type=float, default=0.2, help="prob for acting randomly")
......
......@@ -10,6 +10,8 @@ elif sys.argv[0].split('/')[-1] == "train_hier_sac.py":
from train_hier_sac import args
elif sys.argv[0].split('/')[-1] == "train_hier_double_sac.py":
from train_hier_double_sac import args
elif sys.argv[0].split('/')[-1] == "train_hier_double_sac_goal_correct.py":
from train_hier_double_sac import args
elif sys.argv[0].split('/')[-1] == "train_hier_ppo.py":
from train_hier_ppo import args
elif sys.argv[0].split('/')[-1] == "train_covering.py":
......
import numpy as np
import gym
from arguments.arguments_hier_sac import get_args_ant, get_args_chain
from algos.hier_double_sac_goal_correct import hier_sac_agent
from goal_env.mujoco import *
import random
import torch
def get_env_params(env):
obs = env.reset()
# close the environment
params = {'obs': obs['observation'].shape[0], 'goal': obs['desired_goal'].shape[0],
'action': env.action_space.shape[0], 'action_max': env.action_space.high[0],
'max_timesteps': env._max_episode_steps}
return params
def launch(args):
# create the ddpg_agent
env = gym.make(args.env_name)
test_env = gym.make(args.test)
# if args.env_name == "AntPush-v1":
# test_env1 = gym.make("AntPushTest1-v1")
# test_env2 = gym.make("AntPushTest2-v1")
# elif args.env_name == "AntMaze1-v1":
# test_env1 = gym.make("AntMaze1Test1-v1")
# test_env2 = gym.make("AntMaze1Test2-v1")
# else:
test_env1 = test_env2 = None
print("test_env", test_env1, test_env2)
# set random seeds for reproduce
env.seed(args.seed)
if args.env_name != "NChain-v1":
env.env.env.wrapped_env.seed(args.seed)
test_env.env.env.wrapped_env.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.device is not 'cpu':
torch.cuda.manual_seed(args.seed)
gym.spaces.prng.seed(args.seed)
# get the environment parameters
if args.env_name[:3] in ["Ant", "Poi", "Swi"]:
env.env.env.visualize_goal = args.animate
test_env.env.env.visualize_goal = args.animate
env_params = get_env_params(env)
env_params['max_test_timesteps'] = test_env._max_episode_steps
# create the ddpg agent to interact with the environment
sac_trainer = hier_sac_agent(args, env, env_params, test_env, test_env1, test_env2)
if args.eval:
if not args.resume:
print("random policy !!!")
# sac_trainer._eval_hier_agent(test_env)
# sac_trainer.vis_hier_policy()
# sac_trainer.cal_slow()
# sac_trainer.visualize_representation(100)
# sac_trainer.vis_learning_process()
# sac_trainer.picvideo('fig/final/', (1920, 1080))
else:
sac_trainer.learn()
# get the params
args = get_args_ant()
# args = get_args_chain()
# args = get_args_fetch()
# args = get_args_point()
if __name__ == '__main__':
launch(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment