去掉底层回测引擎,完全自研,增加超参数优化,因子自动挖掘,机器模型交易等,之前分享过,但过于分散,整合成一体。
重写后的代码更简洁:
import pandas as pd from .task import Task from .strategy import StrategyAlgo, ExecContext from .portfolio import Portfolio, PortfolioBar class Engine: def __init__(self, task: Task): self.task = task self.df_datas = self.task.load_datas() self.dates = list(self.df_datas.index.unique()) self.dates.sort() self.portfolio = Portfolio(init_cash=task.init_cash) self.strategy = StrategyAlgo(self.task.get_algos(), self.portfolio) self.symbols = list(self.df_datas['symbol'].unique()) self.exec_context = ExecContext() self.exec_context.dates = self.dates self.exec_context.strategy = self.strategy self.exec_context.df_datas = self.df_datas def _get_bar(self, date): df_bar = self.df_datas.loc[date, :].copy() if type(df_bar) is pd.Series: df_bar = df_bar.to_frame().T df_bar.index = df_bar['symbol'] return df_bar def run(self, **kwargs): # 这里的参数用作超参数优化 for i, date in enumerate(self.dates): df_bar = self._get_bar(date) self.exec_context.temp = {} self.exec_context.now = date self.exec_context.index = i self.exec_context.df_bar = df_bar # 先更新portfolio self.portfolio.on_bar(date, df_bar) # 收盘后交易 self.strategy.on_bar(self.exec_context) self.portfolio.process_orders() def optimize(self): pass def stats(self): portfolio_df = pd.DataFrame.from_records( self.portfolio.bars, columns=PortfolioBar._fields, index="date" ) portfolio_df['market_value'].plot() import matplotlib.pyplot as plt from matplotlib import rcParams rcParams['font.family'] = 'SimHei' portfolio_df['market_value'].plot() plt.show()
核心类就是portfolio,
from dataclasses import dataclass, field from typing import NamedTuple import numpy as np import pandas as pd from loguru import logger class ScheOrder(NamedTuple): symbol: str amount: float @dataclass class Position: symbol: str shares: float close: float # 最近的收盘价 equity: float bars: int = 0 class PortfolioBar(NamedTuple): date: np.datetime64 cash: float equity: float market_value: float fees: float # 手续费 class Portfolio: def __init__(self, init_cash, fee_rate=0.000): self.positions: dict[str, Position] = {} self.cash = init_cash self.fees = 0.0 self.bars: list[PortfolioBar] = list() self.sche_orders = [] self.curr_bar = None self.fee_rate = fee_rate self.total_market_value = init_cash def on_bar(self, date: np.datetime64, df_bar: pd.DataFrame): total_equity = 0.0 self.curr_bar = df_bar for symbol in self.positions.keys(): se = df_bar.loc[symbol] pos = self.positions[symbol] pos.close = se['close'] pos.equity = pos.shares * pos.close # 这里更新equity。 self.positions[symbol] = pos total_equity += pos.equity self.total_market_value = total_equity + self.cash self.bars.append(PortfolioBar( date=date, cash=self.cash, equity=total_equity, market_value=self.total_market_value, fees=self.fees )) # strategy下单在这里,统一执行,先卖再买 def new_order(self, symbol, amount): order = ScheOrder(symbol=symbol, amount=amount) if order.amount < 0: self.sche_orders.insert(0, order) else: self.sche_orders.append(order) def process_orders(self): for o in self.sche_orders: price = self.curr_bar.loc[o.symbol]['close'] shares = int(o.amount / price) if o.amount > 0: self._buy(o.symbol, shares) else: self._sell(o.symbol, shares) self.sche_orders.clear() def _buy(self, symbol, shares): if shares < 1: return #assert shares >= 1 if symbol not in self.curr_bar.index: logger.error('当天{}没有数据'.format(symbol)) return price = self.curr_bar.loc[symbol]['close'] amount = price * shares fee = amount * self.fee_rate total_amount = amount + fee if self.cash < total_amount: logger.error('现金不够,无法下单:{}'.format(symbol)) return if symbol in self.positions.keys(): pos = self.positions[symbol] pos.shares += shares else: pos = Position(symbol=symbol, shares=shares, close=price, equity=shares*price) self.positions[symbol] = pos self.cash -= total_amount def _sell(self, symbol, shares): if symbol not in self.curr_bar.index: logger.error('当天{}没有数据'.format(symbol)) return price = self.curr_bar.loc[symbol]['close'] amount = price * shares if symbol in self.positions.keys(): pos = self.positions[symbol] if pos.shares < shares: logger.error('当前持仓股数:{}小于{},交易无法进行'.format(pos.shares, shares)) return pos.shares -= shares if pos.shares == 0: del self.positions[symbol] fee = amount * self.fee_rate self.cash += (amount-fee) else: logger.error('当前未持仓,无法卖出:{}'.format(symbol))
Quantlab3.0值得期待一下:
发布者:股市刺客,转载请注明出处:https://www.95sca.cn/archives/103574
站内所有文章皆来自网络转载或读者投稿,请勿用于商业用途。如有侵权、不妥之处,请联系站长并出示版权证明以便删除。敬请谅解!