Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to correctly add state representation methods such as AE? #2080

Open
4 tasks done
angel-ayala opened this issue Feb 5, 2025 · 3 comments
Open
4 tasks done

how to correctly add state representation methods such as AE? #2080

angel-ayala opened this issue Feb 5, 2025 · 3 comments
Labels
question Further information is requested

Comments

@angel-ayala
Copy link

❓ Question

Hi, I would like to thank you for the effort to keep this repo updated.
I already implemented and got a little bit familiar with the ecosystem, mainly with the NN models, policy, and algos.
I'm intended to extend three main existing algos (SAC, TD3, and PPO) to use some representation learning techniques such as SPR but I would like to first try a vanilla autoencoder, specifically a VAE.

I saw that I must extend each algorithm class to include the VAE model and perform joint optimization of the critic and the encoder/decoder stage in the train method, however, I was wondering if this is enough, or should I need to consider other aspects.

The AE architecture considers three functions, one online encoder and decoder, and a target encoder.

My main concerns are:

  • Critic inference using the online encoder or target encoder to process current or next observation.
  • Action inference using a detached version of the encoder to prevent gradient propagation (or make it configurable if I want to).
  • Logging the loss value of the reconstruction.
  • AE model saving and loading.

I already started and I was able to create a custom model and policy following the documentation, I really appreciate any guidance on this aspects to not make a huge mess and get unexpected outcomes.

Checklist

@angel-ayala angel-ayala added the question Further information is requested label Feb 5, 2025
@angel-ayala angel-ayala changed the title [Question] question title how to correctly add state representation methods such as AE? Feb 5, 2025
@angel-ayala
Copy link
Author

I did an example script that seems to work, but now I was wondering with the policy_noise and noise_clip attributes from TD3 for my problem.

My environment comprises a continuous action space with four elements, three linear velocities and one angular velocity. Therefore, the action value scales differently. Should I address each noise value differently, e.g., process the policy_noise as a list? have any difference?

Here is my code, if you have any suggestion that probably I'm missing, I really appreciate it.

from typing import Optional

import copy
import numpy as np
import torch as th
from torch.nn import functional as F

from stable_baselines3.common.policies import ContinuousCritic
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.type_aliases import PyTorchObs
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update

from stable_baselines3.td3.policies import Actor
from stable_baselines3.td3.policies import TD3Policy

from sb3_srl.td3 import TD3
from sb3_srl.autoencoders import instance_autoencoder


class SRLTD3Policy(TD3Policy):
    def __init__(self, *args,
                 ae_type: str = 'Vector', ae_params : dict = {},
                 **kwargs):
        self.features_dim = ae_params['latent_dim']
        super(SRLTD3Policy, self).__init__(*args, **kwargs)

        self.make_autencoder(ae_type, ae_params)
        self.encoder_target = copy.deepcopy(self.ae_model.encoder)
        self.encoder_target.train(False)

    def make_autencoder(self, ae_type, ae_params):
        self.ae_model = instance_autoencoder(ae_type, ae_params)
        self.ae_model.adam_optimizer(ae_params['encoder_lr'],
                                     ae_params['decoder_lr'])

    def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
        actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
        actor_kwargs["features_dim"] = self.features_dim
        return Actor(**actor_kwargs).to(self.device)

    def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
        critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
        critic_kwargs["features_dim"] = self.features_dim
        return ContinuousCritic(**critic_kwargs).to(self.device)

    def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
        # Note: the deterministic deterministic parameter is ignored in the case of TD3.
        #   Predictions are always deterministic.
        with th.no_grad():
            obs_z = self.ae_model.encoder(observation)
        return self.actor(obs_z)

    def set_training_mode(self, mode: bool) -> None:
        self.actor.set_training_mode(mode)
        self.critic.set_training_mode(mode)
        self.ae_model.set_training_mode(mode)
        self.training = mode


class SRLTD3(TD3):
    def __init__(self, *args, **kwargs):
        super(SRLTD3, self).__init__(*args, **kwargs)

    def _create_aliases(self) -> None:
        super()._create_aliases()
        self.encoder = self.policy.ae_model.encoder
        self.decoder = self.policy.ae_model.decoder
        self.encoder_target = self.policy.encoder_target

    def _setup_model(self) -> None:
        super()._setup_model()
        self.policy.ae_model.to(self.device)
        # Running mean and running var
        self.encoder_batch_norm_stats = get_parameters_by_name(self.encoder, ["running_"])
        self.encoder_batch_norm_stats_target = get_parameters_by_name(self.encoder_target, ["running_"])

    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)

        # Update learning rate according to lr schedule
        self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])

        actor_losses, critic_losses = [], []
        ae_losses, l2_losses = [], []
        for _ in range(gradient_steps):
            self._n_updates += 1
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)  # type: ignore[union-attr]

            with th.no_grad():
                obs = self.encoder(replay_data.observations)
                next_obs = self.encoder_target(replay_data.next_observations)
                # Select action according to policy and add clipped noise
                # noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise)
                noise = th.randn_like(replay_data.actions) * self.target_policy_noise
                noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
                next_actions = (self.actor_target(next_obs) + noise).clamp(-1, 1)

                # Compute the next Q-values: min over all critics targets
                next_q_values = th.cat(self.critic_target(next_obs, next_actions), dim=1)
                next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates for each critic network
            current_q_values = self.critic(obs, replay_data.actions)

            # Compute critic loss
            critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
            assert isinstance(critic_loss, th.Tensor)
            critic_losses.append(critic_loss.item())

            # Optimize the critics
            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()

            # Compute reconstruction loss
            rep_loss, [rec_loss, latent_loss] = self.policy.ae_model.compute_representation_loss(
                replay_data.observations, replay_data.actions, replay_data.next_observations)
            ae_losses.append(rec_loss.item())
            l2_losses.append(latent_loss.item())
            self.policy.ae_model.update_representation(rep_loss)

            # Delayed policy updates
            if self._n_updates % self.policy_delay == 0:
                obs = self.encoder(replay_data.observations).detach()
                # Compute actor loss
                actor_loss = -self.critic.q1_forward(obs, self.actor(obs)).mean()
                actor_losses.append(actor_loss.item())

                # Optimize the actor
                self.actor.optimizer.zero_grad()
                actor_loss.backward()
                self.actor.optimizer.step()

                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
                polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau)
                polyak_update(self.encoder.parameters(), self.encoder_target.parameters(), self.tau)
                # Copy running stats, see GH issue #996
                polyak_update(self.critic_batch_norm_stats, self.critic_batch_norm_stats_target, 1.0)
                polyak_update(self.actor_batch_norm_stats, self.actor_batch_norm_stats_target, 1.0)
                polyak_update(self.encoder_batch_norm_stats, self.encoder_batch_norm_stats_target, 1.0)

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        if len(actor_losses) > 0:
            self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        self.logger.record("train/ae_loss", np.mean(ae_losses))
        self.logger.record("train/l2_loss", np.mean(l2_losses))

@araffin
Copy link
Member

araffin commented Feb 12, 2025

Hello,
for integrating SRL with SB3, you can have a look at https://github.com/araffin/aae-train-donkeycar/blob/master/ae/wrapper.py (a VecEnv wrapper would be better when using multiple envs).
For the noise, please have a look at the RL tips in the docs (and videos that are linked), you should normalise your action space to not worry about scale later.

@angel-ayala
Copy link
Author

I just reviewed the link, and seems to first pretrain the model to use it later, that is the most simple implementation for SRL, I was looking to do a joint optimization of both the policy and the representation models.
I just find out a method called of self-predictive representation that works on that way (published here https://github.com/twni2016/self-predictive-rl/tree/main) and I would like to perform some experiments with other algorithms.
The code on the second comment seems ok to you for this purpose, I know that is a lot of copy and paste from the original code but was the most straightforward for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants