从零学习量化交易38股票数据热编码OneHot

(1)源代码

# -*- coding: utf-8 -*-
#股票数据热编码One-Hot

import sys, os
sys.path.append("topqt/")

import numpy as np
import pandas as pd
import tushare as ts
import plotly as py
import plotly.figure_factory as pyff

import math
import arrow
import ffn
import pypinyin
import pandas_datareader as pdr
import matplotlib.pyplot as plt

import zsys2025  # 20250213
import ztools as zt
import ztools_str as zstr
import ztools_data2025 as zdat
import ztools_draw2025 as zdr
import ztools_tq2025 as ztq  # 20250213


#1
print('\n#1,set.sys')
pd.set_option('display.width', 450)
pd.set_option('display.float_format', zt.xfloat3)


#2
print('\n#2,读取数据')
fss='data/inx_000001.csv'
#fss='data/300682.csv'
df=pd.read_csv(fss,index_col=0)
df=df.sort_index()

#3
print('\n#3,整理数据')
df['xopen']=df['open'].shift(-1)
df['xclose']=df['close'].shift(-1)
df['kclose']=df['xclose']/df['xopen']*100
df['ktype']=df['kclose'].apply(zt.iff3type,d0=99.5,d9=101.5,v3=3,v2=2,v1=1)
df['ktype']=df['ktype']-1
print(df.tail())

#4
print('\n#4,设置训练数据')
n9=len(df.index)
df1=df.head(2000)
df2=df.tail(n9-2000)
#
X=df1[zsys2025.ohlcLst].values

Y=df1['ktype'].values

#5
print('\n#5,One-Hot Encode')
y_onehot=pd.get_dummies(Y)
print('y_onehot.head(5)')
print(y_onehot.head(5))
y1s=pd.get_dummies(y_onehot)
print('y1s.head(5)')
print(y1s.head(5))
#
y1=y1s.values
print('y1')
print(y1)
print('type(y1),',type(y1))
print('y1.shape,',y1.shape)
#
#6
print('\n#6,One-Hot Decode')
a0,a1=y1[0],y1[1]
a0v,a1v=np.argmax(a0,axis=0),np.argmax(a1,axis=0)
print('\na0v,',a0v,a0)
print('a1v,',a1v,a1)

(2)程序运行输出

C:\Users\Administrator\Anaconda3\envs\tensorflow\python.exe C:\Users\Administrator\PycharmProjects\TOPQuant_Tensorflow-StockDataHotCoded_One-Hot.py
Using TensorFlow backend.
2025-02-15 19:07:59.585705: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'cudart64_100.dll'; dlerror: cudart64_100.dll not found
2025-02-15 19:07:59.585848: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

#1,set.sys

#2,读取数据

#3,整理数据
               open     high    close      low  ...    xopen   xclose  kclose  ktype
date                                            ...                                 
2017-04-26 3132.918 3152.953 3140.847 3131.418  ... 3131.350 3152.187 100.665      1
2017-04-27 3131.350 3155.003 3152.187 3097.333  ... 3144.022 3154.658 100.338      1
2017-04-28 3144.022 3154.727 3154.658 3136.578  ... 3147.228 3143.712  99.888      1
2017-05-02 3147.228 3154.781 3143.712 3136.539  ... 3138.307 3135.346  99.906      1
2017-05-03 3138.307 3148.286 3135.346 3123.751  ...      nan      nan     nan      1

[5 rows x 10 columns]

#4,设置训练数据

#5,One-Hot Encode
y_onehot.head(5)
   0  1  2
0  0  1  0
1  0  0  1
2  0  0  1
3  0  1  0
4  0  1  0
y1s.head(5)
   0  1  2
0  0  1  0
1  0  0  1
2  0  0  1
3  0  1  0
4  0  1  0
y1
[[0 1 0]
 [0 0 1]
 [0 0 1]
 ...
 [1 0 0]
 [0 0 1]
 [0 1 0]]
type(y1), <class 'numpy.ndarray'>
y1.shape, (2000, 3)

#6,One-Hot Decode

a0v, 1 [0 1 0]
a1v, 2 [0 0 1]

Process finished with exit code 0

(3)源代码注释

# -*- coding: utf-8 -*-
# 声明文件的编码格式为UTF-8,确保文件中包含的中文等非ASCII字符能正确被解析

#股票数据热编码One-Hot
# 程序的功能描述,即对股票数据进行One-Hot编码操作

import sys, os
# 导入sys和os模块
# sys模块提供了一些变量和函数,用于与Python解释器进行交互,如访问命令行参数、修改模块搜索路径等
# os模块提供了一种方便的使用操作系统相关功能的方式,如文件和目录操作等

sys.path.append("topqt/")
# 将 "topqt/" 目录添加到Python模块搜索路径中,这样Python解释器就能在该目录下查找要导入的模块

import numpy as np
# 导入numpy库,并将其重命名为np
# numpy是Python中用于科学计算的基础库,提供了多维数组对象和各种数学函数

import pandas as pd
# 导入pandas库,并将其重命名为pd
# pandas是一个强大的数据处理和分析库,提供了DataFrame和Series等数据结构

import tushare as ts
# 导入tushare库,用于获取金融数据
# tushare是一个免费、开源的python财经数据接口包

import plotly as py
# 导入plotly库,并将其重命名为py
# plotly是一个交互式可视化库,可用于创建各种类型的图表

import plotly.figure_factory as pyff
# 导入plotly的figure_factory模块,并将其重命名为pyff
# figure_factory模块提供了一些创建复杂图表的函数

import math
# 导入math模块,提供了许多数学函数和常量

import arrow
# 导入arrow库,用于处理日期和时间
# arrow是一个简洁、易用的Python日期时间处理库

import ffn
# 导入ffn库,用于金融数据分析
# ffn是一个金融分析库,提供了一些金融指标计算和数据分析工具

import pypinyin
# 导入pypinyin库,用于将中文转换为拼音

import pandas_datareader as pdr
# 导入pandas_datareader库,并将其重命名为pdr
# pandas_datareader是一个用于从各种在线数据源获取金融数据的库

import matplotlib.pyplot as plt
# 导入matplotlib的pyplot模块,并将其重命名为plt
# matplotlib是一个用于绘制图表的库,pyplot模块提供了类似于MATLAB的绘图接口

import zsys2025  # 20250213
# 导入自定义模块zsys2025,注释中给出了导入时间
# 自定义模块通常包含一些自定义的函数和变量,用于特定的业务逻辑

import ztools as zt
# 导入自定义模块ztools,并将其重命名为zt
# 自定义工具模块,可能包含一些常用的工具函数

import ztools_str as zstr
# 导入自定义模块ztools_str,并将其重命名为zstr
# 自定义字符串处理工具模块,可能包含一些处理字符串的函数

import ztools_data2025 as zdat
# 导入自定义模块ztools_data2025,并将其重命名为zdat
# 自定义数据处理工具模块,可能包含一些处理数据的函数

import ztools_draw2025 as zdr
# 导入自定义模块ztools_draw2025,并将其重命名为zdr
# 自定义绘图工具模块,可能包含一些绘制图表的函数

import ztools_tq2025 as ztq  # 20250213
# 导入自定义模块ztools_tq2025,并将其重命名为ztq,注释中给出了导入时间
# 自定义工具模块,可能包含与特定任务相关的工具函数


#1
print('\n#1,set.sys')
# 打印提示信息,表明接下来是设置系统相关参数
pd.set_option('display.width', 450)
# 设置pandas DataFrame在控制台输出时的宽度为450个字符,使输出更美观
pd.set_option('display.float_format', zt.xfloat3)
# 设置pandas DataFrame中浮点数的显示格式,使用zt模块中的xfloat3函数进行格式化


#2
print('\n#2,读取数据')
# 打印提示信息,表明接下来是读取数据的操作
fss='data/300682.csv'
# 定义要读取的CSV文件的路径
df=pd.read_csv(fss,index_col=0)
# 使用pandas的read_csv函数读取CSV文件,并将第一列作为索引列,将读取的数据存储在DataFrame对象df中
df=df.sort_index()
# 对DataFrame对象df按照索引进行排序


#3
print('\n#3,整理数据')
# 打印提示信息,表明接下来是整理数据的操作
df['xopen']=df['open'].shift(-1)
# 在DataFrame对象df中新增一列 'xopen',其值为 'open' 列向下移动一行后的值
df['xclose']=df['close'].shift(-1)
# 在DataFrame对象df中新增一列 'xclose',其值为 'close' 列向下移动一行后的值
df['kclose']=df['xclose']/df['xopen']*100
# 在DataFrame对象df中新增一列 'kclose',其值为 'xclose' 列除以 'xopen' 列再乘以100的结果
df['ktype']=df['kclose'].apply(zt.iff3type,d0=99.5,d9=101.5,v3=3,v2=2,v1=1)
# 在DataFrame对象df中新增一列 'ktype',其值是对 'kclose' 列的每个元素应用zt模块中的iff3type函数得到的结果
# 参数d0=99.5, d9=101.5, v3=3, v2=2, v1=1 是传递给iff3type函数的参数
df['ktype']=df['ktype']-1
# 将 'ktype' 列的每个元素减1
print(df.tail())
# 打印DataFrame对象df的最后几行数据


#4
print('\n#4,设置训练数据')
# 打印提示信息,表明接下来是设置训练数据的操作
n9=len(df.index)
# 获取DataFrame对象df的索引长度,即数据的行数
df1=df.head(2000)
# 从DataFrame对象df中选取前2000行数据,存储在DataFrame对象df1中
df2=df.tail(n9-2000)
# 从DataFrame对象df中选取除前2000行之外的剩余数据,存储在DataFrame对象df2中
#
X=df1[zsys2025.ohlcLst].values
# 从DataFrame对象df1中选取zsys2025模块中ohlcLst列表指定的列的数据,并将其转换为numpy数组,存储在变量X中
Y=df1['ktype'].values
# 从DataFrame对象df1中选取 'ktype' 列的数据,并将其转换为numpy数组,存储在变量Y中

#5
print('\n#5,One-Hot Encode')
# 打印提示信息,表明接下来是进行One-Hot编码的操作
y_onehot=pd.get_dummies(Y)
# 使用pandas的get_dummies函数对变量Y进行One-Hot编码,将编码结果存储在DataFrame对象y_onehot中
print('y_onehot.head(5)')
# 打印提示信息,表明接下来要打印y_onehot的前5行数据
print(y_onehot.head(5))
# 打印y_onehot的前5行数据
y1s=pd.get_dummies(y_onehot)
# 对y_onehot再次进行One-Hot编码,将编码结果存储在DataFrame对象y1s中
print('y1s.head(5)')
# 打印提示信息,表明接下来要打印y1s的前5行数据
print(y1s.head(5))
# 打印y1s的前5行数据
#
y1=y1s.values
# 将DataFrame对象y1s转换为numpy数组,存储在变量y1中
print('y1')
# 打印提示信息,表明接下来要打印y1的数据
print(y1)
# 打印y1的数据
print('type(y1),',type(y1))
# 打印y1的数据类型
print('y1.shape,',y1.shape)
# 打印y1的形状

#6
print('\n#6,One-Hot Decode')
# 打印提示信息,表明接下来是进行One-Hot解码的操作
a0,a1=y1[0],y1[1]
# 从y1中选取第0行和第1行的数据,分别存储在变量a0和a1中
a0v,a1v=np.argmax(a0,axis=0),np.argmax(a1,axis=0)
# 使用numpy的argmax函数分别找出a0和a1中最大值的索引,存储在变量a0v和a1v中
# axis=0 表示在第0个维度(行方向)上查找最大值的索引
print('\na0v,',a0v,a0)
# 打印a0v的值和a0的数据
print('a1v,',a1v,a1)
# 打印a1v的值和a1的数据

发布者:股市刺客,转载请注明出处:https://www.95sca.cn/archives/907123
站内所有文章皆来自网络转载或读者投稿,请勿用于商业用途。如有侵权、不妥之处,请联系站长并出示版权证明以便删除。敬请谅解!

(0)
股市刺客的头像股市刺客
上一篇 5小时前
下一篇 5小时前

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注