Skip to content

How to inherit the RecurrentActorCriticPolicy custom model #308

@fishingcatgo

Description

@fishingcatgo

🐛 Bug

code:

import torch as th
import torch.nn as nn
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.distributions import DiagGaussianDistribution
import backtrader as bt
import pandas as pd
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

import gymnasium as gym # 导入gymnasium用于自定义环境
from gymnasium import spaces # 导入spaces用于定义空间

from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy

"""
环境维度格式的规范: 根据obs结构适配lstm
obs的格式维度:window x Features (window为时间步长,Features为特征数,
例如:取时间窗口为5的,ohlcv(5)、macd、boll特征的数据,即5x7),
形状 = 时间步长维度(个数)x 特征维度(个数)
返回的数据类型为:np.array,将列表转为np.array
模型输入维度:obs=obs.unsqueeze(0),加入 batch_size,
# obs的形状为([1, 4]),需要扩展成(batch_size, seq_len, hidden_size)形状,添加sequence_length纬度
seq_len x hidden_size = window x Features(seq_len为window,hidden_size为Features)
"""

=== 自定义 LSTM 特征提取器 ===

class LSTMFeatureExtractor(BaseFeaturesExtractor):
# def init(self, observation_space, features_dim=128):
# super().init(observation_space, features_dim)

# def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
    super().__init__(observation_space, features_dim)
    # input_dim = observation_space.shape[1]

    shape = observation_space.shape
    if len(shape) == 1:
        # 1D,例如 shape=(101,)
        print('obs维度1D:',shape)
        input_dim = observation_space.shape[0]
    elif len(shape) == 2:
        # 2D,例如 shape=(1, 101)
        print('obs维度2D:',shape)
        input_dim = observation_space.shape[1]
    else:
        raise ValueError(f"obs维度不适配特征提取: {obs.shape}")
    
    print('observation_space:',type(observation_space),observation_space.shape)
    print('feature input_dim:',input_dim)


    self.lstm = nn.LSTM(
        input_size=input_dim,
        hidden_size=features_dim,
        num_layers=1,
        batch_first=True,
        # bidirectional=True
    )

def forward(self, obs):
    # print('输入obs形状:', obs.shape)  # 打印输入obs的形状
    # print('输入obs类型:', type(obs),obs)  # 打印输入obs的类型
    # obs=obs.unsqueeze(1) # obs的形状为([1, 4]),需要扩展成(batch_size, seq_len, hidden_size)形状,添加sequence_length纬度
    
     # 维度适配
    # obs = np.array(obs.shape)
    # print(' model obs维度:',obs.ndim)
    if obs.ndim == 2:
        # 1D,例如 shape=(101,)
        print('obs batch维度2D:',obs.shape)
        obs=obs.unsqueeze(1) # obs的形状为([1, 4]),需要扩展成(batch_size, seq_len, hidden_size)形状,添加sequence_length纬度

    elif obs.ndim == 3:
        # 2D,例如 shape=(1, 101)
        # 3维时不作处理
        print('obs batch维度2D:',obs.shape)
        pass
       
    else:
        raise ValueError(f"obs维度不适配自定义模型: {obs.shape}")
    
    print('obs适配维度:',obs.ndim,obs.shape)

    # features, _ = self.lstm(obs)
    features,(h_n, c_n) = self.lstm(obs)
    # print('obs形状:',obs.shape)
    # print("features形状:",features.shape, h_n.shape, c_n.shape)
   
    return features[:, -1, :]
"""
首先,features, _ = self.lstm(obs) 这一行将输入的观测数据 obs 传入 LSTM 层。
LSTM 层通常会返回两个值:所有时间步的输出(features),以及最后一个时间步的隐藏状态(这里用 _ 忽略了)。
features 的形状通常是 (batch_size, seq_len, hidden_size),即每个样本、每个时间步、每个隐藏单元的输出。

接下来,return features[:, -1, :] 表示返回每个样本在序列最后一个时间步的 LSTM 输出。
这里 : 表示所有样本,-1 表示序列的最后一个时间步,: 表示所有隐藏单元。
这样做的目的是只保留每个序列最后时刻的特征,通常用于后续的决策或分类任务。

features[:, -1, :] 取出每个样本最后一个时间步的 LSTM 输出,形状为 (batch_size, hidden_size)。
output[:, -1, :] 直接取出 最后时间步 的隐藏状态。
对于单向 LSTM,这与 h_n[-1] 等价;双向或多层时需注意维度与拼接顺序。
"""

policy_kwargs = dict(
features_extractor_class=LSTMFeatureExtractor,
# features_extractor_kwargs=dict(features_dim=128),
features_extractor_kwargs=dict(features_dim=1024),

)

env = gym.make("CartPole-v1", render_mode="rgb_array")

# model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)

# model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)

model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1,batch_size=8)

# model.learn(10000,progress_bar=True)

# model.learn(10000)

model.learn(1)

# quit()

=== 自定义策略网络(Actor + Critic) ===

class CustomNetwork(nn.Module):
"""
Custom network for policy and value function.
It receives as input the features extracted by the features extractor.

:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
:param last_layer_dim_vf: (int) number of units for the last layer of the value network
"""

