Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LESSON
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Gaoyunkai
LESSON
Commits
50ae1f7d
Commit
50ae1f7d
authored
Jul 07, 2021
by
Gaoyunkai
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add goal correct
parent
6ed87ec7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
173 additions
and
16 deletions
+173
-16
README.md
+7
-2
algos/hier_double_sac_goal_correct.py
+0
-0
algos/sac/model.py
+17
-0
algos/sac/replay_memory.py
+20
-7
algos/sac/sac.py
+55
-6
arguments/arguments_hier_sac.py
+1
-1
goal_env/mujoco/__init__.py
+2
-0
train_hier_double_sac_goal_correct.py
+71
-0
No files found.
README.md
View file @
50ae1f7d
...
@@ -8,5 +8,9 @@ The python dependencies are as follows.
...
@@ -8,5 +8,9 @@ The python dependencies are as follows.
*
[
Gym
](
https://gym.openai.com/
)
*
[
Gym
](
https://gym.openai.com/
)
*
[
Mujoco
](
https://www.roboti.us
)
*
[
Mujoco
](
https://www.roboti.us
)
Run the codes with
``python train_hier_sac.py``
. The tensorboard files are saved in the
``runs``
folder and the
The tensorboard files are saved in
``/lustre/S/gaoyunkai/RL/LESSON/runs/hier/``
folder
trained models are saved in the
``saved_models``
folder.
the trained models are saved in the
``save-dir``
of arguments_hier_sac.py folder.
Run the origin codes with
``python train_hier_sac.py``
.
Run code of the high agent and low agent that use SAC algorithm with
``python train_hier_double_sac.py``
Run code of double_SAC and goal_correct with
``python train_hier_double_sac_goal_correct.py``
\ No newline at end of file
algos/hier_double_sac_goal_correct.py
0 → 100644
View file @
50ae1f7d
This diff is collapsed.
Click to expand it.
algos/sac/model.py
View file @
50ae1f7d
...
@@ -200,6 +200,23 @@ class GaussianPolicy(nn.Module):
...
@@ -200,6 +200,23 @@ class GaussianPolicy(nn.Module):
mean
=
torch
.
tanh
(
mean
)
*
self
.
action_scale
+
self
.
action_bias
mean
=
torch
.
tanh
(
mean
)
*
self
.
action_scale
+
self
.
action_bias
return
action
,
log_prob
,
mean
return
action
,
log_prob
,
mean
def
correct
(
self
,
state_candidate
,
action
):
candidate_num
=
state_candidate
.
shape
[
1
]
mean
,
log_std
=
self
.
forward
(
state_candidate
)
std
=
log_std
.
exp
()
normal
=
Normal
(
mean
,
std
)
x_t
=
torch
.
arctanh
((
action
-
self
.
action_bias
)
/
self
.
action_scale
)
x_t
=
x_t
.
unsqueeze
(
1
)
.
expand
(
-
1
,
candidate_num
,
-
1
,
-
1
)
# print("x_t", x_t.shape)
log_prob
=
normal
.
log_prob
(
x_t
)
# print("log_prob:", log_prob.shape)
log_prob
=
log_prob
.
sum
(
-
1
)
.
sum
(
-
1
)
# print("log_prob:", log_prob.shape)
correct_index
=
log_prob
.
argmax
(
1
,
keepdim
=
True
)
# print("correct_index:", correct_index.shape)
return
correct_index
def
to
(
self
,
device
):
def
to
(
self
,
device
):
self
.
action_scale
=
self
.
action_scale
.
to
(
device
)
self
.
action_scale
=
self
.
action_scale
.
to
(
device
)
self
.
action_bias
=
self
.
action_bias
.
to
(
device
)
self
.
action_bias
=
self
.
action_bias
.
to
(
device
)
...
...
algos/sac/replay_memory.py
View file @
50ae1f7d
...
@@ -2,21 +2,30 @@ import random
...
@@ -2,21 +2,30 @@ import random
import
numpy
as
np
import
numpy
as
np
class
ReplayMemory
:
class
ReplayMemory
:
def
__init__
(
self
,
capacity
):
def
__init__
(
self
,
capacity
,
use_goal_correct
=
False
):
self
.
capacity
=
capacity
self
.
capacity
=
capacity
self
.
buffer
=
[]
self
.
buffer
=
[]
self
.
position
=
0
self
.
position
=
0
self
.
use_goal_correct
=
use_goal_correct
def
push
(
self
,
state
,
action
,
reward
,
next_state
,
done
,
epoch
):
def
push
(
self
,
state
,
action
,
reward
,
next_state
,
done
,
epoch
,
state_c_step
=
None
,
low_action
=
None
):
if
len
(
self
.
buffer
)
<
self
.
capacity
:
if
len
(
self
.
buffer
)
<
self
.
capacity
:
self
.
buffer
.
append
(
None
)
self
.
buffer
.
append
(
None
)
self
.
buffer
[
self
.
position
]
=
(
state
,
action
,
reward
,
next_state
,
done
,
epoch
+
1
)
if
not
self
.
use_goal_correct
:
self
.
buffer
[
self
.
position
]
=
(
state
,
action
,
reward
,
next_state
,
done
,
epoch
+
1
)
else
:
assert
not
low_action
==
None
self
.
buffer
[
self
.
position
]
=
(
state
,
action
,
reward
,
next_state
,
done
,
epoch
+
1
,
state_c_step
,
low_action
)
self
.
position
=
(
self
.
position
+
1
)
%
self
.
capacity
self
.
position
=
(
self
.
position
+
1
)
%
self
.
capacity
def
sample
(
self
,
batch_size
):
def
sample
(
self
,
batch_size
):
batch
=
random
.
sample
(
self
.
buffer
,
batch_size
)
batch
=
random
.
sample
(
self
.
buffer
,
batch_size
)
state
,
action
,
reward
,
next_state
,
done
,
_
=
map
(
np
.
stack
,
zip
(
*
batch
))
if
not
self
.
use_goal_correct
:
return
state
,
action
,
reward
,
next_state
,
done
state
,
action
,
reward
,
next_state
,
done
,
_
=
map
(
np
.
stack
,
zip
(
*
batch
))
return
state
,
action
,
reward
,
next_state
,
done
else
:
state
,
action
,
reward
,
next_state
,
done
,
_
,
state_c_step
,
low_action
=
map
(
np
.
stack
,
zip
(
*
batch
))
return
state
,
action
,
reward
,
next_state
,
done
,
state_c_step
,
low_action
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
buffer
)
return
len
(
self
.
buffer
)
...
@@ -36,8 +45,12 @@ class ReplayMemory:
...
@@ -36,8 +45,12 @@ class ReplayMemory:
p_trajectory
=
p_trajectory
.
astype
(
np
.
float64
)
p_trajectory
=
p_trajectory
.
astype
(
np
.
float64
)
idxs
=
np
.
random
.
choice
(
len
(
self
.
buffer
),
size
=
batch_size
,
replace
=
False
,
p
=
p_trajectory
)
idxs
=
np
.
random
.
choice
(
len
(
self
.
buffer
),
size
=
batch_size
,
replace
=
False
,
p
=
p_trajectory
)
batch
=
[
self
.
buffer
[
i
]
for
i
in
idxs
]
batch
=
[
self
.
buffer
[
i
]
for
i
in
idxs
]
state
,
action
,
reward
,
next_state
,
done
,
_
=
map
(
np
.
stack
,
zip
(
*
batch
))
if
not
self
.
use_goal_correct
:
return
state
,
action
,
reward
,
next_state
,
done
state
,
action
,
reward
,
next_state
,
done
,
_
=
map
(
np
.
stack
,
zip
(
*
batch
))
return
state
,
action
,
reward
,
next_state
,
done
else
:
state
,
action
,
reward
,
next_state
,
done
,
_
,
state_c_step
,
low_action
=
map
(
np
.
stack
,
zip
(
*
batch
))
return
state
,
action
,
reward
,
next_state
,
done
,
state_c_step
,
low_action
def
random_sample
(
self
,
batch_size
):
def
random_sample
(
self
,
batch_size
):
idxs
=
np
.
random
.
randint
(
0
,
len
(
self
.
buffer
),
batch_size
)
idxs
=
np
.
random
.
randint
(
0
,
len
(
self
.
buffer
),
batch_size
)
...
...
algos/sac/sac.py
View file @
50ae1f7d
...
@@ -4,10 +4,10 @@ import torch.nn.functional as F
...
@@ -4,10 +4,10 @@ import torch.nn.functional as F
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
algos.sac.utils
import
soft_update
,
hard_update
from
algos.sac.utils
import
soft_update
,
hard_update
from
algos.sac.model
import
GaussianPolicy
,
QNetwork
,
DeterministicPolicy
,
QNetwork_phi
from
algos.sac.model
import
GaussianPolicy
,
QNetwork
,
DeterministicPolicy
,
QNetwork_phi
import
numpy
as
np
class
SAC
(
object
):
class
SAC
(
object
):
def
__init__
(
self
,
num_inputs
,
action_space
,
args
,
pri_replay
,
goal_dim
,
gradient_flow_value
,
abs_range
,
tanh_output
):
def
__init__
(
self
,
num_inputs
,
action_space
,
args
,
pri_replay
,
goal_dim
,
gradient_flow_value
,
abs_range
,
tanh_output
,
use_goal_correct
=
False
):
self
.
gamma
=
args
.
gamma
self
.
gamma
=
args
.
gamma
self
.
tau
=
args
.
tau
self
.
tau
=
args
.
tau
...
@@ -20,6 +20,7 @@ class SAC(object):
...
@@ -20,6 +20,7 @@ class SAC(object):
self
.
device
=
args
.
device
self
.
device
=
args
.
device
self
.
gradient_flow_value
=
gradient_flow_value
self
.
gradient_flow_value
=
gradient_flow_value
self
.
use_goal_correct
=
use_goal_correct
if
not
gradient_flow_value
:
if
not
gradient_flow_value
:
self
.
critic
=
QNetwork
(
num_inputs
,
action_space
.
shape
[
0
],
args
.
hidden_size
)
.
to
(
device
=
self
.
device
)
self
.
critic
=
QNetwork
(
num_inputs
,
action_space
.
shape
[
0
],
args
.
hidden_size
)
.
to
(
device
=
self
.
device
)
...
@@ -64,18 +65,41 @@ class SAC(object):
...
@@ -64,18 +65,41 @@ class SAC(object):
_
,
_
,
action
=
self
.
policy
.
sample
(
state
)
_
,
_
,
action
=
self
.
policy
.
sample
(
state
)
return
action
.
detach
()
.
cpu
()
.
numpy
()[
0
]
return
action
.
detach
()
.
cpu
()
.
numpy
()[
0
]
def
update_parameters
(
self
,
memory
,
batch_size
,
env_params
,
hi_sparse
,
feature_data
):
def
select_num_action
(
self
,
state
,
num
):
action_candidate
=
np
.
array
([])
batch
=
state
.
shape
[
0
]
for
i
in
range
(
num
):
action
,
_
,
_
=
self
.
policy
.
sample
(
state
)
action_candidate
=
np
.
append
(
action_candidate
,
action
.
detach
()
.
cpu
()
.
numpy
())
return
action_candidate
.
reshape
(
batch
,
num
,
-
1
)
def
update_parameters
(
self
,
memory
,
batch_size
,
env_params
,
hi_sparse
,
feature_data
,
low_policy
=
None
,
representation
=
None
):
# Sample a batch from memory
# Sample a batch from memory
if
self
.
pri_replay
:
if
self
.
pri_replay
:
state_batch
,
action_batch
,
reward_batch
,
next_state_batch
,
mask_batch
=
memory
.
pri_sample
(
batch_size
=
batch_size
)
if
not
self
.
use_goal_correct
:
state_batch
,
action_batch
,
reward_batch
,
next_state_batch
,
mask_batch
=
memory
.
pri_sample
(
batch_size
=
batch_size
)
else
:
state_batch
,
action_batch
,
reward_batch
,
next_state_batch
,
mask_batch
,
low_state_c_step_batch
,
low_action_batch
=
memory
.
pri_sample
(
batch_size
=
batch_size
)
else
:
else
:
state_batch
,
action_batch
,
reward_batch
,
next_state_batch
,
mask_batch
=
memory
.
sample
(
batch_size
=
batch_size
)
if
not
self
.
use_goal_correct
:
state_batch
,
action_batch
,
reward_batch
,
next_state_batch
,
mask_batch
=
memory
.
sample
(
batch_size
=
batch_size
)
else
:
state_batch
,
action_batch
,
reward_batch
,
next_state_batch
,
mask_batch
,
low_state_c_step_batch
,
low_action_batch
=
memory
.
sample
(
batch_size
=
batch_size
)
state_batch
=
torch
.
FloatTensor
(
state_batch
)
.
to
(
self
.
device
)
state_batch
=
torch
.
FloatTensor
(
state_batch
)
.
to
(
self
.
device
)
next_state_batch
=
torch
.
FloatTensor
(
next_state_batch
)
.
to
(
self
.
device
)
next_state_batch
=
torch
.
FloatTensor
(
next_state_batch
)
.
to
(
self
.
device
)
action_batch
=
torch
.
FloatTensor
(
action_batch
)
.
to
(
self
.
device
)
action_batch
=
torch
.
FloatTensor
(
action_batch
)
.
to
(
self
.
device
)
reward_batch
=
torch
.
FloatTensor
(
reward_batch
)
.
to
(
self
.
device
)
.
unsqueeze
(
1
)
reward_batch
=
torch
.
FloatTensor
(
reward_batch
)
.
to
(
self
.
device
)
.
unsqueeze
(
1
)
mask_batch
=
torch
.
FloatTensor
(
mask_batch
)
.
to
(
self
.
device
)
.
unsqueeze
(
1
)
mask_batch
=
torch
.
FloatTensor
(
mask_batch
)
.
to
(
self
.
device
)
.
unsqueeze
(
1
)
low_state_c_step_batch
=
torch
.
FloatTensor
(
low_state_c_step_batch
)
.
to
(
self
.
device
)
low_action_batch
=
torch
.
FloatTensor
(
low_action_batch
)
.
to
(
self
.
device
)
# print("state_batch shape:", state_batch.shape)
# print("action_batch shape:", action_batch.shape)
# print("reward_batch:", reward_batch.shape)
# print("mask_batch:", mask_batch.shape)
# print("low_state_c_step_batch shape:", low_state_c_step_batch.shape)
# print("low_action_batch shape:", low_action_batch.shape)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
next_state_action
,
next_state_log_pi
,
_
=
self
.
policy
.
sample
(
next_state_batch
)
next_state_action
,
next_state_log_pi
,
_
=
self
.
policy
.
sample
(
next_state_batch
)
...
@@ -86,7 +110,32 @@ class SAC(object):
...
@@ -86,7 +110,32 @@ class SAC(object):
if
hi_sparse
:
if
hi_sparse
:
# clip target value
# clip target value
next_q_value
=
torch
.
clamp
(
next_q_value
,
-
env_params
[
'max_timesteps'
],
0.
)
next_q_value
=
torch
.
clamp
(
next_q_value
,
-
env_params
[
'max_timesteps'
],
0.
)
qf1
,
qf2
=
self
.
critic
(
state_batch
,
action_batch
)
# Two Q-functions to mitigate positive bias in the policy improvement step
hi_action_candidate_num
=
10
c_step
=
low_state_c_step_batch
.
shape
[
1
]
real_goal_dim
=
env_params
[
"real_goal_dim"
]
if
self
.
use_goal_correct
:
with
torch
.
no_grad
():
action_batch_candidate
=
torch
.
FloatTensor
(
self
.
select_num_action
(
state_batch
,
hi_action_candidate_num
-
2
))
.
to
(
self
.
device
)
# print("action_batch_candidate:", action_batch_candidate.shape)
mean
,
_
=
self
.
policy
(
state_batch
)
# print("mean:", mean.shape)
action_batch_candidate
=
torch
.
cat
([
action_batch_candidate
,
action_batch
.
unsqueeze
(
1
),
mean
.
unsqueeze
(
1
)],
dim
=
1
)
# print("action_batch_candidate:", action_batch_candidate.shape)
ag
=
representation
(
state_batch
[:,
:
env_params
[
"obs"
]])
.
unsqueeze
(
1
)
# print("ag shape:", ag.shape)
goal_batch_candidate
=
action_batch_candidate
+
ag
low_state_batch_candidate
=
torch
.
cat
([
low_state_c_step_batch
.
unsqueeze
(
1
)
.
expand
(
-
1
,
hi_action_candidate_num
,
-
1
,
-
1
),
goal_batch_candidate
.
unsqueeze
(
2
)
.
expand
(
-
1
,
-
1
,
c_step
,
-
1
)],
dim
=-
1
)
# print("low_state_batch_candidate:", low_state_batch_candidate.shape)
goal_correct_index
=
low_policy
.
correct
(
low_state_batch_candidate
,
low_action_batch
)
goal_correct_index
=
goal_correct_index
.
expand
(
-
1
,
hi_action_candidate_num
*
real_goal_dim
)
.
reshape
(
-
1
,
hi_action_candidate_num
,
real_goal_dim
)
action_batch_correct
=
torch
.
gather
(
action_batch_candidate
,
1
,
goal_correct_index
)[:,
0
,:]
# print("action_batch_candidate:", action_batch_candidate)
# print("action_batch_correct:", action_batch_correct)
# print("action_batch:", action_batch)
qf1
,
qf2
=
self
.
critic
(
state_batch
,
action_batch_correct
)
# Two Q-functions to mitigate positive bias in the policy improvement step
# print("qf1", qf1.shape)
# print("qf1", qf1.shape)
# print("next_q", next_q_value.shape)
# print("next_q", next_q_value.shape)
qf1_loss
=
F
.
mse_loss
(
qf1
,
next_q_value
)
# JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
qf1_loss
=
F
.
mse_loss
(
qf1
,
next_q_value
)
# JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
...
...
arguments/arguments_hier_sac.py
View file @
50ae1f7d
...
@@ -18,7 +18,7 @@ def get_args_ant():
...
@@ -18,7 +18,7 @@ def get_args_ant():
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
125
,
help
=
'random seed'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
125
,
help
=
'random seed'
)
parser
.
add_argument
(
'--replay-strategy'
,
type
=
str
,
default
=
'none'
,
help
=
'the HER strategy'
)
parser
.
add_argument
(
'--replay-strategy'
,
type
=
str
,
default
=
'none'
,
help
=
'the HER strategy'
)
parser
.
add_argument
(
'--save-dir'
,
type
=
str
,
default
=
'saved_models/'
,
help
=
'the path to save the models'
)
parser
.
add_argument
(
'--save-dir'
,
type
=
str
,
default
=
'
/lustre/S/gaoyunkai/RL/LESSON/
saved_models/'
,
help
=
'the path to save the models'
)
parser
.
add_argument
(
'--noise-eps'
,
type
=
float
,
default
=
0.2
,
help
=
'noise factor for Gaussian'
)
parser
.
add_argument
(
'--noise-eps'
,
type
=
float
,
default
=
0.2
,
help
=
'noise factor for Gaussian'
)
parser
.
add_argument
(
'--random-eps'
,
type
=
float
,
default
=
0.2
,
help
=
"prob for acting randomly"
)
parser
.
add_argument
(
'--random-eps'
,
type
=
float
,
default
=
0.2
,
help
=
"prob for acting randomly"
)
...
...
goal_env/mujoco/__init__.py
View file @
50ae1f7d
...
@@ -10,6 +10,8 @@ elif sys.argv[0].split('/')[-1] == "train_hier_sac.py":
...
@@ -10,6 +10,8 @@ elif sys.argv[0].split('/')[-1] == "train_hier_sac.py":
from
train_hier_sac
import
args
from
train_hier_sac
import
args
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_hier_double_sac.py"
:
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_hier_double_sac.py"
:
from
train_hier_double_sac
import
args
from
train_hier_double_sac
import
args
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_hier_double_sac_goal_correct.py"
:
from
train_hier_double_sac
import
args
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_hier_ppo.py"
:
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_hier_ppo.py"
:
from
train_hier_ppo
import
args
from
train_hier_ppo
import
args
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_covering.py"
:
elif
sys
.
argv
[
0
]
.
split
(
'/'
)[
-
1
]
==
"train_covering.py"
:
...
...
train_hier_double_sac_goal_correct.py
0 → 100644
View file @
50ae1f7d
import
numpy
as
np
import
gym
from
arguments.arguments_hier_sac
import
get_args_ant
,
get_args_chain
from
algos.hier_double_sac_goal_correct
import
hier_sac_agent
from
goal_env.mujoco
import
*
import
random
import
torch
def
get_env_params
(
env
):
obs
=
env
.
reset
()
# close the environment
params
=
{
'obs'
:
obs
[
'observation'
]
.
shape
[
0
],
'goal'
:
obs
[
'desired_goal'
]
.
shape
[
0
],
'action'
:
env
.
action_space
.
shape
[
0
],
'action_max'
:
env
.
action_space
.
high
[
0
],
'max_timesteps'
:
env
.
_max_episode_steps
}
return
params
def
launch
(
args
):
# create the ddpg_agent
env
=
gym
.
make
(
args
.
env_name
)
test_env
=
gym
.
make
(
args
.
test
)
# if args.env_name == "AntPush-v1":
# test_env1 = gym.make("AntPushTest1-v1")
# test_env2 = gym.make("AntPushTest2-v1")
# elif args.env_name == "AntMaze1-v1":
# test_env1 = gym.make("AntMaze1Test1-v1")
# test_env2 = gym.make("AntMaze1Test2-v1")
# else:
test_env1
=
test_env2
=
None
print
(
"test_env"
,
test_env1
,
test_env2
)
# set random seeds for reproduce
env
.
seed
(
args
.
seed
)
if
args
.
env_name
!=
"NChain-v1"
:
env
.
env
.
env
.
wrapped_env
.
seed
(
args
.
seed
)
test_env
.
env
.
env
.
wrapped_env
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
if
args
.
device
is
not
'cpu'
:
torch
.
cuda
.
manual_seed
(
args
.
seed
)
gym
.
spaces
.
prng
.
seed
(
args
.
seed
)
# get the environment parameters
if
args
.
env_name
[:
3
]
in
[
"Ant"
,
"Poi"
,
"Swi"
]:
env
.
env
.
env
.
visualize_goal
=
args
.
animate
test_env
.
env
.
env
.
visualize_goal
=
args
.
animate
env_params
=
get_env_params
(
env
)
env_params
[
'max_test_timesteps'
]
=
test_env
.
_max_episode_steps
# create the ddpg agent to interact with the environment
sac_trainer
=
hier_sac_agent
(
args
,
env
,
env_params
,
test_env
,
test_env1
,
test_env2
)
if
args
.
eval
:
if
not
args
.
resume
:
print
(
"random policy !!!"
)
# sac_trainer._eval_hier_agent(test_env)
# sac_trainer.vis_hier_policy()
# sac_trainer.cal_slow()
# sac_trainer.visualize_representation(100)
# sac_trainer.vis_learning_process()
# sac_trainer.picvideo('fig/final/', (1920, 1080))
else
:
sac_trainer
.
learn
()
# get the params
args
=
get_args_ant
()
# args = get_args_chain()
# args = get_args_fetch()
# args = get_args_point()
if
__name__
==
'__main__'
:
launch
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment