(1)源代码
# -*- coding: utf-8 -*-
#股票数据热编码One-Hot
import sys, os
sys.path.append("topqt/")
import numpy as np
import pandas as pd
import tushare as ts
import plotly as py
import plotly.figure_factory as pyff
import math
import arrow
import ffn
import pypinyin
import pandas_datareader as pdr
import matplotlib.pyplot as plt
import zsys2025 # 20250213
import ztools as zt
import ztools_str as zstr
import ztools_data2025 as zdat
import ztools_draw2025 as zdr
import ztools_tq2025 as ztq # 20250213
#1
print('\n#1,set.sys')
pd.set_option('display.width', 450)
pd.set_option('display.float_format', zt.xfloat3)
#2
print('\n#2,读取数据')
fss='data/inx_000001.csv'
#fss='data/300682.csv'
df=pd.read_csv(fss,index_col=0)
df=df.sort_index()
#3
print('\n#3,整理数据')
df['xopen']=df['open'].shift(-1)
df['xclose']=df['close'].shift(-1)
df['kclose']=df['xclose']/df['xopen']*100
df['ktype']=df['kclose'].apply(zt.iff3type,d0=99.5,d9=101.5,v3=3,v2=2,v1=1)
df['ktype']=df['ktype']-1
print(df.tail())
#4
print('\n#4,设置训练数据')
n9=len(df.index)
df1=df.head(2000)
df2=df.tail(n9-2000)
#
X=df1[zsys2025.ohlcLst].values
Y=df1['ktype'].values
#5
print('\n#5,One-Hot Encode')
y_onehot=pd.get_dummies(Y)
print('y_onehot.head(5)')
print(y_onehot.head(5))
y1s=pd.get_dummies(y_onehot)
print('y1s.head(5)')
print(y1s.head(5))
#
y1=y1s.values
print('y1')
print(y1)
print('type(y1),',type(y1))
print('y1.shape,',y1.shape)
#
#6
print('\n#6,One-Hot Decode')
a0,a1=y1[0],y1[1]
a0v,a1v=np.argmax(a0,axis=0),np.argmax(a1,axis=0)
print('\na0v,',a0v,a0)
print('a1v,',a1v,a1)
(2)程序运行输出
C:\Users\Administrator\Anaconda3\envs\tensorflow\python.exe C:\Users\Administrator\PycharmProjects\TOPQuant_Tensorflow-StockDataHotCoded_One-Hot.py
Using TensorFlow backend.
2025-02-15 19:07:59.585705: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'cudart64_100.dll'; dlerror: cudart64_100.dll not found
2025-02-15 19:07:59.585848: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
#1,set.sys
#2,读取数据
#3,整理数据
open high close low ... xopen xclose kclose ktype
date ...
2017-04-26 3132.918 3152.953 3140.847 3131.418 ... 3131.350 3152.187 100.665 1
2017-04-27 3131.350 3155.003 3152.187 3097.333 ... 3144.022 3154.658 100.338 1
2017-04-28 3144.022 3154.727 3154.658 3136.578 ... 3147.228 3143.712 99.888 1
2017-05-02 3147.228 3154.781 3143.712 3136.539 ... 3138.307 3135.346 99.906 1
2017-05-03 3138.307 3148.286 3135.346 3123.751 ... nan nan nan 1
[5 rows x 10 columns]
#4,设置训练数据
#5,One-Hot Encode
y_onehot.head(5)
0 1 2
0 0 1 0
1 0 0 1
2 0 0 1
3 0 1 0
4 0 1 0
y1s.head(5)
0 1 2
0 0 1 0
1 0 0 1
2 0 0 1
3 0 1 0
4 0 1 0
y1
[[0 1 0]
[0 0 1]
[0 0 1]
...
[1 0 0]
[0 0 1]
[0 1 0]]
type(y1), <class 'numpy.ndarray'>
y1.shape, (2000, 3)
#6,One-Hot Decode
a0v, 1 [0 1 0]
a1v, 2 [0 0 1]
Process finished with exit code 0
(3)源代码注释
# -*- coding: utf-8 -*-
# 声明文件的编码格式为UTF-8,确保文件中包含的中文等非ASCII字符能正确被解析
#股票数据热编码One-Hot
# 程序的功能描述,即对股票数据进行One-Hot编码操作
import sys, os
# 导入sys和os模块
# sys模块提供了一些变量和函数,用于与Python解释器进行交互,如访问命令行参数、修改模块搜索路径等
# os模块提供了一种方便的使用操作系统相关功能的方式,如文件和目录操作等
sys.path.append("topqt/")
# 将 "topqt/" 目录添加到Python模块搜索路径中,这样Python解释器就能在该目录下查找要导入的模块
import numpy as np
# 导入numpy库,并将其重命名为np
# numpy是Python中用于科学计算的基础库,提供了多维数组对象和各种数学函数
import pandas as pd
# 导入pandas库,并将其重命名为pd
# pandas是一个强大的数据处理和分析库,提供了DataFrame和Series等数据结构
import tushare as ts
# 导入tushare库,用于获取金融数据
# tushare是一个免费、开源的python财经数据接口包
import plotly as py
# 导入plotly库,并将其重命名为py
# plotly是一个交互式可视化库,可用于创建各种类型的图表
import plotly.figure_factory as pyff
# 导入plotly的figure_factory模块,并将其重命名为pyff
# figure_factory模块提供了一些创建复杂图表的函数
import math
# 导入math模块,提供了许多数学函数和常量
import arrow
# 导入arrow库,用于处理日期和时间
# arrow是一个简洁、易用的Python日期时间处理库
import ffn
# 导入ffn库,用于金融数据分析
# ffn是一个金融分析库,提供了一些金融指标计算和数据分析工具
import pypinyin
# 导入pypinyin库,用于将中文转换为拼音
import pandas_datareader as pdr
# 导入pandas_datareader库,并将其重命名为pdr
# pandas_datareader是一个用于从各种在线数据源获取金融数据的库
import matplotlib.pyplot as plt
# 导入matplotlib的pyplot模块,并将其重命名为plt
# matplotlib是一个用于绘制图表的库,pyplot模块提供了类似于MATLAB的绘图接口
import zsys2025 # 20250213
# 导入自定义模块zsys2025,注释中给出了导入时间
# 自定义模块通常包含一些自定义的函数和变量,用于特定的业务逻辑
import ztools as zt
# 导入自定义模块ztools,并将其重命名为zt
# 自定义工具模块,可能包含一些常用的工具函数
import ztools_str as zstr
# 导入自定义模块ztools_str,并将其重命名为zstr
# 自定义字符串处理工具模块,可能包含一些处理字符串的函数
import ztools_data2025 as zdat
# 导入自定义模块ztools_data2025,并将其重命名为zdat
# 自定义数据处理工具模块,可能包含一些处理数据的函数
import ztools_draw2025 as zdr
# 导入自定义模块ztools_draw2025,并将其重命名为zdr
# 自定义绘图工具模块,可能包含一些绘制图表的函数
import ztools_tq2025 as ztq # 20250213
# 导入自定义模块ztools_tq2025,并将其重命名为ztq,注释中给出了导入时间
# 自定义工具模块,可能包含与特定任务相关的工具函数
#1
print('\n#1,set.sys')
# 打印提示信息,表明接下来是设置系统相关参数
pd.set_option('display.width', 450)
# 设置pandas DataFrame在控制台输出时的宽度为450个字符,使输出更美观
pd.set_option('display.float_format', zt.xfloat3)
# 设置pandas DataFrame中浮点数的显示格式,使用zt模块中的xfloat3函数进行格式化
#2
print('\n#2,读取数据')
# 打印提示信息,表明接下来是读取数据的操作
fss='data/300682.csv'
# 定义要读取的CSV文件的路径
df=pd.read_csv(fss,index_col=0)
# 使用pandas的read_csv函数读取CSV文件,并将第一列作为索引列,将读取的数据存储在DataFrame对象df中
df=df.sort_index()
# 对DataFrame对象df按照索引进行排序
#3
print('\n#3,整理数据')
# 打印提示信息,表明接下来是整理数据的操作
df['xopen']=df['open'].shift(-1)
# 在DataFrame对象df中新增一列 'xopen',其值为 'open' 列向下移动一行后的值
df['xclose']=df['close'].shift(-1)
# 在DataFrame对象df中新增一列 'xclose',其值为 'close' 列向下移动一行后的值
df['kclose']=df['xclose']/df['xopen']*100
# 在DataFrame对象df中新增一列 'kclose',其值为 'xclose' 列除以 'xopen' 列再乘以100的结果
df['ktype']=df['kclose'].apply(zt.iff3type,d0=99.5,d9=101.5,v3=3,v2=2,v1=1)
# 在DataFrame对象df中新增一列 'ktype',其值是对 'kclose' 列的每个元素应用zt模块中的iff3type函数得到的结果
# 参数d0=99.5, d9=101.5, v3=3, v2=2, v1=1 是传递给iff3type函数的参数
df['ktype']=df['ktype']-1
# 将 'ktype' 列的每个元素减1
print(df.tail())
# 打印DataFrame对象df的最后几行数据
#4
print('\n#4,设置训练数据')
# 打印提示信息,表明接下来是设置训练数据的操作
n9=len(df.index)
# 获取DataFrame对象df的索引长度,即数据的行数
df1=df.head(2000)
# 从DataFrame对象df中选取前2000行数据,存储在DataFrame对象df1中
df2=df.tail(n9-2000)
# 从DataFrame对象df中选取除前2000行之外的剩余数据,存储在DataFrame对象df2中
#
X=df1[zsys2025.ohlcLst].values
# 从DataFrame对象df1中选取zsys2025模块中ohlcLst列表指定的列的数据,并将其转换为numpy数组,存储在变量X中
Y=df1['ktype'].values
# 从DataFrame对象df1中选取 'ktype' 列的数据,并将其转换为numpy数组,存储在变量Y中
#5
print('\n#5,One-Hot Encode')
# 打印提示信息,表明接下来是进行One-Hot编码的操作
y_onehot=pd.get_dummies(Y)
# 使用pandas的get_dummies函数对变量Y进行One-Hot编码,将编码结果存储在DataFrame对象y_onehot中
print('y_onehot.head(5)')
# 打印提示信息,表明接下来要打印y_onehot的前5行数据
print(y_onehot.head(5))
# 打印y_onehot的前5行数据
y1s=pd.get_dummies(y_onehot)
# 对y_onehot再次进行One-Hot编码,将编码结果存储在DataFrame对象y1s中
print('y1s.head(5)')
# 打印提示信息,表明接下来要打印y1s的前5行数据
print(y1s.head(5))
# 打印y1s的前5行数据
#
y1=y1s.values
# 将DataFrame对象y1s转换为numpy数组,存储在变量y1中
print('y1')
# 打印提示信息,表明接下来要打印y1的数据
print(y1)
# 打印y1的数据
print('type(y1),',type(y1))
# 打印y1的数据类型
print('y1.shape,',y1.shape)
# 打印y1的形状
#6
print('\n#6,One-Hot Decode')
# 打印提示信息,表明接下来是进行One-Hot解码的操作
a0,a1=y1[0],y1[1]
# 从y1中选取第0行和第1行的数据,分别存储在变量a0和a1中
a0v,a1v=np.argmax(a0,axis=0),np.argmax(a1,axis=0)
# 使用numpy的argmax函数分别找出a0和a1中最大值的索引,存储在变量a0v和a1v中
# axis=0 表示在第0个维度(行方向)上查找最大值的索引
print('\na0v,',a0v,a0)
# 打印a0v的值和a0的数据
print('a1v,',a1v,a1)
# 打印a1v的值和a1的数据
发布者:股市刺客,转载请注明出处:https://www.95sca.cn/archives/907123
站内所有文章皆来自网络转载或读者投稿,请勿用于商业用途。如有侵权、不妥之处,请联系站长并出示版权证明以便删除。敬请谅解!