Commit 0de80caa by haoyifan

hh

parent d28758bd
......@@ -8,9 +8,8 @@ import sys
from metrics import update_metric, update_speaker_prob, update_listener_prob, update_R_and_MIS
# hyperparameters
MAX_EP = 10000000 # maximum count of training data
A_LR = 0.002 # learning rate for actor
C_LR = 0.002 # learning rate for critic
MAX_EP = 10000000 # maximum count of training data
LR = 0.002 # learning rate
BATCH_SIZE = 128
EPSILON = 0.2
CLIP = 0.2
......@@ -76,7 +75,7 @@ class Speaker(object):
surrogate = ratio * self.tf_reward
self.loss = -tf.reduce_mean(tf.minimum(surrogate, tf.clip_by_value(ratio, 1. - CLIP, 1. + CLIP) * self.tf_reward))
self.train = tf.train.AdamOptimizer(A_LR).minimize(self.loss, var_list = [speak_params])
self.train = tf.train.AdamOptimizer(LR).minimize(self.loss, var_list = [speak_params])
self.sess.run(tf.global_variables_initializer())
......@@ -215,7 +214,7 @@ class Listener(object):
surrogate = ratio * self.tf_reward
self.loss = -tf.reduce_mean(tf.minimum(surrogate, tf.clip_by_value(ratio, 1. - CLIP, 1. + CLIP) * self.tf_reward))
self.train = tf.train.AdamOptimizer(A_LR).minimize(self.loss, var_list = [listen_params])
self.train = tf.train.AdamOptimizer(LR).minimize(self.loss, var_list = [listen_params])
self.sess.run(tf.global_variables_initializer())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment