Skip to content

Commit

Permalink
Merge pull request #116 from NREL/origin/ray2_2
Browse files Browse the repository at this point in the history
Update to use ray v2.2.0, graph-env v0.2.0
  • Loading branch information
jlaw9 authored Jan 23, 2023
2 parents ba86c96 + f9fc994 commit e5a6dd8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 32 deletions.
2 changes: 1 addition & 1 deletion rlmolecule/molecule_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from graphenv import tf
from graphenv.graph_model import GraphModel
from nfp.preprocessing import MolPreprocessor
from ray.rllib.agents.dqn.distributional_q_tf_model import DistributionalQTFModel
from ray.rllib.algorithms.dqn.distributional_q_tf_model import DistributionalQTFModel
from ray.rllib.models.tf.tf_modelv2 import TFModelV2

from rlmolecule.policy.model import policy_model
Expand Down
55 changes: 24 additions & 31 deletions tests/test_ppo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import rdkit
from graphenv.graph_env import GraphEnv
from ray.rllib.agents import ppo
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from rlmolecule.builder import MoleculeBuilder
Expand All @@ -13,23 +13,19 @@

@pytest.fixture
def ppo_config():

config = {
"num_gpus": 0,
"num_workers": 1, # parallelism
"framework": "tf2",
"eager_tracing": False,
"eager_max_retraces": 20,
"rollout_fragment_length": 5,
"train_batch_size": 20,
"sgd_minibatch_size": 2,
"shuffle_sequences": True,
"num_sgd_iter": 1,
"lr": 1e-3,
}

ppo_config = ppo.DEFAULT_CONFIG.copy()
ppo_config.update(config)
ppo_config = (
PPOConfig()
.training(lr=1e-3,
train_batch_size=20,
sgd_minibatch_size=2,
shuffle_sequences=True,
num_sgd_iter=1)
.resources(num_gpus=0)
.framework("tf2")
.rollouts(num_rollout_workers=1, #parallelism
rollout_fragment_length=5)
.debugging(log_level="DEBUG")
)

return ppo_config

Expand Down Expand Up @@ -57,17 +53,14 @@ def create_env(config):

register_env("QEDGraphEnv", lambda config: create_env(config))

config = {
"env": "QEDGraphEnv",
"model": {
"custom_model": "MoleculeModel",
"custom_model_config": {
"preprocessor": load_preprocessor(),
"features": 32,
"num_messages": 1,
},
},
}
ppo_config.update(config)
trainer = ppo.PPOTrainer(config=ppo_config)
ppo_config.environment(env='QEDGraphEnv')
ppo_config.training(model={"custom_model": "MoleculeModel",
"custom_model_config": {
"preprocessor": load_preprocessor(),
"features": 32,
"num_messages": 1,
},
},
)
trainer = ppo_config.build()
trainer.train()

0 comments on commit e5a6dd8

Please sign in to comment.