Skip to content

Commit

Permalink
Merge pull request #83 from mrmitzh/master
Browse files Browse the repository at this point in the history
fix index error when using higher pytorch version
  • Loading branch information
junxiaosong authored Mar 5, 2019
2 parents 66292c5 + 961baf4 commit 10f256e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ and uncomment the line
or
# from policy_value_net_tensorflow import PolicyValueNet # Tensorflow
```
and then execute: ``python train.py`` (To use GPU in PyTorch, set ``use_gpu=True``)
and then execute: ``python train.py`` (To use GPU in PyTorch, set ``use_gpu=True`` and use ``return loss.item(), entropy.item()`` in function train_step in policy_value_net_pytorch.py if your pytorch version is greater than 0.5)

The models (best_policy.model and current_policy.model) will be saved every a few updates (default 50).

Expand Down
2 changes: 2 additions & 0 deletions policy_value_net_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def train_step(self, state_batch, mcts_probs, winner_batch, lr):
torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
)
return loss.data[0], entropy.data[0]
#for pytorch version >= 0.5 please use the following line instead.
#return loss.item(), entropy.item()

def get_policy_param(self):
net_params = self.policy_value_net.state_dict()
Expand Down

0 comments on commit 10f256e

Please sign in to comment.