把backtrader改造成金融强化学习回测引擎

我们的AI量化平台,针对传统规则量化策略,进行了“积木式”的拆分,这种拆分的好处,就是最大化复用代码逻辑,这样开发策略又快且不容易出错。

针对强化学习环境,我们也打算这么做。看到有些平台,股票一个环境,加密货币一个环境,期货又是另一个环境,甚至把数据源处理都耦合到环境中,这是不对的。维护起来特别麻烦且容易出错。

图片

01 强化学习环境的构成

一个完整的金融强化学习环境,包括数据(通常是OHLC以及特征数据)、交易动作(做多,平仓或者权重),回测系统(回测系统能够对交易动作做出反馈),激励指标(收益率、夏普比等)。

当下很多开源包在实现这个环境,存在一些问题。它们大篇幅做数据特征处理,实现了简单的量化回测,激励指标比较单一。这些环境扩展性不好,更谈不上接入实盘。

前文我们对传统量化与机器学习量化实现了统一的框架,数据管理及自动化标注,接入到backtrader的回测系统中。

同样,我们也希望可以最大化复用之前的成果,把强化学习环境也整合到一起,我们就不必关心数据处理,回测系统等环节的处理。

然而,backtrader本身为是传统量化而生,它的run直接就遍历了整个时间序列,而强化学习是由环境来调用回测引擎,一步一步往前走。

对于backtrader要进行一些改造。

一、重写_run_next

Backtrader在执行run的时间,在_runnext函数里整个顺序进行遍历。我们重写这个函数,只做第一步初始化的运算。

def _runnext(self, runstrats):
    '''
    Actual implementation of run in full next mode. All objects have its
    ``next`` method invoke on each data arrival
    '''
    self.runstrats_container = runstrats
    self._init_run()

def _init_run(self):
    datas = sorted(self.datas,
                   key=lambda x: (x._timeframe, x._compression))
    datas1 = datas[1:]
    data0 = datas[0]
    d0ret = True
    rsonly = [i for i, x in enumerate(datas)
              if x.resampling and not x.replaying]
    onlyresample = len(datas) == len(rsonly)
    noresample = not rsonly

    clonecount = sum(d._clone for d in datas)
    ldatas = len(datas)
    ldatas_noclones = ldatas - clonecount
    dt0 = date2num(datetime.datetime.max) - 2  # default at max

    self.bt_state_container = {"datas": datas,
                               "datas1": datas1,
                               "data0": data0,
                               "d0ret": d0ret,
                               "rsonly": rsonly,
                               "onlyresample": onlyresample,
                               "noresample": noresample,
                               "ldatas_noclones": ldatas_noclones,
                               "dt0": dt0,
                               }

二、提供一个step单步执行函数

这个函数的代码大多可以从cerebro里查到,这里就是展开描述。

def _step(self, runstrats, datas, datas1, data0, d0ret, rsonly,
          onlyresample, noresample, ldatas_noclones, dt0):

    # if any has live data in the buffer, no data will wait anything
    newqcheck = not any(d.haslivedata() for d in datas)
    if not newqcheck:
        # If no data has reached the live status or all, wait for
        # the next incoming data
        livecount = sum(d._laststatus == d.LIVE for d in datas)
        newqcheck = not livecount or livecount == ldatas_noclones

    lastret = False
    # Notify anything from the store even before moving datas
    # because datas may not move due to an error reported by the store
    self._storenotify()
    if self._event_stop:  # stop if requested
        return True
    self._datanotify()
    if self._event_stop:  # stop if requested
        return True

    # record starting time and tell feeds to discount the elapsed time
    # from the qcheck value
    drets = []
    qstart = datetime.datetime.utcnow()
    for d in datas:
        qlapse = datetime.datetime.utcnow() - qstart
        d.do_qcheck(newqcheck, qlapse.total_seconds())
        drets.append(d.next(ticks=False))

    d0ret = any((dret for dret in drets))
    if not d0ret and any((dret is None for dret in drets)):
        d0ret = None

    if d0ret:
        dts = []
        for i, ret in enumerate(drets):
            dts.append(datas[i].datetime[0] if ret else None)

        # Get index to minimum datetime
        if onlyresample or noresample:
            dt0 = min((d for d in dts if d is not None))
        else:
            dt0 = min((d for i, d in enumerate(dts)
                       if d is not None and i not in rsonly))

        dmaster = datas[dts.index(dt0)]  # and timemaster
        self._dtmaster = dmaster.num2date(dt0)
        self._udtmaster = num2date(dt0)

        # slen = len(runstrats[0])
        # Try to get something for those that didn't return
        for i, ret in enumerate(drets):
            if ret:  # dts already contains a valid datetime for this i
                continue

            # try to get a data by checking with a master
            d = datas[i]
            d._check(forcedata=dmaster)  # check to force output
            if d.next(datamaster=dmaster, ticks=False):  # retry
                dts[i] = d.datetime[0]  # good -> store
                # self._plotfillers2[i].append(slen)  # mark as fill
            else:
                # self._plotfillers[i].append(slen)  # mark as empty
                pass

        # make sure only those at dmaster level end up delivering
        for i, dti in enumerate(dts):
            if dti is not None:
                di = datas[i]
                rpi = False and di.replaying  # to check behavior
                if dti > dt0:
                    if not rpi:  # must see all ticks ...
                        di.rewind()  # cannot deliver yet
                    # self._plotfillers[i].append(slen)
                elif not di.replaying:
                    # Replay forces tick fill, else force here
                    di._tick_fill(force=True)

                # self._plotfillers2[i].append(slen)  # mark as fill

    elif d0ret is None:
        # meant for things like live feeds which may not produce a bar
        # at the moment but need the loop to run for notifications and
        # getting resample and others to produce timely bars
        for data in datas:
            data._check()
    else:
        lastret = data0._last()
        for data in datas1:
            lastret += data._last(datamaster=data0)

        if not lastret:
            # Only go extra round if something was changed by "lasts"
            return True  # return somethin signaling the end

    # Datas may have generated a new notification after next
    self._datanotify()
    if self._event_stop:  # stop if requested
        return True

    if d0ret or lastret:  # if any bar, check timers before broker
        self._check_timers(runstrats, dt0, cheat=True)
        if self.p.cheat_on_open:
            for strat in runstrats:
                strat._next_open()
                if self._event_stop:  # stop if requested
                    return True

    self._brokernotify()
    if self._event_stop:  # stop if requested
        return True

    if d0ret or lastret:  # bars produced by data or filters
        self._check_timers(runstrats, dt0, cheat=False)
        for strat in runstrats:
            strat._next()
            if self._event_stop:  # stop if requested
                return True

            self._next_writers(runstrats)

    self.bt_state_container = {"datas": datas,
                               "datas1": datas1,
                               "data0": data0,
                               "d0ret": d0ret,
                               "rsonly": rsonly,
                               "onlyresample": onlyresample,
                               "noresample": noresample,
                               "ldatas_noclones": ldatas_noclones,
                               "dt0": dt0,
                               }

    return False

如此,我们的backtrader就可以实现单步循环,为我们的环境提供回测能力了。

发布者:股市刺客,转载请注明出处:https://www.95sca.cn/archives/104202
站内所有文章皆来自网络转载或读者投稿,请勿用于商业用途。如有侵权、不妥之处,请联系站长并出示版权证明以便删除。敬请谅解!

(0)
股市刺客的头像股市刺客
上一篇 2024 年 7 月 29 日
下一篇 2024 年 7 月 29 日

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注