Commit 35ab999f by lvzhengyang

seperate optimizer for actor and critic

parent 824b68b7
...@@ -76,9 +76,12 @@ class Policy(nn.Module): ...@@ -76,9 +76,12 @@ class Policy(nn.Module):
critic, critic,
obs_space=None, obs_space=None,
action_space=None, action_space=None,
optimizer=None, actor_optimizer=None,
lr=2.5e-4, critic_optimizer=None,
weight_decay=5e-4, lr_actor=2.5e-4,
lr_critic=2.5e-4,
weight_decay_actor=5e-4,
weight_decay_critic=5e-4,
loss_fn_critic=torch.nn.MSELoss() loss_fn_critic=torch.nn.MSELoss()
) -> None: ) -> None:
super(Policy, self).__init__() super(Policy, self).__init__()
...@@ -87,8 +90,10 @@ class Policy(nn.Module): ...@@ -87,8 +90,10 @@ class Policy(nn.Module):
self.obs_space = obs_space self.obs_space = obs_space
self.action_space = action_space self.action_space = action_space
if optimizer == None: if actor_optimizer == None:
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr_actor, weight_decay=weight_decay_actor)
if critic_optimizer == None:
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr_critic, weight_decay=weight_decay_critic)
self.loss_fn_critic = loss_fn_critic self.loss_fn_critic = loss_fn_critic
self.log_probs_buf = { self.log_probs_buf = {
...@@ -181,7 +186,8 @@ class Policy(nn.Module): ...@@ -181,7 +186,8 @@ class Policy(nn.Module):
@param rewards: reward of an episode in a batch. [#num_batch] @param rewards: reward of an episode in a batch. [#num_batch]
@note the env has returned 'done' signal @note the env has returned 'done' signal
""" """
self.optimizer.zero_grad() self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
# cal log(p_res) for each trajectory in the batch # cal log(p_res) for each trajectory in the batch
u0 = self.log_probs_buf['u0'] # [#num_batch] u0 = self.log_probs_buf['u0'] # [#num_batch]
...@@ -194,14 +200,15 @@ class Policy(nn.Module): ...@@ -194,14 +200,15 @@ class Policy(nn.Module):
with torch.no_grad(): with torch.no_grad():
baselines = self.critic(nodes) baselines = self.critic(nodes)
j = (baselines - rewards) * p_res j = (baselines - rewards) * p_res
j = j.mean() j = -1.0 * j.mean() # the j should be larger after a backward process.
baselines = self.critic(nodes) baselines = self.critic(nodes)
loss_critic = self.loss_fn_critic(baselines, rewards) loss_critic = self.loss_fn_critic(baselines, rewards)
j.backward() j.backward()
loss_critic.backward() loss_critic.backward()
self.optimizer.step() self.actor_optimizer.step()
self.critic_optimizer.step()
# Finally, reset the buf # Finally, reset the buf
self.log_probs_buf = { self.log_probs_buf = {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment