Commit 35ab999f by lvzhengyang

seperate optimizer for actor and critic

parent 824b68b7
......@@ -76,9 +76,12 @@ class Policy(nn.Module):
critic,
obs_space=None,
action_space=None,
optimizer=None,
lr=2.5e-4,
weight_decay=5e-4,
actor_optimizer=None,
critic_optimizer=None,
lr_actor=2.5e-4,
lr_critic=2.5e-4,
weight_decay_actor=5e-4,
weight_decay_critic=5e-4,
loss_fn_critic=torch.nn.MSELoss()
) -> None:
super(Policy, self).__init__()
......@@ -87,8 +90,10 @@ class Policy(nn.Module):
self.obs_space = obs_space
self.action_space = action_space
if optimizer == None:
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
if actor_optimizer == None:
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr_actor, weight_decay=weight_decay_actor)
if critic_optimizer == None:
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr_critic, weight_decay=weight_decay_critic)
self.loss_fn_critic = loss_fn_critic
self.log_probs_buf = {
......@@ -181,7 +186,8 @@ class Policy(nn.Module):
@param rewards: reward of an episode in a batch. [#num_batch]
@note the env has returned 'done' signal
"""
self.optimizer.zero_grad()
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
# cal log(p_res) for each trajectory in the batch
u0 = self.log_probs_buf['u0'] # [#num_batch]
......@@ -194,14 +200,15 @@ class Policy(nn.Module):
with torch.no_grad():
baselines = self.critic(nodes)
j = (baselines - rewards) * p_res
j = j.mean()
j = -1.0 * j.mean() # the j should be larger after a backward process.
baselines = self.critic(nodes)
loss_critic = self.loss_fn_critic(baselines, rewards)
j.backward()
loss_critic.backward()
self.optimizer.step()
self.actor_optimizer.step()
self.critic_optimizer.step()
# Finally, reset the buf
self.log_probs_buf = {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment