-
Notifications
You must be signed in to change notification settings - Fork 555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features/openai hacks #35
base: master
Are you sure you want to change the base?
Changes from 15 commits
052a5c8
7969e31
af79a23
3378fe2
000c2e7
830daaf
088dae1
ec231e7
51f9c17
648e939
9245122
cee2a86
78150e8
050e86f
97936cc
a768c46
91c3e2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
data/ | ||
*/*/mjkey.txt | ||
**/.DS_STORE | ||
**/*.pyc | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import gym | ||
|
||
import rlkit.torch.pytorch_util as ptu | ||
from rlkit.exploration_strategies.base import ( | ||
PolicyWrappedWithExplorationStrategy | ||
) | ||
from rlkit.exploration_strategies.gaussian_and_epsilon_strategy import ( | ||
GaussianAndEpsilonStrategy | ||
) | ||
from rlkit.torch.her.her import HerTd3 | ||
import rlkit.samplers.rollout_functions as rf | ||
|
||
|
||
from rlkit.torch.networks import FlattenMlp, MlpPolicy, QNormalizedFlattenMlp, CompositeNormalizedMlpPolicy | ||
from rlkit.torch.data_management.normalizer import CompositeNormalizer | ||
|
||
|
||
def experiment(variant): | ||
try: | ||
import robotics_recorder | ||
except ImportError as e: | ||
print(e) | ||
|
||
env = gym.make(variant['env_id']) | ||
es = GaussianAndEpsilonStrategy( | ||
action_space=env.action_space, | ||
max_sigma=.2, | ||
min_sigma=.2, # constant sigma | ||
epsilon=.3, | ||
) | ||
obs_dim = env.observation_space.spaces['observation'].low.size | ||
goal_dim = env.observation_space.spaces['desired_goal'].low.size | ||
action_dim = env.action_space.low.size | ||
|
||
shared_normalizer = CompositeNormalizer(obs_dim + goal_dim, action_dim, obs_clip_range=5) | ||
|
||
qf1 = QNormalizedFlattenMlp( | ||
input_size=obs_dim + goal_dim + action_dim, | ||
output_size=1, | ||
hidden_sizes=[400, 300], | ||
composite_normalizer=shared_normalizer | ||
) | ||
qf2 = QNormalizedFlattenMlp( | ||
input_size=obs_dim + goal_dim + action_dim, | ||
output_size=1, | ||
hidden_sizes=[400, 300], | ||
composite_normalizer=shared_normalizer | ||
) | ||
import torch | ||
policy = CompositeNormalizedMlpPolicy( | ||
input_size=obs_dim + goal_dim, | ||
output_size=action_dim, | ||
hidden_sizes=[400, 300], | ||
composite_normalizer=shared_normalizer, | ||
output_activation=torch.tanh | ||
) | ||
exploration_policy = PolicyWrappedWithExplorationStrategy( | ||
exploration_strategy=es, | ||
policy=policy, | ||
) | ||
|
||
from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer | ||
|
||
observation_key = 'observation' | ||
desired_goal_key = 'desired_goal' | ||
achieved_goal_key = desired_goal_key.replace("desired", "achieved") | ||
|
||
replay_buffer = ObsDictRelabelingBuffer( | ||
env=env, | ||
observation_key=observation_key, | ||
desired_goal_key=desired_goal_key, | ||
achieved_goal_key=achieved_goal_key, | ||
**variant['replay_buffer_kwargs'] | ||
) | ||
|
||
algorithm = HerTd3( | ||
her_kwargs=dict( | ||
observation_key='observation', | ||
desired_goal_key='desired_goal' | ||
), | ||
td3_kwargs = dict( | ||
env=env, | ||
qf1=qf1, | ||
qf2=qf2, | ||
policy=policy, | ||
exploration_policy=exploration_policy | ||
), | ||
replay_buffer=replay_buffer, | ||
**variant['algo_kwargs'] | ||
) | ||
|
||
if variant.get("save_video", True): | ||
rollout_function = rf.create_rollout_function( | ||
rf.multitask_rollout, | ||
max_path_length=algorithm.max_path_length, | ||
observation_key=algorithm.observation_key, | ||
desired_goal_key=algorithm.desired_goal_key, | ||
) | ||
video_func = get_video_save_func( | ||
rollout_function, | ||
env, | ||
algorithm.eval_policy, | ||
variant, | ||
) | ||
algorithm.post_epoch_funcs.append(video_func) | ||
|
||
algorithm.to(ptu.device) | ||
algorithm.train() | ||
|
||
|
||
if __name__ == "__main__": | ||
variant = dict( | ||
algo_kwargs=dict( | ||
num_epochs=5000, | ||
num_steps_per_epoch=1000, | ||
num_steps_per_eval=500, | ||
max_path_length=50, | ||
batch_size=128, | ||
discount=0.98, | ||
save_algorithm=True, | ||
), | ||
replay_buffer_kwargs=dict( | ||
max_size=100000, | ||
fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper | ||
fraction_goals_env_goals=0.0, | ||
), | ||
render=False, | ||
env_id="FetchPickAndPlace-v1", | ||
doodad_docker_image="", # Set | ||
gpu_doodad_docker_image="", # Set | ||
save_video=False, | ||
save_video_period=50, | ||
) | ||
|
||
from rlkit.launchers.launcher_util import run_experiment | ||
|
||
run_experiment( | ||
experiment, | ||
exp_prefix="her_td3_gym_fetch_pnp_test", # Make sure no spaces... | ||
region="us-east-2", | ||
mode='here_no_doodad', | ||
variant=variant, | ||
use_gpu=True, # Note: online normalization is very slow without GPU. | ||
spot_price=.5, | ||
snapshot_mode='gap_and_last', | ||
snapshot_gap=100, | ||
num_exps_per_instance=2 | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
from rlkit.policies.base import Policy | ||
from rlkit.torch import pytorch_util as ptu | ||
from rlkit.torch.core import PyTorchModule | ||
from rlkit.torch.data_management.normalizer import TorchFixedNormalizer | ||
from rlkit.torch.data_management.normalizer import TorchFixedNormalizer, TorchNormalizer, CompositeNormalizer | ||
from rlkit.torch.modules import LayerNorm | ||
|
||
|
||
|
@@ -89,6 +89,49 @@ def forward(self, *inputs, **kwargs): | |
return super().forward(flat_inputs, **kwargs) | ||
|
||
|
||
class QNormalizedFlattenMlp(FlattenMlp): | ||
def __init__( | ||
self, | ||
*args, | ||
composite_normalizer: CompositeNormalizer = None, | ||
**kwargs | ||
): | ||
self.save_init_params(locals()) | ||
super().__init__(*args, **kwargs) | ||
assert composite_normalizer is not None | ||
self.composite_normalizer = composite_normalizer | ||
|
||
def forward( | ||
self, | ||
observations, | ||
actions, | ||
return_preactivations=False): | ||
obs, _ = self.composite_normalizer.normalize_all(observations, None) | ||
flat_input = torch.cat((obs, actions), dim=1) | ||
return super().forward(flat_input, return_preactivations=return_preactivations) | ||
|
||
|
||
class VNormalizedFlattenMlp(FlattenMlp): | ||
def __init__( | ||
self, | ||
*args, | ||
composite_normalizer: CompositeNormalizer = None, | ||
**kwargs | ||
): | ||
self.save_init_params(locals()) | ||
super().__init__(*args, **kwargs) | ||
assert composite_normalizer is not None | ||
self.composite_normalizer = composite_normalizer | ||
|
||
def forward( | ||
self, | ||
observations, | ||
return_preactivations=False): | ||
obs, _ = self.composite_normalizer.normalize_all(observations, None) | ||
flat_input = obs | ||
return super().forward(flat_input, return_preactivations=return_preactivations) | ||
|
||
|
||
class MlpPolicy(Mlp, Policy): | ||
""" | ||
A simpler interface for creating policies. | ||
|
@@ -117,10 +160,29 @@ def get_actions(self, obs): | |
return self.eval_np(obs) | ||
|
||
|
||
class CompositeNormalizedMlpPolicy(MlpPolicy): | ||
def __init__( | ||
self, | ||
*args, | ||
composite_normalizer: CompositeNormalizer = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like we can just make this a required argument rather than kwarg. |
||
**kwargs | ||
): | ||
assert composite_normalizer is not None | ||
self.save_init_params(locals()) | ||
super().__init__(*args, **kwargs) | ||
self.composite_normalizer = composite_normalizer | ||
|
||
def forward(self, obs, **kwargs): | ||
if self.composite_normalizer: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check seems a bit redundant given the assert statement in |
||
obs, _ = self.composite_normalizer.normalize_all(obs, None) | ||
return super().forward(obs, **kwargs) | ||
|
||
|
||
class TanhMlpPolicy(MlpPolicy): | ||
""" | ||
A helper class since most policies have a tanh output activation. | ||
""" | ||
def __init__(self, *args, **kwargs): | ||
self.save_init_params(locals()) | ||
super().__init__(*args, output_activation=torch.tanh, **kwargs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,9 @@ def __init__( | |
tau=0.005, | ||
qf_criterion=None, | ||
optimizer_class=optim.Adam, | ||
|
||
policy_preactivation_loss=True, | ||
policy_preactivation_coefficient=1.0, | ||
clip_q=True, | ||
**kwargs | ||
): | ||
super().__init__( | ||
|
@@ -71,6 +73,9 @@ def __init__( | |
self.policy.parameters(), | ||
lr=policy_learning_rate, | ||
) | ||
self.clip_q = clip_q | ||
self.policy_preactivation_penalty = policy_preactivation_loss | ||
self.policy_preactivation_coefficient = policy_preactivation_coefficient | ||
|
||
def _do_training(self): | ||
batch = self.get_batch() | ||
|
@@ -99,6 +104,14 @@ def _do_training(self): | |
target_q1_values = self.target_qf1(next_obs, noisy_next_actions) | ||
target_q2_values = self.target_qf2(next_obs, noisy_next_actions) | ||
target_q_values = torch.min(target_q1_values, target_q2_values) | ||
|
||
if self.clip_q: | ||
target_q_values = torch.clamp( | ||
target_q_values, | ||
-1/(1-self.discount), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you make this a parameter rather than hard-coding it? It could be something like:
in |
||
0 | ||
) | ||
|
||
q_target = rewards + (1. - terminals) * self.discount * target_q_values | ||
q_target = q_target.detach() | ||
|
||
|
@@ -123,9 +136,12 @@ def _do_training(self): | |
|
||
policy_actions = policy_loss = None | ||
if self._n_train_steps_total % self.policy_and_target_update_period == 0: | ||
policy_actions = self.policy(obs) | ||
policy_actions, policy_preactivations = self.policy(obs, return_preactivations=True) | ||
q_output = self.qf1(obs, policy_actions) | ||
|
||
policy_loss = - q_output.mean() | ||
if self.policy_preactivation_penalty: | ||
policy_loss += self.policy_preactivation_coefficient * (policy_preactivations ** 2).mean() | ||
|
||
self.policy_optimizer.zero_grad() | ||
policy_loss.backward() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this fix!