-
Notifications
You must be signed in to change notification settings - Fork 223
Description
🐛 Bug
code:
import torch as th
import torch.nn as nn
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.distributions import DiagGaussianDistribution
import backtrader as bt
import pandas as pd
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import gymnasium as gym # 导入gymnasium用于自定义环境
from gymnasium import spaces # 导入spaces用于定义空间
from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
"""
环境维度格式的规范: 根据obs结构适配lstm
obs的格式维度:window x Features (window为时间步长,Features为特征数,
例如:取时间窗口为5的,ohlcv(5)、macd、boll特征的数据,即5x7),
形状 = 时间步长维度(个数)x 特征维度(个数)
返回的数据类型为:np.array,将列表转为np.array
模型输入维度:obs=obs.unsqueeze(0),加入 batch_size,
# obs的形状为([1, 4]),需要扩展成(batch_size, seq_len, hidden_size)形状,添加sequence_length纬度,
seq_len x hidden_size = window x Features(seq_len为window,hidden_size为Features)
"""
=== 自定义 LSTM 特征提取器 ===
class LSTMFeatureExtractor(BaseFeaturesExtractor):
# def init(self, observation_space, features_dim=128):
# super().init(observation_space, features_dim)
# def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
super().__init__(observation_space, features_dim)
# input_dim = observation_space.shape[1]
shape = observation_space.shape
if len(shape) == 1:
# 1D,例如 shape=(101,)
print('obs维度1D:',shape)
input_dim = observation_space.shape[0]
elif len(shape) == 2:
# 2D,例如 shape=(1, 101)
print('obs维度2D:',shape)
input_dim = observation_space.shape[1]
else:
raise ValueError(f"obs维度不适配特征提取: {obs.shape}")
print('observation_space:',type(observation_space),observation_space.shape)
print('feature input_dim:',input_dim)
self.lstm = nn.LSTM(
input_size=input_dim,
hidden_size=features_dim,
num_layers=1,
batch_first=True,
# bidirectional=True
)
def forward(self, obs):
# print('输入obs形状:', obs.shape) # 打印输入obs的形状
# print('输入obs类型:', type(obs),obs) # 打印输入obs的类型
# obs=obs.unsqueeze(1) # obs的形状为([1, 4]),需要扩展成(batch_size, seq_len, hidden_size)形状,添加sequence_length纬度
# 维度适配
# obs = np.array(obs.shape)
# print(' model obs维度:',obs.ndim)
if obs.ndim == 2:
# 1D,例如 shape=(101,)
print('obs batch维度2D:',obs.shape)
obs=obs.unsqueeze(1) # obs的形状为([1, 4]),需要扩展成(batch_size, seq_len, hidden_size)形状,添加sequence_length纬度
elif obs.ndim == 3:
# 2D,例如 shape=(1, 101)
# 3维时不作处理
print('obs batch维度2D:',obs.shape)
pass
else:
raise ValueError(f"obs维度不适配自定义模型: {obs.shape}")
print('obs适配维度:',obs.ndim,obs.shape)
# features, _ = self.lstm(obs)
features,(h_n, c_n) = self.lstm(obs)
# print('obs形状:',obs.shape)
# print("features形状:",features.shape, h_n.shape, c_n.shape)
return features[:, -1, :]
"""
首先,features, _ = self.lstm(obs) 这一行将输入的观测数据 obs 传入 LSTM 层。
LSTM 层通常会返回两个值:所有时间步的输出(features),以及最后一个时间步的隐藏状态(这里用 _ 忽略了)。
features 的形状通常是 (batch_size, seq_len, hidden_size),即每个样本、每个时间步、每个隐藏单元的输出。
接下来,return features[:, -1, :] 表示返回每个样本在序列最后一个时间步的 LSTM 输出。
这里 : 表示所有样本,-1 表示序列的最后一个时间步,: 表示所有隐藏单元。
这样做的目的是只保留每个序列最后时刻的特征,通常用于后续的决策或分类任务。
features[:, -1, :] 取出每个样本最后一个时间步的 LSTM 输出,形状为 (batch_size, hidden_size)。
output[:, -1, :] 直接取出 最后时间步 的隐藏状态。
对于单向 LSTM,这与 h_n[-1] 等价;双向或多层时需注意维度与拼接顺序。
"""
policy_kwargs = dict(
features_extractor_class=LSTMFeatureExtractor,
# features_extractor_kwargs=dict(features_dim=128),
features_extractor_kwargs=dict(features_dim=1024),
)
env = gym.make("CartPole-v1", render_mode="rgb_array")
# model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
# model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1,batch_size=8)
# model.learn(10000,progress_bar=True)
# model.learn(10000)
model.learn(1)
# quit()
=== 自定义策略网络(Actor + Critic) ===
class CustomNetwork(nn.Module):
"""
Custom network for policy and value function.
It receives as input the features extracted by the features extractor.
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
:param last_layer_dim_vf: (int) number of units for the last layer of the value network
"""
def __init__(
self,
feature_dim: int,
# # last_layer_dim_pi: int = 64,
last_layer_dim_pi: int = 128,
last_layer_dim_vf: int = 128,
# last_layer_dim_pi: int = 1024,
# last_layer_dim_vf: int = 1024,
):
super().__init__()
# IMPORTANT:
# Save output dimensions, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Policy network
# self.policy_net = nn.Sequential(
# nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
# )
self.policy_net = nn.Sequential(
nn.Linear(feature_dim, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.ReLU(),
nn.LayerNorm(256),
nn.Linear(256, last_layer_dim_pi),
nn.ReLU()
)
# 使用LSTM网络
# self.policy_net = nn.LSTM(
# input_size=feature_dim,
# hidden_size=last_layer_dim_pi,
# num_layers=1,
# batch_first=True
# )
# Value network
# self.value_net = nn.Sequential(
# nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
# )
self.value_net = nn.Sequential(
nn.Linear(feature_dim, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.ReLU(),
nn.LayerNorm(256),
nn.Linear(256, last_layer_dim_pi),
nn.ReLU()
)
# 使用LSTM网络
# self.value_net = nn.LSTM(
# input_size=feature_dim,
# hidden_size=last_layer_dim_vf,
# num_layers=1,
# batch_first=True
# )
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
return self.forward_actor(features), self.forward_critic(features)
def forward_actor(self, features: th.Tensor) -> th.Tensor: #使用Lstm网络
return self.policy_net(features)
# print('输入features_policy_net形状:',features.shape)
# features,(h_n, c_n)= self.policy_net(features.unsqueeze(1)) # 添加sequence_length纬度
# print('输出Lstm_policy_net形状:',features.shape, h_n.shape, c_n.shape)
# print('返回Lstm_policy_net',features[:, -1, :].shape)
# 使用LSTM网络
# features,(h_n, c_n)= self.policy_net(features.unsqueeze(1)) # 添加sequence_length纬度
# return features[:, -1, :]
def forward_critic(self, features: th.Tensor) -> th.Tensor: #使用MLP前馈网络
return self.value_net(features)
# print('输入features_value_net形状:',features.shape)
# output = self.value_net(features) # 添加sequence_length纬度
# # print('输出value_net形状:',output.shape)
# return self.value_net(features)
# 使用LSTM网络
# features,(h_n, c_n)= self.value_net(features.unsqueeze(1)) # 添加sequence_length纬度
# return features[:, -1, :]
class CustomActorCriticPolicy(ActorCriticPolicy): #
class CustomActorCriticPolicy(RecurrentActorCriticPolicy): #RecurrentActorCriticPolicy
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Callable[[float], float],
*args,
**kwargs,
):
# Disable orthogonal initialization
kwargs["ortho_init"] = False
super().__init__(
observation_space,
action_space,
lr_schedule,
# Pass remaining arguments to base class
*args,
**kwargs,
)
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = CustomNetwork(self.features_dim)
from stable_baselines3.common.policies import ActorCriticPolicy
from sb3_contrib import RecurrentPPO
import sb3_contrib.ppo_recurrent.MlpLstmPolicy
from sb3_contrib.ppo_recurrent import MlpLstmPolicy
import sb3_contrib.common.recurrent.policies.RecurrentActorCriticPolicy
from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
if name == "main":
policy_kwargs = dict(
features_extractor_class=LSTMFeatureExtractor,
# features_extractor_kwargs=dict(features_dim=128),
features_extractor_kwargs=dict(features_dim=512),
)
# === 1、测试模型构建与训练 ===
# model = PPO(CustomActorCriticPolicy, "CartPole-v1", policy_kwargs=policy_kwargs,
# verbose=1, batch_size=2, n_steps=2)
model = RecurrentPPO(CustomActorCriticPolicy, "CartPole-v1", policy_kwargs=policy_kwargs,
verbose=1, batch_size=2, n_steps=2)
# model = RecurrentPPO(RecurrentActorCriticPolicy, "CartPole-v1", policy_kwargs=policy_kwargs,
# verbose=1, batch_size=2, n_steps=2)
model.learn(10)
print("*" * 50)
print(model.policy)
class CustomActorCriticPolicy(RecurrentActorCriticPolicy): #RecurrentActorCriticPolicy:(when use this)
Traceback (most recent call last):
File "/Users/umr/MyData/vspro/SB3_12tunning4/多股/police_model/lstm_mlp_rec_test.py", line 291, in
model.learn(10)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 450, in learn
return super().learn(
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 324, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 242, in collect_rollouts
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/sb3_contrib/common/recurrent/policies.py", line 249, in forward
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
File "/Users/umr/MyData/vspro/SB3_12tunning4/多股/police_model/lstm_mlp_rec_test.py", line 217, in forward_actor
return self.policy_net(features)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/container.py", line 244, in forward
input = module(input)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
class CustomActorCriticPolicy(ActorCriticPolicy): #(when use this)
Traceback (most recent call last):
File "/Users/umr/MyData/vspro/SB3_12tunning4/多股/police_model/lstm_mlp_rec_test.py", line 285, in
model = RecurrentPPO(CustomActorCriticPolicy, "CartPole-v1", policy_kwargs=policy_kwargs,
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 137, in init
self._setup_model()
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 156, in _setup_model
lstm = self.policy.lstm_actor
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1962, in getattr
raise AttributeError(
AttributeError: 'CustomActorCriticPolicy' object has no attribute 'lstm_actor'
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/umr/MyInstall/miniconda3/envs/pytorch10_1/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x256 and 512x512)
To Reproduce
from stable_baselines3 import ...
from sb3_contrib import ...Relevant log output / Error message
System Info
How to inherit the Recurrent ActorCriticPolicy custom model
Checklist
- I have checked that there is no similar issue in the repo
- I have read the documentation
- I have provided a minimal and working example to reproduce the bug
- I've used the markdown code blocks for both code and stack traces.