def __init__(
    self,
    feature_dim: int,
    # # last_layer_dim_pi: int = 64,
    last_layer_dim_pi: int = 128,
    last_layer_dim_vf: int = 128,

    # last_layer_dim_pi: int = 1024,
    # last_layer_dim_vf: int = 1024,
):
    super().__init__()

    # IMPORTANT:
    # Save output dimensions, used to create the distributions
    self.latent_dim_pi = last_layer_dim_pi
    self.latent_dim_vf = last_layer_dim_vf

    # Policy network
    # self.policy_net = nn.Sequential(
    #     nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
    # )

    self.policy_net = nn.Sequential(
        nn.Linear(feature_dim, 512),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.LayerNorm(256),
        nn.Linear(256, last_layer_dim_pi),
        nn.ReLU()
    )


    # 使用LSTM网络
    # self.policy_net =  nn.LSTM(
    #     input_size=feature_dim,
    #     hidden_size=last_layer_dim_pi,
    #     num_layers=1,
    #     batch_first=True
    # )

    # Value network
    # self.value_net = nn.Sequential(
    #     nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
    # )


    self.value_net = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, last_layer_dim_pi),
            nn.ReLU()
        )


    # 使用LSTM网络
    # self.value_net =  nn.LSTM(
    #     input_size=feature_dim,
    #     hidden_size=last_layer_dim_vf,
    #     num_layers=1,
    #     batch_first=True
    # )


 

def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
    """
    :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
        If all layers are shared, then ``latent_policy == latent_value``
    """
    return self.forward_actor(features), self.forward_critic(features)

def forward_actor(self, features: th.Tensor) -> th.Tensor: #使用Lstm网络
    return self.policy_net(features)
    # print('输入features_policy_net形状:',features.shape)
    # features,(h_n, c_n)= self.policy_net(features.unsqueeze(1))  # 添加sequence_length纬度
    # print('输出Lstm_policy_net形状:',features.shape, h_n.shape, c_n.shape)
    # print('返回Lstm_policy_net',features[:, -1, :].shape)

    # 使用LSTM网络
    # features,(h_n, c_n)= self.policy_net(features.unsqueeze(1))  # 添加sequence_length纬度
    # return features[:, -1, :]

def forward_critic(self, features: th.Tensor) -> th.Tensor: #使用MLP前馈网络
    return self.value_net(features)
    # print('输入features_value_net形状:',features.shape)
    # output = self.value_net(features)  # 添加sequence_length纬度
    # # print('输出value_net形状:',output.shape)
    # return self.value_net(features)

    # 使用LSTM网络
    # features,(h_n, c_n)= self.value_net(features.unsqueeze(1))  # 添加sequence_length纬度
    # return features[:, -1, :]

class CustomActorCriticPolicy(ActorCriticPolicy): #

class CustomActorCriticPolicy(RecurrentActorCriticPolicy): #RecurrentActorCriticPolicy

def __init__(
    self,
    observation_space: spaces.Space,
    action_space: spaces.Space,
    lr_schedule: Callable[[float], float],
    *args,
    **kwargs,
):
    # Disable orthogonal initialization
    kwargs["ortho_init"] = False
    super().__init__(
        observation_space,
        action_space,
        lr_schedule,
        # Pass remaining arguments to base class
        *args,
        **kwargs,
    )


def _build_mlp_extractor(self) -> None:
    self.mlp_extractor = CustomNetwork(self.features_dim)

from stable_baselines3.common.policies import ActorCriticPolicy

from sb3_contrib import RecurrentPPO

import sb3_contrib.ppo_recurrent.MlpLstmPolicy

from sb3_contrib.ppo_recurrent import MlpLstmPolicy

import sb3_contrib.common.recurrent.policies.RecurrentActorCriticPolicy

from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
if name == "main":
policy_kwargs = dict(
features_extractor_class=LSTMFeatureExtractor,
# features_extractor_kwargs=dict(features_dim=128),
features_extractor_kwargs=dict(features_dim=512),

        )
 # === 1、测试模型构建与训练 ===
# model = PPO(CustomActorCriticPolicy, "CartPole-v1", policy_kwargs=policy_kwargs,
#             verbose=1, batch_size=2, n_steps=2)

model = RecurrentPPO(CustomActorCriticPolicy, "CartPole-v1", policy_kwargs=policy_kwargs,
            verbose=1, batch_size=2, n_steps=2)

# model = RecurrentPPO(RecurrentActorCriticPolicy, "CartPole-v1", policy_kwargs=policy_kwargs,
#             verbose=1, batch_size=2, n_steps=2)

model.learn(10)

print("*" * 50)
print(model.policy)

class CustomActorCriticPolicy(RecurrentActorCriticPolicy): #RecurrentActorCriticPolicy:(when use this)
Traceback (most recent call last):
File "/Users/umr/MyData/vspro/SB3_12tunning4/多股/police_model/lstm_mlp_rec_test.py", line 291, in
model.learn(10)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 450, in learn
return super().learn(
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 324, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 242, in collect_rollouts
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/sb3_contrib/common/recurrent/policies.py", line 249, in forward
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
File "/Users/umr/MyData/vspro/SB3_12tunning4/多股/police_model/lstm_mlp_rec_test.py", line 217, in forward_actor
return self.policy_net(features)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/container.py", line 244, in forward
input = module(input)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)

class CustomActorCriticPolicy(ActorCriticPolicy): #(when use this)
Traceback (most recent call last):
File "/Users/umr/MyData/vspro/SB3_12tunning4/多股/police_model/lstm_mlp_rec_test.py", line 285, in
model = RecurrentPPO(CustomActorCriticPolicy, "CartPole-v1", policy_kwargs=policy_kwargs,
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 137, in init
self._setup_model()
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 156, in _setup_model
lstm = self.policy.lstm_actor
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1962, in getattr
raise AttributeError(
AttributeError: 'CustomActorCriticPolicy' object has no attribute 'lstm_actor'
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x256 and 512x512)

To Reproduce

from stable_baselines3 import ...
from sb3_contrib import ...

Relevant log output / Error message

System Info

How to inherit the Recurrent ActorCriticPolicy custom model

Checklist

Metadata

Metadata

Assignees

No one assigned

    Labels

    check the checklistYou have checked the required items in the checklist but you didn't do what is written...

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions