stable-baseline3强化学习回测投资组合管理,年化8.5%(附代码下载)

很多开源代码里的示例代码本身也不复杂,但由于它们把回测系统给耦合在里边,所以看起来特别乱。

今天我们来看看深度强化学习框架。

原本打算使用清华本科生那个“天授”,原因是qlib里使用的框架竟然是它。但看了一下官方文档,竟然没有tutorial。然后知乎上读了一些文章,弃用了。试下elegantRL吧。实在不行,兜底stable-baseline3是肯定没有问题的。

一个框架好不好用,第一观感就是它的quick start和tutorial。

pip install stable-baselines3[extra]

从官网上看,还是stable-baselines3成熟,安装也简单。

stable-baselines3,要求action_space这个可以理解,因为动作空间长度是确定的;但要求observation_space这个比较奇怪,每次观察的状态的shape必须一致?在投资里做不到啊,比如A股/美股的投资日历不同,那么有些天就不是所以的标的都可以“观察”到数据,在机器学习里,我们的输入仅要求“列”相同即可,输入的行数是可变的。

from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
check_env(e) model = A2C("CnnPolicy", e).learn(total_timesteps=1000)

强化学习包不是太少,而是太多,乱花渐欲迷人眼。另外像tensortrade, finrl这样,就是把强化学习应用于金融,已经开发出了自己一个完整的框架,还带来数据的,也不在少数。

from abc import abstractmethod
import gym
import numpy as np
from gym import spaces
from numpy import infty

from engine.datafeed.csv_dataloader import CSVDataloader
from engine.env import Env


class EnvRL(gym.Env):
    def __init__(self, env: Env, feature_cols: list):
        super().__init__()
        self.env = env
        self.feature_cols = feature_cols

        self.symbols = env.df_data['symbol'].unique()
        self.symbols.sort()
        #print(self.symbols)
        len_symbols = len(self.symbols)
        self.action_space = spaces.Box(low=0, high=1, shape=(len_symbols,))  # 不分配权重就是0, 全仓就是1
        self.observation_space = spaces.Box(
            low=-np.infty, high=np.infty, shape=(len_symbols, len(feature_cols)), dtype=np.float64)

    def _softmax_normalization(self, actions):
        numerator = np.exp(actions)
        denominator = np.sum(np.exp(actions))
        softmax_output = numerator / denominator
        return softmax_output

    def step(self, action):
        bar_df = self.env.step()
        obs = bar_df[self.feature_cols]

        actions = self._softmax_normalization(action) * 0.99
        weights = {s: a for s, a in zip(self.symbols, actions)}


        self.env.portfolio.order_target_weights(weights)

        #reward = self.env.portfolio.get_total_mv() - self.env.portfolio.init_cash
        reward = self.env.calc_sharpe()
        is_done = self.env.curr_step == len(self.env.dates) - 1

        if is_done:
            import matplotlib.pyplot as plt
            self.env.show_results()
            #print('reward',reward)
        return obs.values, reward, is_done, {}

    def reset(self):
        bar_df = self.env.reset()
        obs = bar_df[self.feature_cols]
        return obs.values


if __name__ == '__main__':
    from engine.env import Env

    symbols = ['000300.SH', '399006.SZ']
    loader = CSVDataloader(symbols, start_date="20120101")
    df = loader.load(fields=['close/shift(close,20)-1'], names=['roc_20'])

    from engine.env import Env
    from engine.algo.algos import *
    from engine.algo.algo_weights import *

    trading_env = Env(df)
    e = EnvRL(trading_env,['roc_20'])
    o = e.reset()
    done = False

    from stable_baselines3 import A2C
    from stable_baselines3.common.env_checker import check_env

    #check_env(e)

    model = A2C("MlpPolicy", e).learn(total_timesteps=10000)
    #print(e.env.portfolio.get_total_mv())

    print('开始预测')

    vec_env = model.get_env()
    obs = vec_env.reset()
    while not done:
        action, _state = model.predict(obs, deterministic=True)
        obs, reward, done, info = vec_env.step(action)
        if done:
            print('预测完成。')
        #vec_env.render("human")
        # VecEnv resets automatically


    #e.env.show_results()

    '''
    
    while not done:
        action = e.action_space.sample()
        obs, reward, done, info = e.step(action)
        print(obs)
    '''

图片

可前往星球下载全部代码,数据。

发布者:股市刺客,转载请注明出处:https://www.95sca.cn/archives/104108
站内所有文章皆来自网络转载或读者投稿,请勿用于商业用途。如有侵权、不妥之处,请联系站长并出示版权证明以便删除。敬请谅解!

(0)
股市刺客的头像股市刺客
上一篇 2024 年 7 月 29 日
下一篇 2024 年 7 月 29 日

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注