Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Custom Sub-Hyperparameters during train.py -> Optimize #431

Open
5 tasks done
kingjin94 opened this issue Dec 18, 2023 · 1 comment
Open
5 tasks done

[Bug]: Custom Sub-Hyperparameters during train.py -> Optimize #431

kingjin94 opened this issue Dec 18, 2023 · 1 comment
Labels
bug Something isn't working Maintainers on vacation Maintainers are on vacation so they can recharge their batteries, we will be back soon ;)

Comments

@kingjin94
Copy link

kingjin94 commented Dec 18, 2023

🐛 Bug

I am developing a custom Feature Extractor Type (based on DeepSets) for SB3 and want to train + optimize it with sb3_zoo. For it I add the following to a custom config.py file:

gym.register(
    "env-name",
    class,
    kwargs)

hyperparams = {
    "env-name": dict(
        policy="MultiInputPolicy",
        policy_kwargs={
            "features_extractor_class": FeatureExtractorSet,
            "features_extractor_kwargs": {
                "features_dim": 10
            }
        }
    )
}

This works well with the normal train.py (Arguments: '--algo', 'a2c', '--conf-file', 'path/to/config.py', '--gym-packages', 'path.to.config', '--n-timesteps', '100', '--device', 'cpu', '-P', '--env', 'env-name', ...)

When adding '-optimize' the training fails (actions contain NaN as I encode invalid observations that are discarded by the custom FeatureExtractorSet with NaN). Closer investigation shows that the objective function updated self._hyperparams which contains the sub-dict {'policy_kwargs': {'feature_extractor_class': FeatureExtractorSet}} with the sampled hyper-parameters that also set other policy_kwargs then feature_extractor_class.

I would suggest replacing

kwargs.update(sampled_hyperparams)
with a deep_update (e.g. from pydantic).

To Reproduce

No response

Relevant log output / Error message

No response

System Info

  • OS: Linux-5.15.0-91-generic-x86_64-with-glibc2.31 # 101~20.04.1-Ubuntu SMP Thu Nov 16 14:22:28 UTC 2023
  • Python: 3.9.18
  • Stable-Baselines3: 2.2.1
  • PyTorch: 2.1.1+cu121
  • GPU Enabled: True
  • Numpy: 1.26.2
  • Cloudpickle: 3.0.0
  • Gymnasium: 0.29.1

Checklist

@kingjin94 kingjin94 added the bug Something isn't working label Dec 18, 2023
@araffin araffin added the Maintainers on vacation Maintainers are on vacation so they can recharge their batteries, we will be back soon ;) label Dec 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Maintainers on vacation Maintainers are on vacation so they can recharge their batteries, we will be back soon ;)
2 participants