Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LESSON
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Gaoyunkai
LESSON
Commits
6ed87ec7
Commit
6ed87ec7
authored
Jul 05, 2021
by
Gaoyunkai
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add double_sac
parent
2267b85e
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
109 additions
and
23 deletions
+109
-23
.gitignore
+5
-0
algos/hier_double_sac.py
+0
-0
algos/hier_sac.py
+1
-1
algos/sac/replay_memory.py
+4
-0
algos/sac/sac.py
+18
-17
arguments/arguments_hier_sac.py
+5
-3
goal_env/mujoco/__init__.py
+2
-0
models/networks.py
+3
-2
train_hier_double_sac.py
+71
-0
No files found.
.gitignore
View file @
6ed87ec7
...
@@ -123,3 +123,8 @@ venv.bak/
...
@@ -123,3 +123,8 @@ venv.bak/
*.avi
*.avi
.idea/
.idea/
runs/
runs/
#slurm
*.err
*.out
*.slurm
algos/hier_double_sac.py
0 → 100644
View file @
6ed87ec7
This diff is collapsed.
Click to expand it.
algos/hier_sac.py
View file @
6ed87ec7
...
@@ -129,7 +129,7 @@ class hier_sac_agent:
...
@@ -129,7 +129,7 @@ class hier_sac_agent:
if
args
.
save
:
if
args
.
save
:
current_time
=
datetime
.
now
()
.
strftime
(
'
%
b
%
d_
%
H-
%
M-
%
S'
)
current_time
=
datetime
.
now
()
.
strftime
(
'
%
b
%
d_
%
H-
%
M-
%
S'
)
self
.
log_dir
=
'runs/hier/'
+
str
(
args
.
env_name
)
+
'/RB_Decay_'
+
current_time
+
\
self
.
log_dir
=
'
/lustre/S/gaoyunkai/RL/LESSON/
runs/hier/'
+
str
(
args
.
env_name
)
+
'/RB_Decay_'
+
current_time
+
\
"_C_"
+
str
(
args
.
c
)
+
"_Image_"
+
str
(
args
.
image
)
+
\
"_C_"
+
str
(
args
.
c
)
+
"_Image_"
+
str
(
args
.
image
)
+
\
"_Seed_"
+
str
(
args
.
seed
)
+
"_Reward_"
+
str
(
args
.
low_reward_coeff
)
+
\
"_Seed_"
+
str
(
args
.
seed
)
+
"_Reward_"
+
str
(
args
.
low_reward_coeff
)
+
\
"_NoPhi_"
+
str
(
self
.
not_update_phi
)
+
"_LearnG_"
+
str
(
self
.
learn_goal_space
)
+
"_Early_"
+
str
(
self
.
early_stop_thres
)
+
str
(
args
.
early_stop
)
"_NoPhi_"
+
str
(
self
.
not_update_phi
)
+
"_LearnG_"
+
str
(
self
.
learn_goal_space
)
+
"_Early_"
+
str
(
self
.
early_stop_thres
)
+
str
(
args
.
early_stop
)
...
...
algos/sac/replay_memory.py
View file @
6ed87ec7
...
@@ -47,6 +47,10 @@ class ReplayMemory:
...
@@ -47,6 +47,10 @@ class ReplayMemory:
obs_next
=
np
.
array
(
obs_next
)
obs_next
=
np
.
array
(
obs_next
)
return
obs
,
obs_next
return
obs
,
obs_next
def
clear
(
self
):
self
.
buffer
=
[]
self
.
position
=
0
class
Array_ReplayMemory
:
class
Array_ReplayMemory
:
def
__init__
(
self
,
capacity
,
env_params
):
def
__init__
(
self
,
capacity
,
env_params
):
self
.
capacity
=
capacity
self
.
capacity
=
capacity
...
...
algos/sac/sac.py
View file @
6ed87ec7
...
@@ -92,12 +92,6 @@ class SAC(object):
...
@@ -92,12 +92,6 @@ class SAC(object):
qf1_loss
=
F
.
mse_loss
(
qf1
,
next_q_value
)
# JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf1_loss
=
F
.
mse_loss
(
qf1
,
next_q_value
)
# JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf2_loss
=
F
.
mse_loss
(
qf2
,
next_q_value
)
# JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf2_loss
=
F
.
mse_loss
(
qf2
,
next_q_value
)
# JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
pi
,
log_pi
,
_
=
self
.
policy
.
sample
(
state_batch
)
qf1_pi
,
qf2_pi
=
self
.
critic
(
state_batch
,
pi
)
min_qf_pi
=
torch
.
min
(
qf1_pi
,
qf2_pi
)
policy_loss
=
((
self
.
alpha
*
log_pi
)
-
min_qf_pi
)
.
mean
()
# Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
if
feature_data
is
not
None
:
if
feature_data
is
not
None
:
if
self
.
gradient_flow_value
:
if
self
.
gradient_flow_value
:
obs
,
obs_next
=
self
.
critic
.
phi
(
feature_data
[
0
]),
self
.
critic
.
phi
(
feature_data
[
1
])
obs
,
obs_next
=
self
.
critic
.
phi
(
feature_data
[
0
]),
self
.
critic
.
phi
(
feature_data
[
1
])
...
@@ -106,26 +100,33 @@ class SAC(object):
...
@@ -106,26 +100,33 @@ class SAC(object):
max_dist
=
torch
.
clamp
(
1
-
(
hi_obs
-
hi_obs_next
)
.
pow
(
2
)
.
mean
(
dim
=
1
),
min
=
0.
)
max_dist
=
torch
.
clamp
(
1
-
(
hi_obs
-
hi_obs_next
)
.
pow
(
2
)
.
mean
(
dim
=
1
),
min
=
0.
)
representation_loss
=
(
min_dist
+
max_dist
)
.
mean
()
representation_loss
=
(
min_dist
+
max_dist
)
.
mean
()
qf1_loss
=
qf1_loss
*
0.1
+
representation_loss
qf1_loss
=
qf1_loss
*
0.1
+
representation_loss
else
:
qf_loss
=
qf1_loss
+
qf2_loss
self
.
critic_optim
.
zero_grad
()
qf_loss
.
backward
()
self
.
critic_optim
.
step
()
pi
,
log_pi
,
_
=
self
.
policy
.
sample
(
state_batch
)
qf1_pi
,
qf2_pi
=
self
.
critic
(
state_batch
,
pi
)
min_qf_pi
=
torch
.
min
(
qf1_pi
,
qf2_pi
)
policy_loss
=
((
self
.
alpha
*
log_pi
)
-
min_qf_pi
)
.
mean
()
# Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
# print("log_pi:", log_pi)
# print("min_qf_pi:", min_qf_pi)
# print("policy_loss:", policy_loss)
if
feature_data
is
not
None
:
if
not
self
.
gradient_flow_value
:
obs
,
obs_next
=
self
.
policy
.
phi
(
feature_data
[
0
]),
self
.
policy
.
phi
(
feature_data
[
1
])
obs
,
obs_next
=
self
.
policy
.
phi
(
feature_data
[
0
]),
self
.
policy
.
phi
(
feature_data
[
1
])
min_dist
=
torch
.
clamp
((
obs
-
obs_next
)
.
pow
(
2
)
.
mean
(
dim
=
1
),
min
=
0.
)
min_dist
=
torch
.
clamp
((
obs
-
obs_next
)
.
pow
(
2
)
.
mean
(
dim
=
1
),
min
=
0.
)
hi_obs
,
hi_obs_next
=
self
.
policy
.
phi
(
feature_data
[
2
]),
self
.
policy
.
phi
(
feature_data
[
3
])
hi_obs
,
hi_obs_next
=
self
.
policy
.
phi
(
feature_data
[
2
]),
self
.
policy
.
phi
(
feature_data
[
3
])
max_dist
=
torch
.
clamp
(
1
-
(
hi_obs
-
hi_obs_next
)
.
pow
(
2
)
.
mean
(
dim
=
1
),
min
=
0.
)
max_dist
=
torch
.
clamp
(
1
-
(
hi_obs
-
hi_obs_next
)
.
pow
(
2
)
.
mean
(
dim
=
1
),
min
=
0.
)
representation_loss
=
(
min_dist
+
max_dist
)
.
mean
()
representation_loss
=
(
min_dist
+
max_dist
)
.
mean
()
policy_loss
+=
representation_loss
policy_loss
=
policy_loss
+
representation_loss
self
.
critic_optim
.
zero_grad
()
qf1_loss
.
backward
()
self
.
critic_optim
.
step
()
self
.
critic_optim
.
zero_grad
()
qf2_loss
.
backward
()
self
.
critic_optim
.
step
()
self
.
policy_optim
.
zero_grad
()
self
.
policy_optim
.
zero_grad
()
policy_loss
.
backward
()
policy_loss
.
backward
()
self
.
policy_optim
.
step
()
self
.
policy_optim
.
step
()
#print("policy_loss:", policy_loss)
if
self
.
automatic_entropy_tuning
:
if
self
.
automatic_entropy_tuning
:
alpha_loss
=
-
(
self
.
log_alpha
*
(
log_pi
+
self
.
target_entropy
)
.
detach
())
.
mean
()
alpha_loss
=
-
(
self
.
log_alpha
*
(
log_pi
+
self
.
target_entropy
)
.
detach
())
.
mean
()
...
...
arguments/arguments_hier_sac.py
View file @
6ed87ec7
...
@@ -35,7 +35,7 @@ def get_args_ant():
...
@@ -35,7 +35,7 @@ def get_args_ant():
parser
.
add_argument
(
'--n-test-rollouts'
,
type
=
int
,
default
=
10
,
help
=
'the number of tests'
)
parser
.
add_argument
(
'--n-test-rollouts'
,
type
=
int
,
default
=
10
,
help
=
'the number of tests'
)
parser
.
add_argument
(
'--metric'
,
type
=
str
,
default
=
'MLP'
,
help
=
'the metric for the distance embedding'
)
parser
.
add_argument
(
'--metric'
,
type
=
str
,
default
=
'MLP'
,
help
=
'the metric for the distance embedding'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
"cuda:
3
"
,
help
=
'cuda device'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
"cuda:
0
"
,
help
=
'cuda device'
)
parser
.
add_argument
(
'--lr-decay-actor'
,
type
=
int
,
default
=
3000
,
help
=
'actor learning rate decay'
)
parser
.
add_argument
(
'--lr-decay-actor'
,
type
=
int
,
default
=
3000
,
help
=
'actor learning rate decay'
)
parser
.
add_argument
(
'--lr-decay-critic'
,
type
=
int
,
default
=
3000
,
help
=
'critic learning rate decay'
)
parser
.
add_argument
(
'--lr-decay-critic'
,
type
=
int
,
default
=
3000
,
help
=
'critic learning rate decay'
)
...
@@ -101,7 +101,7 @@ def get_args_chain():
...
@@ -101,7 +101,7 @@ def get_args_chain():
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
160
,
help
=
'random seed'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
160
,
help
=
'random seed'
)
parser
.
add_argument
(
'--replay-strategy'
,
type
=
str
,
default
=
'none'
,
help
=
'the HER strategy'
)
parser
.
add_argument
(
'--replay-strategy'
,
type
=
str
,
default
=
'none'
,
help
=
'the HER strategy'
)
parser
.
add_argument
(
'--save-dir'
,
type
=
str
,
default
=
'saved_models/'
,
help
=
'the path to save the models'
)
parser
.
add_argument
(
'--save-dir'
,
type
=
str
,
default
=
'
/lustre/S/gaoyunkai/RL/LESSON/
saved_models/'
,
help
=
'the path to save the models'
)
parser
.
add_argument
(
'--noise-eps'
,
type
=
float
,
default
=
0.2
,
help
=
'noise factor for Gaussian'
)
parser
.
add_argument
(
'--noise-eps'
,
type
=
float
,
default
=
0.2
,
help
=
'noise factor for Gaussian'
)
parser
.
add_argument
(
'--random-eps'
,
type
=
float
,
default
=
0.2
,
help
=
"prob for acting randomly"
)
parser
.
add_argument
(
'--random-eps'
,
type
=
float
,
default
=
0.2
,
help
=
"prob for acting randomly"
)
...
@@ -118,7 +118,7 @@ def get_args_chain():
...
@@ -118,7 +118,7 @@ def get_args_chain():
parser
.
add_argument
(
'--n-test-rollouts'
,
type
=
int
,
default
=
10
,
help
=
'the number of tests'
)
parser
.
add_argument
(
'--n-test-rollouts'
,
type
=
int
,
default
=
10
,
help
=
'the number of tests'
)
parser
.
add_argument
(
'--metric'
,
type
=
str
,
default
=
'MLP'
,
help
=
'the metric for the distance embedding'
)
parser
.
add_argument
(
'--metric'
,
type
=
str
,
default
=
'MLP'
,
help
=
'the metric for the distance embedding'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
"cuda:
8
"
,
help
=
'cuda device'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
"cuda:
0
"
,
help
=
'cuda device'
)
parser
.
add_argument
(
'--lr-decay-actor'
,
type
=
int
,
default
=
3000
,
help
=
'actor learning rate decay'
)
parser
.
add_argument
(
'--lr-decay-actor'
,
type
=
int
,
default
=
3000
,
help
=
'actor learning rate decay'
)
parser
.
add_argument
(
'--lr-decay-critic'
,
type
=
int
,
default
=
3000
,
help
=
'critic learning rate decay'
)
parser
.
add_argument
(
'--lr-decay-critic'
,
type
=
int
,
default
=
3000
,
help
=
'critic learning rate decay'
)
...
@@ -147,6 +147,8 @@ def get_args_chain():
...
@@ -147,6 +147,8 @@ def get_args_chain():
parser
.
add_argument
(
"--use_prediction"
,
type
=
bool
,
default
=
False
,
help
=
'use prediction error to learn feature'
)
parser
.
add_argument
(
"--use_prediction"
,
type
=
bool
,
default
=
False
,
help
=
'use prediction error to learn feature'
)
parser
.
add_argument
(
"--start_update_phi"
,
type
=
int
,
default
=
2
,
help
=
'use prediction error to learn feature'
)
parser
.
add_argument
(
"--start_update_phi"
,
type
=
int
,
default
=
2
,
help
=
'use prediction error to learn feature'
)
parser
.
add_argument
(
"--image"
,
type
=
bool
,
default
=
False
,
help
=
'use image input'
)
parser
.
add_argument
(
"--image"
,
type
=
bool
,
default
=
False
,
help
=
'use image input'
)
parser
.
add_argument
(
"--old_sample"
,
type
=
bool
,
default
=
False
,
help
=
'sample the absolute goal in the abs_range'
)
# args of sac (high-level learning)
# args of sac (high-level learning)
parser
.
add_argument
(
'--policy'
,
default
=
"Gaussian"
,
parser
.
add_argument
(
'--policy'
,
default
=
"Gaussian"
,
...
...
goal_env/mujoco/__init__.py
View file @
6ed87ec7
...
@@ -8,6 +8,8 @@ elif sys.argv[0].split('/')[-1] == "train_hier_ddpg.py":
...
@@ -8,6 +8,8 @@ elif sys.argv[0].split('/')[-1] == "train_hier_ddpg.py":
from
train_hier_ddpg
import
args
from
train_hier_ddpg
import
args
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_hier_sac.py"
:
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_hier_sac.py"
:
from
train_hier_sac
import
args
from
train_hier_sac
import
args
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_hier_double_sac.py"
:
from
train_hier_double_sac
import
args
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_hier_ppo.py"
:
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_hier_ppo.py"
:
from
train_hier_ppo
import
args
from
train_hier_ppo
import
args
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_covering.py"
:
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_covering.py"
:
...
...
models/networks.py
View file @
6ed87ec7
...
@@ -152,7 +152,7 @@ class Critic_double(nn.Module):
...
@@ -152,7 +152,7 @@ class Critic_double(nn.Module):
def
__init__
(
self
,
env_params
,
args
):
def
__init__
(
self
,
env_params
,
args
):
super
(
Critic_double
,
self
)
.
__init__
()
super
(
Critic_double
,
self
)
.
__init__
()
self
.
max_action
=
env_params
[
'action_max'
]
self
.
max_action
=
env_params
[
'action_max'
]
self
.
inp_dim
=
env_params
[
'obs'
]
+
env_params
[
'action'
]
+
env_params
[
'
goal
'
]
self
.
inp_dim
=
env_params
[
'obs'
]
+
env_params
[
'action'
]
+
env_params
[
'
real_goal_dim
'
]
self
.
out_dim
=
1
self
.
out_dim
=
1
self
.
mid_dim
=
400
self
.
mid_dim
=
400
...
@@ -211,7 +211,8 @@ class doubleWrapper(nn.Module):
...
@@ -211,7 +211,8 @@ class doubleWrapper(nn.Module):
def
forward
(
self
,
obs
,
goal
,
actions
):
def
forward
(
self
,
obs
,
goal
,
actions
):
dist
,
dist1
=
self
.
base
(
obs
,
goal
,
actions
)
dist
,
dist1
=
self
.
base
(
obs
,
goal
,
actions
)
self
.
alpha
=
np
.
log
(
self
.
gamma
)
self
.
alpha
=
np
.
log
(
self
.
gamma
)
return
-
(
1
-
torch
.
exp
(
dist
*
self
.
alpha
))
/
(
1
-
self
.
gamma
),
-
(
1
-
torch
.
exp
(
dist1
*
self
.
alpha
))
/
(
1
-
self
.
gamma
)
#return -(1 - torch.exp(dist * self.alpha)) / (1 - self.gamma), -(1 - torch.exp(dist1 * self.alpha)) / (1 - self.gamma)
return
dist
,
dist1
def
Q1
(
self
,
obs
,
goal
,
actions
):
def
Q1
(
self
,
obs
,
goal
,
actions
):
dist
,
_
=
self
.
base
(
obs
,
goal
,
actions
)
dist
,
_
=
self
.
base
(
obs
,
goal
,
actions
)
...
...
train_hier_double_sac.py
0 → 100644
View file @
6ed87ec7
import
numpy
as
np
import
gym
from
arguments.arguments_hier_sac
import
get_args_ant
,
get_args_chain
from
algos.hier_double_sac
import
hier_sac_agent
from
goal_env.mujoco
import
*
import
random
import
torch
def
get_env_params
(
env
):
obs
=
env
.
reset
()
# close the environment
params
=
{
'obs'
:
obs
[
'observation'
]
.
shape
[
0
],
'goal'
:
obs
[
'desired_goal'
]
.
shape
[
0
],
'action'
:
env
.
action_space
.
shape
[
0
],
'action_max'
:
env
.
action_space
.
high
[
0
],
'max_timesteps'
:
env
.
_max_episode_steps
}
return
params
def
launch
(
args
):
# create the ddpg_agent
env
=
gym
.
make
(
args
.
env_name
)
test_env
=
gym
.
make
(
args
.
test
)
# if args.env_name == "AntPush-v1":
# test_env1 = gym.make("AntPushTest1-v1")
# test_env2 = gym.make("AntPushTest2-v1")
# elif args.env_name == "AntMaze1-v1":
# test_env1 = gym.make("AntMaze1Test1-v1")
# test_env2 = gym.make("AntMaze1Test2-v1")
# else:
test_env1
=
test_env2
=
None
print
(
"test_env"
,
test_env1
,
test_env2
)
# set random seeds for reproduce
env
.
seed
(
args
.
seed
)
if
args
.
env_name
!=
"NChain-v1"
:
env
.
env
.
env
.
wrapped_env
.
seed
(
args
.
seed
)
test_env
.
env
.
env
.
wrapped_env
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
if
args
.
device
is
not
'cpu'
:
torch
.
cuda
.
manual_seed
(
args
.
seed
)
gym
.
spaces
.
prng
.
seed
(
args
.
seed
)
# get the environment parameters
if
args
.
env_name
[:
3
]
in
[
"Ant"
,
"Poi"
,
"Swi"
]:
env
.
env
.
env
.
visualize_goal
=
args
.
animate
test_env
.
env
.
env
.
visualize_goal
=
args
.
animate
env_params
=
get_env_params
(
env
)
env_params
[
'max_test_timesteps'
]
=
test_env
.
_max_episode_steps
# create the ddpg agent to interact with the environment
sac_trainer
=
hier_sac_agent
(
args
,
env
,
env_params
,
test_env
,
test_env1
,
test_env2
)
if
args
.
eval
:
if
not
args
.
resume
:
print
(
"random policy !!!"
)
# sac_trainer._eval_hier_agent(test_env)
# sac_trainer.vis_hier_policy()
# sac_trainer.cal_slow()
# sac_trainer.visualize_representation(100)
# sac_trainer.vis_learning_process()
# sac_trainer.picvideo('fig/final/', (1920, 1080))
else
:
sac_trainer
.
learn
()
# get the params
args
=
get_args_ant
()
# args = get_args_chain()
# args = get_args_fetch()
# args = get_args_point()
if
__name__
==
'__main__'
:
launch
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment