Commit 6ed87ec7 by Gaoyunkai

add double_sac

parent 2267b85e
......@@ -123,3 +123,8 @@ venv.bak/
*.avi
.idea/
runs/
#slurm
*.err
*.out
*.slurm
......@@ -129,7 +129,7 @@ class hier_sac_agent:
if args.save:
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
self.log_dir = 'runs/hier/' + str(args.env_name) + '/RB_Decay_' + current_time + \
self.log_dir = '/lustre/S/gaoyunkai/RL/LESSON/runs/hier/' + str(args.env_name) + '/RB_Decay_' + current_time + \
"_C_" + str(args.c) + "_Image_" + str(args.image) + \
"_Seed_" + str(args.seed) + "_Reward_" + str(args.low_reward_coeff) + \
"_NoPhi_" + str(self.not_update_phi) + "_LearnG_" + str(self.learn_goal_space) + "_Early_" + str(self.early_stop_thres) + str(args.early_stop)
......
......@@ -47,6 +47,10 @@ class ReplayMemory:
obs_next = np.array(obs_next)
return obs, obs_next
def clear(self):
self.buffer = []
self.position = 0
class Array_ReplayMemory:
def __init__(self, capacity, env_params):
self.capacity = capacity
......
......@@ -92,12 +92,6 @@ class SAC(object):
qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
pi, log_pi, _ = self.policy.sample(state_batch)
qf1_pi, qf2_pi = self.critic(state_batch, pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
if feature_data is not None:
if self.gradient_flow_value:
obs, obs_next = self.critic.phi(feature_data[0]), self.critic.phi(feature_data[1])
......@@ -106,26 +100,33 @@ class SAC(object):
max_dist = torch.clamp(1 - (hi_obs - hi_obs_next).pow(2).mean(dim=1), min=0.)
representation_loss = (min_dist + max_dist).mean()
qf1_loss = qf1_loss * 0.1 + representation_loss
else:
qf_loss = qf1_loss + qf2_loss
self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()
pi, log_pi, _ = self.policy.sample(state_batch)
qf1_pi, qf2_pi = self.critic(state_batch, pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
# print("log_pi:", log_pi)
# print("min_qf_pi:", min_qf_pi)
# print("policy_loss:", policy_loss)
if feature_data is not None:
if not self.gradient_flow_value:
obs, obs_next = self.policy.phi(feature_data[0]), self.policy.phi(feature_data[1])
min_dist = torch.clamp((obs - obs_next).pow(2).mean(dim=1), min=0.)
hi_obs, hi_obs_next = self.policy.phi(feature_data[2]), self.policy.phi(feature_data[3])
max_dist = torch.clamp(1 - (hi_obs - hi_obs_next).pow(2).mean(dim=1), min=0.)
representation_loss = (min_dist + max_dist).mean()
policy_loss += representation_loss
self.critic_optim.zero_grad()
qf1_loss.backward()
self.critic_optim.step()
self.critic_optim.zero_grad()
qf2_loss.backward()
self.critic_optim.step()
policy_loss = policy_loss + representation_loss
self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()
#print("policy_loss:", policy_loss)
if self.automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
......
......@@ -35,7 +35,7 @@ def get_args_ant():
parser.add_argument('--n-test-rollouts', type=int, default=10, help='the number of tests')
parser.add_argument('--metric', type=str, default='MLP', help='the metric for the distance embedding')
parser.add_argument('--device', type=str, default="cuda:3", help='cuda device')
parser.add_argument('--device', type=str, default="cuda:0", help='cuda device')
parser.add_argument('--lr-decay-actor', type=int, default=3000, help='actor learning rate decay')
parser.add_argument('--lr-decay-critic', type=int, default=3000, help='critic learning rate decay')
......@@ -101,7 +101,7 @@ def get_args_chain():
parser.add_argument('--seed', type=int, default=160, help='random seed')
parser.add_argument('--replay-strategy', type=str, default='none', help='the HER strategy')
parser.add_argument('--save-dir', type=str, default='saved_models/', help='the path to save the models')
parser.add_argument('--save-dir', type=str, default='/lustre/S/gaoyunkai/RL/LESSON/saved_models/', help='the path to save the models')
parser.add_argument('--noise-eps', type=float, default=0.2, help='noise factor for Gaussian')
parser.add_argument('--random-eps', type=float, default=0.2, help="prob for acting randomly")
......@@ -118,7 +118,7 @@ def get_args_chain():
parser.add_argument('--n-test-rollouts', type=int, default=10, help='the number of tests')
parser.add_argument('--metric', type=str, default='MLP', help='the metric for the distance embedding')
parser.add_argument('--device', type=str, default="cuda:8", help='cuda device')
parser.add_argument('--device', type=str, default="cuda:0", help='cuda device')
parser.add_argument('--lr-decay-actor', type=int, default=3000, help='actor learning rate decay')
parser.add_argument('--lr-decay-critic', type=int, default=3000, help='critic learning rate decay')
......@@ -147,6 +147,8 @@ def get_args_chain():
parser.add_argument("--use_prediction", type=bool, default=False, help='use prediction error to learn feature')
parser.add_argument("--start_update_phi", type=int, default=2, help='use prediction error to learn feature')
parser.add_argument("--image", type=bool, default=False, help='use image input')
parser.add_argument("--old_sample", type=bool, default=False, help='sample the absolute goal in the abs_range')
# args of sac (high-level learning)
parser.add_argument('--policy', default="Gaussian",
......
......@@ -8,6 +8,8 @@ elif sys.argv[0].split('/')[-1] == "train_hier_ddpg.py":
from train_hier_ddpg import args
elif sys.argv[0].split('/')[-1] == "train_hier_sac.py":
from train_hier_sac import args
elif sys.argv[0].split('/')[-1] == "train_hier_double_sac.py":
from train_hier_double_sac import args
elif sys.argv[0].split('/')[-1] == "train_hier_ppo.py":
from train_hier_ppo import args
elif sys.argv[0].split('/')[-1] == "train_covering.py":
......
......@@ -152,7 +152,7 @@ class Critic_double(nn.Module):
def __init__(self, env_params, args):
super(Critic_double, self).__init__()
self.max_action = env_params['action_max']
self.inp_dim = env_params['obs'] + env_params['action'] + env_params['goal']
self.inp_dim = env_params['obs'] + env_params['action'] + env_params['real_goal_dim']
self.out_dim = 1
self.mid_dim = 400
......@@ -211,7 +211,8 @@ class doubleWrapper(nn.Module):
def forward(self, obs, goal, actions):
dist, dist1 = self.base(obs, goal, actions)
self.alpha = np.log(self.gamma)
return -(1 - torch.exp(dist * self.alpha)) / (1 - self.gamma), -(1 - torch.exp(dist1 * self.alpha)) / (1 - self.gamma)
#return -(1 - torch.exp(dist * self.alpha)) / (1 - self.gamma), -(1 - torch.exp(dist1 * self.alpha)) / (1 - self.gamma)
return dist, dist1
def Q1(self, obs, goal, actions):
dist, _ = self.base(obs, goal, actions)
......
import numpy as np
import gym
from arguments.arguments_hier_sac import get_args_ant, get_args_chain
from algos.hier_double_sac import hier_sac_agent
from goal_env.mujoco import *
import random
import torch
def get_env_params(env):
obs = env.reset()
# close the environment
params = {'obs': obs['observation'].shape[0], 'goal': obs['desired_goal'].shape[0],
'action': env.action_space.shape[0], 'action_max': env.action_space.high[0],
'max_timesteps': env._max_episode_steps}
return params
def launch(args):
# create the ddpg_agent
env = gym.make(args.env_name)
test_env = gym.make(args.test)
# if args.env_name == "AntPush-v1":
# test_env1 = gym.make("AntPushTest1-v1")
# test_env2 = gym.make("AntPushTest2-v1")
# elif args.env_name == "AntMaze1-v1":
# test_env1 = gym.make("AntMaze1Test1-v1")
# test_env2 = gym.make("AntMaze1Test2-v1")
# else:
test_env1 = test_env2 = None
print("test_env", test_env1, test_env2)
# set random seeds for reproduce
env.seed(args.seed)
if args.env_name != "NChain-v1":
env.env.env.wrapped_env.seed(args.seed)
test_env.env.env.wrapped_env.seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.device is not 'cpu':
torch.cuda.manual_seed(args.seed)
gym.spaces.prng.seed(args.seed)
# get the environment parameters
if args.env_name[:3] in ["Ant", "Poi", "Swi"]:
env.env.env.visualize_goal = args.animate
test_env.env.env.visualize_goal = args.animate
env_params = get_env_params(env)
env_params['max_test_timesteps'] = test_env._max_episode_steps
# create the ddpg agent to interact with the environment
sac_trainer = hier_sac_agent(args, env, env_params, test_env, test_env1, test_env2)
if args.eval:
if not args.resume:
print("random policy !!!")
# sac_trainer._eval_hier_agent(test_env)
# sac_trainer.vis_hier_policy()
# sac_trainer.cal_slow()
# sac_trainer.visualize_representation(100)
# sac_trainer.vis_learning_process()
# sac_trainer.picvideo('fig/final/', (1920, 1080))
else:
sac_trainer.learn()
# get the params
args = get_args_ant()
# args = get_args_chain()
# args = get_args_fetch()
# args = get_args_point()
if __name__ == '__main__':
launch(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment