Commit 135ed1f6 by yuxguo

fix

parent 84297604
......@@ -413,7 +413,7 @@ class SharedGroupMLP(nn.Module):
# mlps indicates different experts
if shared:
mlps = [MLPModel(group_input_dim, group_output_dim,
hidden_dims=hidden_dims, flatten=False) for i in range(nr_mlps)]
hidden_dims=hidden_dims) for i in range(nr_mlps)]
else:
mlps = [GroupMLP(group_input_dim * groups, group_output_dim * groups,
hidden_dims=hidden_dims) for i in range(nr_mlps)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment