Python量化交易- mplfinance库 -画K线图

mplfinance库

  • 1. mplfinance 模块说明
  • 2. mplfinance安装
  • 3. mplfinance 模块 plot 基本用法
    • 参数
      • type
      • style
      • make_addplot
      • 设置图表颜色 make_marketcolors
      • 添加图表样式 make_mpf_style
  • 4. mplfinance 的基本K线图
    • 实现自定义风格和颜色
    • 图表尺寸调整、相关信息的显示
    • 添加完整移动平均线
    • 添加指标 MACD
    • 实现鼠标拖动平移交互功能
    • 实现鼠标滚轮缩放
    • 实现双击切换指标
    • 使用键盘方向键平移缩放K线图及切换指标
    • 完整代码实现

1. mplfinance 模块说明

mplfinance是专用于金融数据的可视化分析模块,是基于matplotlib的实用模块程序。

mplfinance 使用操作简单,绘制个均线什么的一个关键字参数解决,剔除停盘时间段的空白不用你想它已经自动做了,还有时间坐标都是自动完成的,比如显示的是当天k线就只显示时间,跨天就自动带上日期,跨年就自动带上年份。

2. mplfinance安装

pip install --upgrade mplfinance

3. mplfinance 模块 plot 基本用法

import mplfinance as mpf
 
mpf.plot(data)

这里需要强调的是参数 data 的类型,data 必须是 pandas.DataFrame 数据类型,对所包含的列也有需要求,必须包含 OpenHignLowClose 数据(注意:首字母是大写的),而且行索引必须是 pandas.DatatimeIndex,行索引的名称必须是 Data(同理注意首字母大写),此外还有一列是 Volume ,这一列不是必须的,可选项(前提是你想要绘制成交量的话)。

import os
import pandas as pd
import mplfinance as mpf

# 拼接路径
filename = os.path.join(os.path.join(os.getcwd(), "量化data"), "sh600000.csv")
print(filename)
df = pd.read_csv(
    # 数据路径
    filename,
    encoding='gbk',
    skiprows=1
)

df = df[['Date', 'Open', 'Close', 'High', 'Low', 'Volume']]
# df.columns = ['Date', 'Open', 'Close', 'High', 'Low', 'Volume', ]
# 将 Date 设置为时间类型
df['Date'] = pd.to_datetime(df['Date'])
# 将 日期 设置为index行索引
df.set_index('Date', inplace=True)
df = df[4300:]

mpf.plot(df)

在这里插入图片描述

参数

参数 描述
type 绘制图线的种类
ylabel y轴标签
style 风格样式
title 图表标题
mav 均线,格式为一个元组,如(5, 10)表示绘制5日均线和10日均线
volume 是否绘制量柱图,默认为False,表示不绘制。
figratio 图像横纵比,如(5,3)表示图像长比宽为5:3
ylabel_lower 表示底部图像的标签(一般是量柱图)
savefig 保存图片
xrotation x轴刻度旋转度
datetime_format 设置x轴刻度日期格式

type

type取值 描述
candle 蜡烛图
ohlc OHLC图,也称“美国线”。即用一根垂直的线段表示一天的行情,在开盘和收盘价格处划一笔刻度。
line 直线,即近绘制收盘价曲线
renko 砖形图
pnf pnf图,由圈和叉构成

candel
在这里插入图片描述
ohlc
在这里插入图片描述
line
在这里插入图片描述
renko
在这里插入图片描述
pnf
在这里插入图片描述

style

style 描述
binance 币安风格
blueskies 蓝天风格
brasil 巴西风格
charles 查理风格
checkers 跳棋风格
classic 古典风格
default 默认风格
mike 夜云风格
nightclouds
sas SAS风格
starsandstripes 星条旗风格
yahoo 雅虎风格

make_addplot

金融数据分析中,我们要通过数据可视化展示的不仅是OpenHighLowClose和最常见ma,还有一些其他分析数据,那么就要用到make_addplot()方法了,make_addplot可以接受一个pandasnumpyarray以及list格式的数据(tuple不可以),和**kwargs参数;需要注意的是:传递给 make_addplot 的数据参数必须与将来画图传递给plot的数据参数行数相同,**kwargs参数将全部传递到polt方法中

import os
import pandas as pd
import mplfinance as mpf
import talib
import matplotlib.pyplot as plt

filename = os.path.join(os.path.join(os.getcwd(), "datas/days"), "000012.SZ.csv")
df = pd.read_csv(filename)
df = df[['trade_date', 'open', 'close', 'high', 'low', 'vol']]
df.columns = ['Date', 'Open', 'Close', 'High', 'Low', 'Volume', ]
df['Date'] = pd.to_datetime(df['Date'])
df.set_index('Date', inplace=True)

df['upper'], df['middle'], df['lower'] = talib.BBANDS(df['Close'], timeperiod=20, nbdevup=2, nbdevdn=2, matype=0)
print(df.tail())

df = df[5760:]

add_plot = mpf.make_addplot(df[['lower']])
mpf.plot(df, addplot=add_plot, type='candle', mav=(2, 5, 10), volume=True)



设置图表颜色 make_marketcolors

"""
make_marketcolors() 设置k线颜色
:up 设置阳线柱填充颜色
:down 设置阴线柱填充颜色
:edge 设置蜡烛线边缘颜色 'i' 代表继承k线的颜色
:wick 设置蜡烛上下影线的颜色
:volume 设置成交量颜色
:inherit 是否继承, 如果设置了继承inherit=True,那么edge即便设了颜色也会无效
"""
my_color = mplfinance.make_marketcolors(up='cyan', down='red', edge='black', wick='black', volume='blue')

添加图表样式 make_mpf_style

参数 描述
base_mpf_style 要继承的 mplfinance 风格
base_mpl_style 要继承的 matplotlib 风格
marketcolors 用于设置K线的颜色。使用mpf.make_marketcolors()方法生成
mavcolors 移动平均线的颜色
facecolor 图像的填充颜色。指的是坐标系内侧的部分的颜色
edgecolor 坐标轴的颜色
figcolor 图像外周边填充色
gridcolor 网格线颜色
gridstyle 设置网格线样式,可以是’-', ‘–’, ‘-.’, ‘:’, ‘’, offset, on-off-seq
gridaxis 网格线的方向,可以是’vertical’, ‘horizontal’, 或 ‘both’
y_on_right 设置y轴的位置是否在右边
rc 设置字体相关。中文和负号的正常显示问题都需要操作该参数。以字典形式传入
legacy_rc 也是用于设置字体格式的,不过与rc不同的是,rc仅会将rc中传入的值更新进字典,并保留原有其他字体参数。而legacy_rc会将所有原字典删除,而仅仅使用legacy_rc
style_name 风格名字,可以在使用mpf.write_style_file(style,filename)方法写自定义风格样式文件时使用
import mplfinance as mpf
import pandas as pd


def read_data(filename):
    df = pd.read_csv(filename, encoding='gbk', skiprows=1)
    df['Date'] = pd.to_datetime(df['Date'])
    df.set_index(['Date'], inplace=True)
    return df


# 获取数据
df = read_data(r'F:----笔记记录----青灯爬虫17期课件爬虫十七期18 js解密(上)1 上课代码量化datash600000.csv')

# 设置线元素的颜色
my_color = mpf.make_marketcolors(
    up="red",  # 上涨K线的颜色
    down="green",  # 下跌K线的颜色
    edge="black",  # 蜡烛图箱体的颜色
    volume="purple",  # 成交量柱子的颜色
    wick="black"  # 蜡烛图影线的颜色
)

# 自定义风格
my_style = mpf.make_mpf_style(
    base_mpf_style='blueskies',
    # base_mpl_style='seaborn',  # 也可以试试matplotlib的seaborn等风格。
    marketcolors=my_color,
    figcolor='(0, 0.8, 0.85)',
    gridcolor='(0.9, 0.9, 0.9)',
    rc={'font.family': 'SimHei', 'axes.unicode_minus': 'False'}
)

# 选择平安银行2022年8月的数据进行绘图
mpf.plot(df.loc['2019-3':'2019-4'],
         type='candle',
         ylabel="price",
         style=my_style,
         title='sh600000 日线行情',
         mav=(5, 10),
         volume=True,
         figratio=(5, 3),
         ylabel_lower="Volume", savefig='sh600000.jpg')

在这里插入图片描述

4. mplfinance 的基本K线图

在这里插入图片描述

  • 颜色风格不符合习惯:众所周知中国股市K线图的颜色代码跟世界惯例是恰好相反的,其他国家都是红跌绿涨而我国是红涨绿跌,上面这幅图中的颜色信息并不符合中国股市的惯例,让人看着别扭。
  • 信息显示不够完整:一般的K线图都会显示出图上最后一个交易日的OHLC也就是开/高/低/收价格、使用红色和绿色表示本交易日相对过去的涨跌情况,同时还能显示其他的相关数据如涨幅、交易额、交易量等等信息,另外,股票代码、名称等信息也应该显示出来。
  • 均线系统不够完善:首先是均线不完整,最初的多个交易日内没有均线,这是因为仅仅使用图表内的数据计算均线导致产生了空缺值,其次,除了均线之外,没有其他的相关价格指标,如布林带线等与价格一同显示,也没有区间内最高价、最低价的指示。
  • 无动态交互功能:无法拖动平移K线以显示更早或更晚的交易日K线数据,无法通过鼠标滚轮增大或缩小显示的K线的范围,也无法通过点选某一根K线以显示当天交易日的详细信息。

以上这些功能都只能说是绝大部分股票软件所提供的最基本的交互式K线图交互功能,然而前面实现的K线图仅仅是个静态图像,因此离哪怕是最基本的实用性也还差得很远。好在mplfinance是基于matplotlib开发的,因此前面所描述的基本交互K线图功能,都可以实现。
在这里插入图片描述

目标
在开始实际工作之前,需要确定我们需要达到的目标,以便一步步实现:

  • 符合中国习惯的配色风格——红涨绿跌自然是必须实现的第一步
  • 图表上要能显示股票代码和股票名称、以及价格信息
  • 图表上要显示完整的移动平均线
  • 在交易量的下方显示第三张图表,同步显示相关指标如MACD等
  • 在图表上用鼠标单击拖动,可以平移K线图以显示更早或更晚的K线
  • 在图表上是用鼠标滚轮缩放,可以实现放大或缩小所显示的K线的范围
  • 在图表上双击,可以循环切换移动平均线和布林带线
  • 在指标图上双击,可以循环切换不同的指标类型如MACD/DEMA/RSI等等

实现自定义风格和颜色

import mplfinance as mpf
# 设置mplfinance的蜡烛颜色,up为阳线颜色,down为阴线颜色
my_color = mpf.make_marketcolors(up='r',
                                 down='g',
                                 edge='inherit',
                                 wick='inherit',
                                 volume='inherit')
# 设置图表的背景色
my_style = mpf.make_mpf_style(marketcolors=my_color,
                              figcolor='(0.82, 0.83, 0.85)',
                              gridcolor='(0.82, 0.83, 0.85)')

make_marketcolors 函数中,几个不同的参数主要用于设置K线的颜色,updown 都很明显,用于分别指定上涨K线和下跌K线的颜色。因此根据国内习惯自然应该设置 up=‘r' 也就是red红色,down 自然就是 ’g' 也就是 ‘green’ 绿色。不过需要注意的是这里仅仅设置K线的柱子的内部填充色,如果不指定边框、上下影线的颜色,他们都会是黑色,显示的效果就是黑色的边框、黑色的上下影线,挺难看的,因此还需要设置边框 "edge" 的颜色。此处设置为 “in”“inherit” 代表“使用主配色“。也就是说,阳线(上涨)的柱子外框线跟阳线的内部填充色一致,那么如果阳线的颜色为红色,边框的颜色也是红色,如果阳线是绿色,则边框也是绿色。阴线也一样。
wick设置的就是上下影线的颜色,这里为了显眼,同样设置为”in“。
类似的,volume设置的是交易量柱子的颜色,也设置为”in“就可以了。

有朋友可能想问,如果我不喜欢标准的红绿色配色,觉得太鲜艳了,想改成自定义的RGB配色可不可以,当然可以,不过需要注意的是,在标准的 matplotlib 中,可以传入一个元组表示RGB配色,例如 (0.5, 0.8, 0.6) 然而mpf不能直接传递元组作为颜色代码,但可以接受一个表示元组的字符串,如上面代码中的 figcolor='(0.82, 0.83, 0.85)'

make_mpf_style() 函数接受上面的参数,将所有的配置都存储在一个字典中,然后使用mpf的基本绘图方法,就可以生成一张符合中国股市习惯的K线图了:

import pandas as pd
# 读取测试数据
data = pd.read_csv('test_data.csv', index_col=0)
# 读取的测试数据索引为字符串类型,需要转化为时间日期类型
data.index = pd.to_datetime(data.index)
mpf.plot(data.iloc[100:200], style=my_style, type='candle', volume=True)

在这里插入图片描述
上面的代码从本地读取测试数据后,将其中一部分显示在K线图中,如上图所示,红涨绿跌。这就是符合国内习惯的K线图了。

图表尺寸调整、相关信息的显示

有了符合中国股市习惯配色风格的K线图,接下来需要调整图表尺寸,同时显示价格信息。
mplfinance的默认设置下,K线图会显示两张图表,K线图在上,交易量柱状图在下。实际上在大多数情况下,还需要第三张图表以显示一些相关的指标如KDJ,MACD等等,另外,图表的顶部应该预留出一些区域用于显示价格。

因此我们必须对图表的尺寸和位置进行精确控制,然而mplfinance的基础用法是不允许我们控制每一个图表的位置的,因此就必须使用mplfinance提供的另一种用法“External Axes Mode“,在这种模式下,我们可以像使用matplotlib一样直接控制画布上的每一个图标元素和文字元素,获得更大的操作自由。

为了实现自由控制,需要获取图表的 figure 对象,然后手动在 figure 上放置图表 Axes 和文字Text,文字和图表的位置、大小、格式完全自定义,下面是代码

# data是测试数据,可以直接下载后读取,在下例中只显示其中100个交易日的数据
plot_data = data.iloc[100: 200]
# 读取显示区间最后一个交易日的数据
last_data = plot_data.iloc[-1]
# 使用mpf.figure()函数可以返回一个figure对象,从而进入External Axes Mode,从而实现对Axes对象和figure对象的自由控制
fig = mpf.figure(style=my_style, figsize=(12, 8), facecolor=(0.82, 0.83, 0.85))
# 添加三个图表,四个数字分别代表图表左下角在figure中的坐标,以及图表的宽(0.88)、高(0.60)
ax1 = fig.add_axes([0.06, 0.25, 0.88, 0.60])
# 添加第二、三张图表时,使用sharex关键字指明与ax1在x轴上对齐,且共用x轴
ax2 = fig.add_axes([0.06, 0.15, 0.88, 0.10], sharex=ax1)
ax3 = fig.add_axes([0.06, 0.05, 0.88, 0.10], sharex=ax1)
# 设置三张图表的Y轴标签
ax1.set_ylabel('price')
ax2.set_ylabel('volume')
ax3.set_ylabel('macd')
# 在figure对象上添加文本对象,用于显示各种价格和标题
fig.text(0.50, 0.94, '513100.SH - 纳指ETF:')
fig.text(0.12, 0.90, '开/收: ')
fig.text(0.14, 0.89, f'{np.round(last_data["open"], 3)} / {np.round(last_data["close"], 3)}')
fig.text(0.14, 0.86, f'{last_data["change"]}')
fig.text(0.22, 0.86, f'[{np.round(last_data["pct_change"], 2)}%]')
fig.text(0.12, 0.86, f'{last_data.name.date()}')
fig.text(0.40, 0.90, '高: ')
fig.text(0.40, 0.90, f'{last_data["high"]}')
fig.text(0.40, 0.86, '低: ')
fig.text(0.40, 0.86, f'{last_data["low"]}')
fig.text(0.55, 0.90, '量(万手): ')
fig.text(0.55, 0.90, f'{np.round(last_data["volume"] / 10000, 3)}')
fig.text(0.55, 0.86, '额(亿元): ')
fig.text(0.55, 0.86, f'{last_data["value"]}')
fig.text(0.70, 0.90, '涨停: ')
fig.text(0.70, 0.90, f'{last_data["upper_lim"]}')
fig.text(0.70, 0.86, '跌停: ')
fig.text(0.70, 0.86, f'{last_data["lower_lim"]}')
fig.text(0.85, 0.90, '均价: ')
fig.text(0.85, 0.90, f'{np.round(last_data["average"], 3)}')
fig.text(0.85, 0.86, '昨收: ')
fig.text(0.85, 0.86, f'{last_data["last_close"]}')
# 调用mpf.plot()函数,注意调用的方式跟上一节不同,这里需要指定ax=ax1,volume=ax2,将K线图显示在ax1中,交易量显示在ax2中
mpf.plot(plot_data,
		 ax=ax1,
		 volume=ax2,
		 type='candle',
		 style=my_style)
fig.show()		

External Axes Mode 模式下,由于我们手动创建了几个Axes对象(这也就是External Axes Mode的由来),调用mpf.plot()函数,注意调用的方式跟上一节不同,这里需要指定ax=ax1, volume=ax2,将K线图显示在ax1中,交易量显示在ax2中。
在这里插入图片描述
可以看到,图表的格式和数量都正确了,三个 Axes 分别用于显示K线图、交易量以及指标(暂时还未显示),最后一个交易日的价格显示在顶部区域,但是有两个问题:

  • 数字的格式和颜色不对应该用红绿色区分不同的价格,字体大小也需要设置正确的格式
  • 中文显示为乱码需要设法使mplfinance支持utf-8编码格式的字符串

为了解决第一个问题,我们可以预设几种不同的格式备用,而第二个问题的原因在于使用的字体不支持中文,只要使用支持中文的字体就可以了。因此,可以分别定义下面几种字体,分别用于标题(黑色大字),开盘价/收盘价(大字体数字)和普通字体,每种字体都有红色和绿色两个版本:

# 标题格式,字体为中文字体,颜色为黑色,粗体,水平中心对齐
title_font = {'fontname': 'pingfang HK', 
              'size':     '16',
              'color':    'black',
              'weight':   'bold',
              'va':       'bottom',
              'ha':       'center'}
# 红色数字格式(显示开盘收盘价)粗体红色24号字
large_red_font = {'fontname': 'Arial',
                  'size':     '24',
                  'color':    'red',
                  'weight':   'bold',
                  'va':       'bottom'}
# 绿色数字格式(显示开盘收盘价)粗体绿色24号字
large_green_font = {'fontname': 'Arial',
                    'size':     '24',
                    'color':    'green',
                    'weight':   'bold',
                    'va':       'bottom'}
# 小数字格式(显示其他价格信息)粗体红色12号字
small_red_font = {'fontname': 'Arial',
                  'size':     '12',
                  'color':    'red',
                  'weight':   'bold',
                  'va':       'bottom'}
# 小数字格式(显示其他价格信息)粗体绿色12号字
small_green_font = {'fontname': 'Arial',
                    'size':     '12',
                    'color':    'green',
                    'weight':   'bold',
                    'va':       'bottom'}
# 标签格式,可以显示中文,普通黑色12号字
normal_label_font = {'fontname': 'pingfang HK',
                     'size':     '12',
                     'color':    'black',
                     'va':       'bottom',
                     'ha':       'right'}
# 普通文本格式,普通黑色12号字
normal_font = {'fontname': 'Arial',
               'size':     '12',
               'color':    'black',
               'va':       'bottom',
               'ha':       'left'}

然后修改一下前面的代码,将格式分别应用到各个文本中去:

# 这里的代码与上一段完全相同
fig = mpf.figure(style=my_style, figsize=(12, 8), facecolor=(0.82, 0.83, 0.85))
ax1 = fig.add_axes([0.06, 0.25, 0.88, 0.60])
ax2 = fig.add_axes([0.06, 0.15, 0.88, 0.10], sharex=ax1)
ax3 = fig.add_axes([0.06, 0.05, 0.88, 0.10], sharex=ax1)
ax1.set_ylabel('price')
ax2.set_ylabel('volume')
ax3.set_ylabel('macd')
# 设置显示文本的时候,返回文本对象
# 对不同的文本采用不同的格式
t1 = fig.text(0.50, 0.94, '513100.SH - 纳指ETF:', **title_font)
t2 = fig.text(0.12, 0.90, '开/收: ', **normal_label_font)
t3 = fig.text(0.14, 0.89, f'{np.round(last_data["open"], 3)} / {np.round(last_data["close"], 3)}', **large_red_font)
t4 = fig.text(0.14, 0.86, f'{last_data["change"]}', **small_red_font)
t5 = fig.text(0.22, 0.86, f'[{np.round(last_data["pct_change"], 2)}%]', **small_red_font)
t6 = fig.text(0.12, 0.86, f'{last_data.name.date()}', **normal_label_font)
t7 = fig.text(0.40, 0.90, '高: ', **normal_label_font)
t8 = fig.text(0.40, 0.90, f'{last_data["high"]}', **small_red_font)
t9 = fig.text(0.40, 0.86, '低: ', **normal_label_font)
t10 = fig.text(0.40, 0.86, f'{last_data["low"]}', **small_green_font)
t11 = fig.text(0.55, 0.90, '量(万手): ', **normal_label_font)
t12 = fig.text(0.55, 0.90, f'{np.round(last_data["volume"] / 10000, 3)}', **normal_font)
t13 = fig.text(0.55, 0.86, '额(亿元): ', **normal_label_font)
t14 = fig.text(0.55, 0.86, f'{last_data["value"]}', **normal_font)
t15 = fig.text(0.70, 0.90, '涨停: ', **normal_label_font)
t16 = fig.text(0.70, 0.90, f'{last_data["upper_lim"]}', **small_red_font)
t17 = fig.text(0.70, 0.86, '跌停: ', **normal_label_font)
t18 = fig.text(0.70, 0.86, f'{last_data["lower_lim"]}', **small_green_font)
t19 = fig.text(0.85, 0.90, '均价: ', **normal_label_font)
t20 = fig.text(0.85, 0.90, f'{np.round(last_data["average"], 3)}', **normal_font)
t21 = fig.text(0.85, 0.86, '昨收: ', **normal_label_font)
t22 = fig.text(0.85, 0.86, f'{last_data["last_close"]}', **normal_font)

mpf.plot(plot_data,
		 ax=ax1,
		 volume=ax2,
		 type='candle',
		 style=my_style)
fig.show()

在这里插入图片描述
有的朋友在运行上述代码时,可能会遇到错误说使用的中文字体不存在,因而还是显示乱码,这里给出一个解决方案供大家参考:

为了显示系统中有哪些中文字体,可以先导入 matplotlibFontManager 类,调用这个类的 ttflist 属性,就可以看到系统中已经存在的所有可以被 matplotlib 使用的字体了,选择其中的中文字体即可(中文字体名称中一般都带有拼音,或者含有 TC、SC 之类的关键字:

>>> from matplotlib.font_manager import FontManager
>>> fm = FontManager()
>>> fm.ttflist
Out:
[<Font 'STIXSizeOneSym' (STIXSizOneSymBol.ttf) normal normal 700 normal>,
 <Font 'STIXSizeOneSym' (STIXSizOneSymReg.ttf) normal normal 400 normal>,
...
 <Font 'PingFang HK' (PingFang.ttc) normal normal 400 normal>,
... 
 <Font 'STIXIntegralsUpD' (STIXIntUpDReg.otf) normal normal 400 normal>,
 <Font 'Apple Braille' (Apple Braille Pinpoint 6 Dot.ttf) normal normal 400 normal>] 

清单中的字体可能会比较多,也有多种中文字体,比如上面例子中的 PingFang HK 就是中文字体,将字体名称PingFang HK用于font就可以了: font_name='PingFang HK'

添加完整移动平均线

mplfinance 的标准 plot() 方法中,有一个mav参数,接受一个整数元组或列表,代表不同的移动平均线的天数,如[5, 20, 60]代表绘制三条均线,分别为5日、20日和60日均线。

然而,直接传入mav参数后绘制的移动平均线是不完整的,在图表的最初一段时间内没有均线,因为 mplfinance 没有足够的数据计算完整均线。如果要显示完整的均线,就必须提前计算好均线数据,然后再添加到K线图中。

mplfinance 中,添加更多的数据和均线到K线图中,不管是移动平均线,指标、买卖点、还是布林带线等等信息,都需要用到 addplot

在我们准备好的test_data中,已经计算好了四条均线的值,通过 data[['MA5', 'MA10', 'MA20', 'MA60']] 就可以访问了,我们现在通过 make_addplot() 方法把这几条均线添加到ax1中:

# 通过ax=ax1参数指定把新的线条添加到ax1中,与K线图重叠
ap = mpf.make_addplot(plot_data[['MA5', 'MA10', 'MA20', 'MA60']], ax=ax1)

接下来,还是调用 mplfinance.plot() 方法,同时指定 addplot=ap 即可。为节省篇幅,下面的代码省略了Axes对象和text对象的创建部分,这部分代码与前面相同:

# 调用plot()方法,注意传递addplot=ap参数,以添加均线
mpf.plot(plot_data,
         ax=ax1,
         volume=ax2,
         addplot=ap,
         type='candle',
         style=my_style)
fig.show()

在这里插入图片描述

添加指标 MACD

至此,一个静态的实用K线图已经初具雏形了,不过, ax3 还是空的,这里本来应该显示MACD指标的,那么如何实现呢?自然还是需要 addplot()
不过,在开始之前,我们要知道,MACD指标包含两条线,还有一组柱状图,而且柱状图还分红绿两色,不是一个简单的图形,需要分别绘制。在 mplfinance 中, addplot 可以是一个字典,也可以是一个列表,列表中包含多组不同的addplot,这样在调用mpf.plot()的时候,可以传入任意多个addplot,从而实现复杂的图表形式。

在本文的测试数据中,已经计算好了用于MACD的数据,分别存放在 data[['macd-m', 'macd-h', 'macd-s']] 中,addplot应该按照以下方法设置:

# 生成一个空列表用于存储多个addplot
ap = []
# 在ax3图表中绘制 MACD指标中的快线和慢线
ap.append(mpf.make_addplot(plot_data[['macd-m', 'macd-s']], ax=ax3))
# 使用柱状图绘制快线和慢线的差值,根据差值的数值大小,分别用红色和绿色填充
# 红色和绿色部分需要分别填充,因此先生成两组数据,分别包含大于零和小于等于零的数据
bar_r = np.where(plot_data['macd-h'] > 0, plot_data['macd-h'], 0)
bar_g = np.where(plot_data['macd-h'] <= 0, plot_data['macd-h'], 0)
# 使用柱状图填充(type='bar'),设置颜色分别为红色和绿色
ap.append(mpf.make_addplot(bar_r, type='bar', color='red', ax=ax3))
ap.append(mpf.make_addplot(bar_g, type='bar', color='green', ax=ax3))

将上面的代码添加到前一节的示例代码中,并运行后得到结果,可以看到MACD指标已经可以显示出来了:
在这里插入图片描述

实现鼠标拖动平移交互功能

到现在为止,我们的K线图已经实现了基本的信息显示功能,不过,这个K线图还是静态的,下面我们将开始重头戏:实现鼠标拖动平移功能,使K线图动态可交互。

为了实现鼠标对图表的控制,我们需要理解一些关于事件响应和面向对象的程序设计方法,对初学者来说可能不是那么友好,不过我会尽量把过程讲清楚,便于大家的理解。

首先,我们需要理解数据的平移在数据层面上是如何实现的。查看示例数据的信息,可以看到实际上示例数据包含488个交易日,从2018-01-02 到 2020-01-03,然而,我们在绘制K线图的时候,仅仅截取了其中的100个交易日的数据,这100个交易日的数据从第100个交易日开始,到第199个交易日结束。更宽泛地讲,我们可以说截取从第N个交易日开始,到第N+99个交易日结束的所有数据,显示在K线图上。
在这里插入图片描述
那么实际上,如果我们希望在K线图上将画面向右平移,以看到更早交易日的K线图,只需要将N减少,视野向左移动,反之则将N增大,视野向右移动,K线图中的画面向左平移,因此,所谓K线图的平移,实际上是通过N的大小来控制的,减少N则画面向右平移,增加N则画面向左平移。

同时,在K线图平移的过程中,图像是需要不断刷新的,最简单的做法就是当N值改变时,不断地刷新画面,清除原来的内容并重新绘制新的K线图,为了方便地实现画面刷新的功能,需要定义一个刷新画面的函数,同时,还需要保存N值。所以,我们最好是定义一个K线图类,将画面刷新函数和属性N都封装起来,便于使用。

class InterCandle: # 定义一个交互K线图类
    def __init__(self, data, my_style):
        # 初始化交互式K线图对象,历史数据作为唯一的参数用于初始化对象
        self.data = data
        self.style = my_style
        # 设置初始化的K线图显示区间起点为0,即显示第0到第99个交易日的数据(前100个数据)
        self.idx_start = 0
        
        # 初始化figure对象,在figure上建立三个Axes对象并分别设置好它们的位置和基本属性
        self.fig = mpf.figure(style=my_style, figsize=(12, 8), facecolor=(0.82, 0.83, 0.85))
        fig = self.fig
        self.ax1 = fig.add_axes([0.08, 0.25, 0.88, 0.60])
        self.ax2 = fig.add_axes([0.08, 0.15, 0.88, 0.10], sharex=self.ax1)
        self.ax2.set_ylabel('volume')
        self.ax3 = fig.add_axes([0.08, 0.05, 0.88, 0.10], sharex=self.ax1)
        self.ax3.set_ylabel('macd')
        # 初始化figure对象,在figure上预先放置文本并设置格式,文本内容根据需要显示的数据实时更新
        # 初始化时,所有的价格数据都显示为空字符串
        self.t1 = fig.text(0.50, 0.94, 'TITLE', **title_font)
        self.t2 = fig.text(0.12, 0.90, '开/收: ', **normal_label_font)
        self.t3 = fig.text(0.14, 0.89, '', **large_red_font)
        self.t4 = fig.text(0.14, 0.86, '', **small_red_font)
        self.t5 = fig.text(0.22, 0.86, '', **small_red_font)
        self.t6 = fig.text(0.12, 0.86, '', **normal_label_font)
        self.t7 = fig.text(0.40, 0.90, '高: ', **normal_label_font)
        self.t8 = fig.text(0.40, 0.90, '', **small_red_font)
        self.t9 = fig.text(0.40, 0.86, '低: ', **normal_label_font)
        self.t10 = fig.text(0.40, 0.86, '', **small_green_font)
        self.t11 = fig.text(0.55, 0.90, '量(万手): ', **normal_label_font)
        self.t12 = fig.text(0.55, 0.90, '', **normal_font)
        self.t13 = fig.text(0.55, 0.86, '额(亿元): ', **normal_label_font)
        self.t14 = fig.text(0.55, 0.86, '', **normal_font)
        self.t15 = fig.text(0.70, 0.90, '涨停: ', **normal_label_font)
        self.t16 = fig.text(0.70, 0.90, '', **small_red_font)
        self.t17 = fig.text(0.70, 0.86, '跌停: ', **normal_label_font)
        self.t18 = fig.text(0.70, 0.86, '', **small_green_font)
        self.t19 = fig.text(0.85, 0.90, '均价: ', **normal_label_font)
        self.t20 = fig.text(0.85, 0.90, '', **normal_font)
        self.t21 = fig.text(0.85, 0.86, '昨收: ', **normal_label_font)
        self.t22 = fig.text(0.85, 0.86, '', **normal_font)

    def refresh_plot(self, idx_start):
        """ 根据最新的参数,重新绘制整个图表
        """
        all_data = self.data
        plot_data = all_data.iloc[idx_start: idx_start + 100]

        ap = []
        # 添加K线图重叠均线
        ap.append(mpf.make_addplot(plot_data[['MA5', 'MA10', 'MA20', 'MA60']], ax=self.ax1))
        # 添加指标MACD
        ap.append(mpf.make_addplot(plot_data[['macd-m', 'macd-s']], ax=self.ax3))
        bar_r = np.where(plot_data['macd-h'] > 0, plot_data['macd-h'], 0)
        bar_g = np.where(plot_data['macd-h'] <= 0, plot_data['macd-h'], 0)
        ap.append(mpf.make_addplot(bar_r, type='bar', color='red', ax=self.ax3))
        ap.append(mpf.make_addplot(bar_g, type='bar', color='green', ax=self.ax3))
        # 绘制图表
        mpf.plot(plot_data,
                 ax=self.ax1,
                 volume=self.ax2,
                 addplot=ap,
                 type='candle',
                 style=self.style,
                 datetime_format='%Y-%m',
                 xrotation=0)
        self.fig.show()

    def refresh_texts(self, display_data):
        """ 更新K线图上的价格文本
        """
        # display_data是一个交易日内的所有数据,将这些数据分别填入figure对象上的文本中
        self.t3.set_text(f'{np.round(display_data["open"], 3)} / {np.round(display_data["close"], 3)}')
        self.t4.set_text(f'{display_data["change"]}')
        self.t5.set_text(f'[{np.round(display_data["pct_change"], 2)}%]')
        self.t6.set_text(f'{display_data.name.date()}')
        self.t8.set_text(f'{display_data["high"]}')
        self.t10.set_text(f'{display_data["low"]}')
        self.t12.set_text(f'{np.round(display_data["volume"] / 10000, 3)}')
        self.t14.set_text(f'{display_data["value"]}')
        self.t16.set_text(f'{display_data["upper_lim"]}')
        self.t18.set_text(f'{display_data["lower_lim"]}')
        self.t20.set_text(f'{np.round(display_data["average"], 3)}')
        self.t22.set_text(f'{display_data["last_close"]}')
        # 根据本交易日的价格变动值确定开盘价、收盘价的显示颜色
        if display_data['change'] > 0:  # 如果今日变动额大于0,即今天价格高于昨天,今天价格显示为红色
            close_number_color = 'red'
        elif display_data['change'] < 0:  # 如果今日变动额小于0,即今天价格低于昨天,今天价格显示为绿色
            close_number_color = 'green'
        else:
            close_number_color = 'black'
        self.t1.set_color(close_number_color)
        self.t2.set_color(close_number_color)
        self.t3.set_color(close_number_color)

上面的代码其实跟我们在前几节中的代码没有什么不同,唯一的区别是他们都被放在了一个名为 InterCandle 的类中。在这个类中我们定义了三个方法:

  • __init__() 方法,初始化对象。在这里创建一个图表figure对象,生成整个K线图的布局、放置文本,同时定义一个关键的参数:idx_start,这个参数是一个整形变量,代表我们需要显示的K线区间的开始日起,如0表示从第1个交易日,100表示从第101个交易日开始,以此类推
  • refresh_plot() 方法:接受参数idx_start,显示从idx_start开始的100个交易日的K线
  • refresh_text() 方法:更新图表上的价格文本,接受某一天的全部价格数据,并把这一天的数据显示在K线图上

由于采用了面向对象的方法写程序,因此我们要显示第N天开始的100日K线的代码就变的非常简单:

# 创建一个InterCandle对象,显示风格为前面定义好的my_style风格(即中国股市惯例风格)
candle = InterCandle(data, my_style)
# 更新图表上的文本,显示第100个交易日的K线数据
candle.refresh_texts(data.iloc[100])
# 更新显示第100天开始的K线图
candle.refresh_plot(100)

在这里插入图片描述
如果要显示第1天开始的K线图,或者任意交易日开始的K线图,只需要重复上面的三条指令,并把其中的100改为0即可,不过,这种方法实现平移太笨了,不是说好的用鼠标拖动来平移吗?是的,这就是我们接下来要完成的工作。

我们如果想用鼠标实现与图表的交互,就必须借助“事件”,例如鼠标单击、移动、键盘输入,这些都是事件。 matplotlib 库允许我们建立“回调函数”来响应不同的事件,那么当我们移动鼠标或者单击鼠标的时候,这些回调函数就会自动运行,对我们的操作产生回应。在每个事件发生的同时,回调函数会接收到响应的信息,以确定哪一个按键被按下了,或者鼠标目前所处的位置,根据这些信息,我们就能控制K线图的平移了。具体怎么做呢?

设想一下,当我们使用鼠标在K线图上拖拽时,实际上包含三个过程(会发生三种事件):

  • 鼠标按键按下的事件,这时我们知道拖拽动作开始,通过事件信息,我们可以知道哪一个鼠标按键被按下,同时,最关键的,我们能知道当按键按下时鼠标的坐标(x,y),这就是拖拽的起点
  • 移动鼠标事件保持按键按下,移动鼠标,这时会持续产生“移动鼠标”事件,并且通过事件信息我们可以获知鼠标的新坐标,随着鼠标的移动,这个坐标不断更新,描绘出鼠标拖拽的轨迹
  • 鼠标按键释放事件,这时我们知道拖拽动作已经结束。
    以上三个事件定义了一个完整的鼠标拖拽过程,通过这三个事件中鼠标坐标的关系,我们就能确定新的N值(这个N值定义了K线图的起点),随着鼠标的移动,不断用新的N值更新图表,我们就得到了一个可以拖动平移的交互式动态K线图:原理如下图所示
    在这里插入图片描述
    为了实现上述功能,我们需要在本节定义的 InterCandle 类中定义三个回调函数,分别对应鼠标键按下、鼠标移动、鼠标键释放三个事件的处理,同时增加几个中间属性,用于识别鼠标键是否按下:
	# 在InterCandle类的初始化过程中,设置几个中间变量,用于存储鼠标状态
    def __init__(self, data, my_style):
    	# 鼠标按键状态,False为按键未按下,True为按键按下
        self.press = False
        # 鼠标按下时的x坐标
        self.xpress = None

        # 以下部分代码与本节开头的完全相同,为节省篇幅,略去不表
		...
		# 下面的代码在__init__()中,告诉matplotlib哪些回调函数用于响应哪些事件
		# 鼠标按下事件与self.on_press回调函数绑定
        fig.canvas.mpl_connect('button_press_event', self.on_press)
        # 鼠标按键释放事件与self.on_release回调函数绑定
        fig.canvas.mpl_connect('button_release_event', self.on_release)
        # 鼠标移动事件与self.on_motion回调函数绑定
        fig.canvas.mpl_connect('motion_notify_event', self.on_motion)

    def on_press(self, event):
    	# 当鼠标按键按下时,调用该函数,event为事件信息,是一个dict对象,包含事件相关的信息
    	# 如坐标、按键类型、是否在某个Axes对象内等等
    	# event.inaxes可用于判断事件发生时,鼠标是否在某个Axes内,在这里我们指定,只有鼠
    	# 标在ax1内时,才能平移K线图,否则就退出事件处理函数
        if not event.inaxes == self.ax1:
            return
        # 检查是否按下了鼠标左键,如果不是左键,同样退出事件处理函数
        if event.button != 1:
            return
        # 如果鼠标在ax1范围内,且按下了左键,条件满足,设置鼠标状态为pressed
        self.pressed = True
        # 同时记录鼠标按下时的x坐标,退出函数,等待鼠标移动事件发生
        self.xpress = event.xdata
	
	# 鼠标移动事件处理
    def on_motion(self, event):
        # 如果鼠标按键没有按下pressed == False,则什么都不做,退出处理函数
        if not self.pressed:
            return
        # 如果移动出了ax1的范围,也退出处理函数
        if not event.inaxes == self.ax1:
            return
        # 如果鼠标在ax1范围内,且左键按下,则开始计算dx,并根据dx计算新的K线图起点
        dx = int(event.xdata - self.xpress)
        # 前面介绍过了,新的起点N(new) = N - dx
        new_start = self.idx_start - dx
        # 设定平移的左右界限,如果平移后超出界限,则不再平移
        if new_start <= 0:
            new_start = 0
        if new_start >= len(self.data) - 100:
            new_start = len(self.data) - 100
		# 清除各个图表Axes中的内容,准备以新的起点重新绘制
		self.ax1.clear()
		self.ax2.clear()
		self.ax3.clear()
		# 更新图表上的文字、以新的起点开始绘制K线图
        self.refresh_texts(self.data.iloc[new_start])
        self.refresh_plot(new_start)
	
	# 鼠标按键释放
    def on_release(self, event):
    	# 按键释放后,设置鼠标的pressed为False
        self.pressed = False
        # 此时别忘了最后一次更新K线图的起点,否则下次拖拽的时候就不会从这次的起点开始移动了
        dx = int(event.xdata - self.xpress)
        self.idx_start -= dx
        if self.idx_start <= 0:
            self.idx_start = 0
        if self.idx_start >= len(self.data) - 100:
            self.idx_start = len(self.data) - 100

通过上面代码,我们对 InterCandle 类进行了一番改造(为了节省篇幅,省略了未改动的部分):

  • __init__() 初始化方法中增加了几个状态变量,同时,非常关键的一步,绑定了事件处理函数
  • 定义了三个事件处理函数:
    • on_press() 函数记录鼠标的初始位置、设置鼠标按键状态
    • on_motions() 函数计算鼠标距离原点的水平距离,并将K线图的起点相应调整,重新绘制K线图
    • on_release()释放鼠标状态,并最后一次更新K线图的起点

怎么样?并不难吧?更新 InterCandle 类后,测试一下吧:

candle = InterCandle(data, my_style)
candle.refresh_texts(data.iloc[150])
candle.refresh_plot(150)

在这里插入图片描述
虽然FPS很低,但是K线图已经动起来了

因为我们用了最简单的 axes.clear() + 重新绘制的方式实现交互式动态图,这种方法虽然简单,但是性能比较差。其实, matplotlib 还提供了另一种高性能的动态图表实现方法 blitting ,未来如果朋友们有需求,可以尝试改进一下图表的性能

实现鼠标滚轮缩放

可以使用鼠标拖拽来平移的动态K线图,了解了 matplotlib 的事件控制机制和回调函数的工作方式,有了这些武器,下面就可以添加更多的交互功能了。

在前面一节,我们使用一个固定的参数N(代码中是idx_start参数)来控制所显示的K线的起点,固定显示从N开始以后100个交易日的K线柱,那么如何实现鼠标滚轮缩放呢?想必朋友们都很聪明,早已想到了,我们可以把100个交易日变成R个交易日,动态控制R的值就可以了。

因此,我们需要首先在 InterCandle 类中增加一个参数, idx_range() ,用来控制可显示的交易日的数量,另外,创建一个新的回调函数 on_scroll()

class InterCandle():
    def __init__(self, data, my_style):
        ... # 其他代码未发生变化,因而省略
        self.idx_range = 100  # 控制K线图的显示范围大小
        ...
        # 将新增的回调函数on_scroll与鼠标滚轮事件绑定起来
        self.fig.canvas.mpl_connect('scroll_event', on_scroll)
    ...
    def on_scroll(self, event):
        # 仅当鼠标滚轮在axes1范围内滚动时起作用
        if event.inaxes != self.ax1:
            return
        if event.button == 'down':
            # 缩小20%显示范围
            scale_factor = 0.8
        if event.button == 'up':
            # 放大20%显示范围
            scale_factor = 1.2
		# 设置K线的显示范围大小
        self.idx_range = int(self.idx_range * scale_factor)
        # 限定可以显示的K线图的范围,最少不能少于30个交易日,最大不能超过当前位置与
        # K线数据总长度的差
        data_length = len(self.data)
        if self.idx_range >= data_length - self.idx_start:
            self.idx_range = data_length - self.idx_start
        if self.idx_range <= 30:
            self.idx_range = 30 
		# 更新图表(注意因为多了一个参数idx_range,refresh_plot函数也有所改动)
        self.ax1.clear()
        self.ax2.clear()
        self.ax3.clear()
        self.refresh_texts(self.data.iloc[self.idx_start])
        self.refresh_plot(self.idx_start, self.idx_range)

在上面的代码中我们在 InterCandle 类中增加了一个新的属性,并创建了一个回调函数。

  • on_scroll() 函数:当我们在图表上滚动鼠标滚轮时,回调函数会判断滚轮的滚动方向,并根据滚动方向将当前的 self.idx_range 参数放大20%或者缩小20%,除非这个参数已经越界了。调整好参数后,同样调用 refresh_plot() 方法更新图表。

需要注意的是, refresh_plot() 方法需要稍作调整,接受idx_range 参数控制显示K线的范围,限于篇幅,这里就不赘述了,文末的源码里有。

好了,看看效果吧!
在这里插入图片描述

实现双击切换指标

最后再增加一个功能,在K线图区域双击,将移动平均线切换为布林带线反复双击则在移动平均线布林带线之间循环切换

了解了前面的基础,双击切换的思路也很容易:在 on_press 函数中增加一个判断,当鼠标双击时(此时event.dbl_click属性为1,单击时为0),修改ax1的addplot,并刷新图表即可。因此on_press方法应做如下修改:

    def on_press(self, event):
        if not event.inaxes == self.ax1:
            return
        if event.button != 1:
            return
        self.pressed = True
        self.xpress = event.xdata

        # 切换当前ma类型, 在ma、bb、none之间循环
        if event.inaxes == self.ax1 and event.dblclick == 1:
            if self.avg_type == 'ma':
                self.avg_type = 'bb'
            elif self.avg_type == 'bb':
                self.avg_type = 'none'
            else:
                self.avg_type = 'ma'
        # 切换当前indicator类型,在macd/dma/rsi/kdj之间循环
        if event.inaxes == self.ax3 and event.dblclick == 1:
            if self.indicator == 'macd':
                self.indicator = 'dma'
            elif self.indicator == 'dma':
                self.indicator = 'rsi'
            elif 
                self.indicator = 'macd'

我做的 测试数据 中要包含了布林带线、dema、rsi等指标的值,所以可以直接使用 mpf.make_addplot() 方法将这些数据加入到K线图中:

    def refresh_plot(self, idx_start, idx_range):
        ...  # 此处省略
        # 添加K线图重叠均线,根据均线类型添加移动均线或布林带线
        if self.avg_type == 'ma':
            ap.append(mpf.make_addplot(plot_data[['MA5', 'MA10', 'MA20', 'MA60']], ax=self.ax1))
        elif self.avg_type == 'bb':
            ap.append(mpf.make_addplot(plot_data[['bb-u', 'bb-m', 'bb-l']], ax=self.ax1))
        # 添加指标,根据指标类型添加MACD或RSI或DEMA
        if self.indicator == 'macd':
            ap.append(mpf.make_addplot(plot_data[['macd-m', 'macd-s']], ylabel='macd', ax=self.ax3))
            bar_r = np.where(plot_data['macd-h'] > 0, plot_data['macd-h'], 0)
            bar_g = np.where(plot_data['macd-h'] <= 0, plot_data['macd-h'], 0)
            ap.append(mpf.make_addplot(bar_r, type='bar', color='red', ax=self.ax3))
            ap.append(mpf.make_addplot(bar_g, type='bar', color='green', ax=self.ax3))
        elif self.indicator == 'rsi':
            ap.append(mpf.make_addplot([75] * len(plot_data), color=(0.75, 0.6, 0.6), ax=self.ax3))
            ap.append(mpf.make_addplot([30] * len(plot_data), color=(0.6, 0.75, 0.6), ax=self.ax3))
            ap.append(mpf.make_addplot(plot_data['rsi'], ylabel='rsi', ax=self.ax3))
        else:  # indicator == 'dema'
            ap.append(mpf.make_addplot(plot_data['dema'], ylabel='dema', ax=self.ax3))
        # 绘制图表
        ...  # 此处省略

使用键盘方向键平移缩放K线图及切换指标

    # 键盘按下处理
    def on_key_press(self, event):
        data_length = len(self.data)
        if event.key == 'a':  # avg_type, 在ma,bb,none之间循环
            if self.avg_type == 'ma':
                self.avg_type = 'bb'
            elif self.avg_type == 'bb':
                self.avg_type = 'none'
            elif self.avg_type == 'none':
                self.avg_type = 'ma'
        elif event.key == 'up':  # 向上,看仔细1倍
            if self.idx_range > 60:
                self.idx_range = int(self.idx_range / 2)
        elif event.key == 'down':  # 向下,看多1倍标的
            if self.idx_range <= 480:
                self.idx_range = self.idx_range * 2
        elif event.key == 'left':  
            if self.idx_start > self.idx_range:
                self.idx_start = self.idx_start - self.idx_range
        elif event.key == 'right':
            if self.idx_start < data_length - self.idx_range:
                self.idx_start = self.idx_start + self.idx_range
        self.ax1.clear()
        self.ax2.clear()
        self.ax3.clear()
        self.refresh_texts(self.data.iloc[self.idx_start])
        self.refresh_plot(self.idx_start, self.idx_range)

上面的代码实现了按键切换的功能,别忘了添加事件的回调函数:
fig.canvas.mpl_connect('key_press_event', self.on_key_press)

完整代码实现

# coding=utf-8
# inter_candle.py

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import mplfinance as mpf

# 读取示例数据
data = pd.read_csv('test_data.csv', index_col=0)
data.index = pd.to_datetime(data.index)

my_color = mpf.make_marketcolors(up='r',
                                 down='g',
                                 edge='inherit',
                                 wick='inherit',
                                 volume='inherit')
my_style = mpf.make_mpf_style(marketcolors=my_color,
                                  figcolor='(0.82, 0.83, 0.85)',
                                  gridcolor='(0.82, 0.83, 0.85)')
# 定义各种字体
title_font = {'fontname': 'pingfang HK',
              'size':     '16',
              'color':    'black',
              'weight':   'bold',
              'va':       'bottom',
              'ha':       'center'}
large_red_font = {'fontname': 'Arial',
                  'size':     '24',
                  'color':    'red',
                  'weight':   'bold',
                  'va':       'bottom'}
large_green_font = {'fontname': 'Arial',
                    'size':     '24',
                    'color':    'green',
                    'weight':   'bold',
                    'va':       'bottom'}
small_red_font = {'fontname': 'Arial',
                  'size':     '12',
                  'color':    'red',
                  'weight':   'bold',
                  'va':       'bottom'}
small_green_font = {'fontname': 'Arial',
                    'size':     '12',
                    'color':    'green',
                    'weight':   'bold',
                    'va':       'bottom'}
normal_label_font = {'fontname': 'pingfang HK',
                     'size':     '12',
                     'color':    'black',
                     'weight':   'normal',
                     'va':       'bottom',
                     'ha':       'right'}
normal_font = {'fontname': 'Arial',
               'size':     '12',
               'color':    'black',
               'weight':   'normal',
               'va':       'bottom',
               'ha':       'left'}

class InterCandle:
    def __init__(self, data, my_style):
        self.pressed = False
        self.xpress = None

        # 初始化交互式K线图对象,历史数据作为唯一的参数用于初始化对象
        self.data = data
        self.style = my_style
        # 设置初始化的K线图显示区间起点为0,即显示第0到第99个交易日的数据(前100个数据)
        self.idx_start = 0
        self.idx_range = 100
        # 设置ax1图表中显示的均线类型
        self.avg_type = 'ma'
        self.indicator = 'macd'

        # 初始化figure对象,在figure上建立三个Axes对象并分别设置好它们的位置和基本属性
        self.fig = mpf.figure(style=my_style, figsize=(12, 8), facecolor=(0.82, 0.83, 0.85))
        fig = self.fig
        self.ax1 = fig.add_axes([0.08, 0.25, 0.88, 0.60])
        self.ax2 = fig.add_axes([0.08, 0.15, 0.88, 0.10], sharex=self.ax1)
        self.ax2.set_ylabel('volume')
        self.ax3 = fig.add_axes([0.08, 0.05, 0.88, 0.10], sharex=self.ax1)
        self.ax3.set_ylabel('macd')
        # 初始化figure对象,在figure上预先放置文本并设置格式,文本内容根据需要显示的数据实时更新
        self.t1 = fig.text(0.50, 0.94, '513100.SH - 纳斯达克指数ETF基金', **title_font)
        self.t2 = fig.text(0.12, 0.90, '开/收: ', **normal_label_font)
        self.t3 = fig.text(0.14, 0.89, f'', **large_red_font)
        self.t4 = fig.text(0.14, 0.86, f'', **small_red_font)
        self.t5 = fig.text(0.22, 0.86, f'', **small_red_font)
        self.t6 = fig.text(0.12, 0.86, f'', **normal_label_font)
        self.t7 = fig.text(0.40, 0.90, '高: ', **normal_label_font)
        self.t8 = fig.text(0.40, 0.90, f'', **small_red_font)
        self.t9 = fig.text(0.40, 0.86, '低: ', **normal_label_font)
        self.t10 = fig.text(0.40, 0.86, f'', **small_green_font)
        self.t11 = fig.text(0.55, 0.90, '量(万手): ', **normal_label_font)
        self.t12 = fig.text(0.55, 0.90, f'', **normal_font)
        self.t13 = fig.text(0.55, 0.86, '额(亿元): ', **normal_label_font)
        self.t14 = fig.text(0.55, 0.86, f'', **normal_font)
        self.t15 = fig.text(0.70, 0.90, '涨停: ', **normal_label_font)
        self.t16 = fig.text(0.70, 0.90, f'', **small_red_font)
        self.t17 = fig.text(0.70, 0.86, '跌停: ', **normal_label_font)
        self.t18 = fig.text(0.70, 0.86, f'', **small_green_font)
        self.t19 = fig.text(0.85, 0.90, '均价: ', **normal_label_font)
        self.t20 = fig.text(0.85, 0.90, f'', **normal_font)
        self.t21 = fig.text(0.85, 0.86, '昨收: ', **normal_label_font)
        self.t22 = fig.text(0.85, 0.86, f'', **normal_font)

        fig.canvas.mpl_connect('button_press_event', self.on_press)
        fig.canvas.mpl_connect('button_release_event', self.on_release)
        fig.canvas.mpl_connect('motion_notify_event', self.on_motion)
        fig.canvas.mpl_connect('key_press_event', self.on_key_press)
        fig.canvas.mpl_connect('scroll_event', self.on_scroll)

    def refresh_plot(self, idx_start, idx_range):
        """ 根据最新的参数,重新绘制整个图表
        """
        all_data = self.data
        plot_data = all_data.iloc[idx_start: idx_start + idx_range]

        ap = []
        # 添加K线图重叠均线,根据均线类型添加移动均线或布林带线
        if self.avg_type == 'ma':
            ap.append(mpf.make_addplot(plot_data[['MA5', 'MA10', 'MA20', 'MA60']], ax=self.ax1))
        elif self.avg_type == 'bb':
            ap.append(mpf.make_addplot(plot_data[['bb-u', 'bb-m', 'bb-l']], ax=self.ax1))
        # 添加指标,根据指标类型添加MACD或RSI或DEMA
        if self.indicator == 'macd':
            ap.append(mpf.make_addplot(plot_data[['macd-m', 'macd-s']], ylabel='macd', ax=self.ax3))
            bar_r = np.where(plot_data['macd-h'] > 0, plot_data['macd-h'], 0)
            bar_g = np.where(plot_data['macd-h'] <= 0, plot_data['macd-h'], 0)
            ap.append(mpf.make_addplot(bar_r, type='bar', color='red', ax=self.ax3))
            ap.append(mpf.make_addplot(bar_g, type='bar', color='green', ax=self.ax3))
        elif self.indicator == 'rsi':
            ap.append(mpf.make_addplot([75] * len(plot_data), color=(0.75, 0.6, 0.6), ax=self.ax3))
            ap.append(mpf.make_addplot([30] * len(plot_data), color=(0.6, 0.75, 0.6), ax=self.ax3))
            ap.append(mpf.make_addplot(plot_data['rsi'], ylabel='rsi', ax=self.ax3))
        else:  # indicator == 'dema'
            ap.append(mpf.make_addplot(plot_data['dema'], ylabel='dema', ax=self.ax3))

        # 绘制图表
        mpf.plot(plot_data,
                 ax=self.ax1,
                 volume=self.ax2,
                 addplot=ap,
                 type='candle',
                 style=self.style,
                 datetime_format='%Y-%m',
                 xrotation=0)

        plt.show()

    def refresh_texts(self, display_data):
        """ 更新K线图上的价格文本
        """
        # display_data是一个交易日内的所有数据,将这些数据分别填入figure对象上的文本中
        self.t3.set_text(f'{np.round(display_data["open"], 3)} / {np.round(display_data["close"], 3)}')
        self.t4.set_text(f'{np.round(display_data["change"], 3)}')
        self.t5.set_text(f'[{np.round(display_data["pct_change"], 3)}%]')
        self.t6.set_text(f'{display_data.name.date()}')
        self.t8.set_text(f'{np.round(display_data["high"], 3)}')
        self.t10.set_text(f'{np.round(display_data["low"], 3)}')
        self.t12.set_text(f'{np.round(display_data["volume"] / 10000, 3)}')
        self.t14.set_text(f'{display_data["value"]}')
        self.t16.set_text(f'{np.round(display_data["upper_lim"], 3)}')
        self.t18.set_text(f'{np.round(display_data["lower_lim"], 3)}')
        self.t20.set_text(f'{np.round(display_data["average"], 3)}')
        self.t22.set_text(f'{np.round(display_data["last_close"], 3)}')
        # 根据本交易日的价格变动值确定开盘价、收盘价的显示颜色
        if display_data['change'] > 0:  # 如果今日变动额大于0,即今天价格高于昨天,今天价格显示为红色
            close_number_color = 'red'
        elif display_data['change'] < 0:  # 如果今日变动额小于0,即今天价格低于昨天,今天价格显示为绿色
            close_number_color = 'green'
        else:
            close_number_color = 'black'
        self.t3.set_color(close_number_color)
        self.t4.set_color(close_number_color)
        self.t5.set_color(close_number_color)

    def on_press(self, event):
        if not (event.inaxes == self.ax1) and (not event.inaxes == self.ax3):
            return
        if event.button != 1:
            return
        self.pressed = True
        self.xpress = event.xdata

        # 切换当前ma类型, 在ma、bb、none之间循环
        if event.inaxes == self.ax1 and event.dblclick == 1:
            if self.avg_type == 'ma':
                self.avg_type = 'bb'
            elif self.avg_type == 'bb':
                self.avg_type = 'none'
            else:
                self.avg_type = 'ma'
        # 切换当前indicator类型,在macd/dma/rsi/kdj之间循环
        if event.inaxes == self.ax3 and event.dblclick == 1:
            if self.indicator == 'macd':
                self.indicator = 'dma'
            elif self.indicator == 'dma':
                self.indicator = 'rsi'
            elif self.indicator == 'rsi':
                self.indicator = 'kdj'
            else:
                self.indicator = 'macd'

        self.ax1.clear()
        self.ax2.clear()
        self.ax3.clear()
        self.refresh_plot(self.idx_start, self.idx_range)

    def on_release(self, event):
        self.pressed = False
        dx = int(event.xdata - self.xpress)
        self.idx_start -= dx
        if self.idx_start <= 0:
            self.idx_start = 0
        if self.idx_start >= len(self.data) - 100:
            self.idx_start = len(self.data) - 100

    def on_motion(self, event):
        if not self.pressed:
            return
        if not event.inaxes == self.ax1:
            return
        dx = int(event.xdata - self.xpress)
        new_start = self.idx_start - dx
        # 设定平移的左右界限,如果平移后超出界限,则不再平移
        if new_start <= 0:
            new_start = 0
        if new_start >= len(self.data) - 100:
            new_start = len(self.data) - 100
        self.ax1.clear()
        self.ax2.clear()
        self.ax3.clear()

        self.refresh_texts(self.data.iloc[new_start])
        self.refresh_plot(new_start, self.idx_range)

    def on_scroll(self, event):
        # 仅当鼠标滚轮在axes1范围内滚动时起作用
        scale_factor = 1.0
        if event.inaxes != self.ax1:
            return
        if event.button == 'down':
            # 缩小20%显示范围
            scale_factor = 0.8
        if event.button == 'up':
            # 放大20%显示范围
            scale_factor = 1.2
        # 设置K线的显示范围大小
        self.idx_range = int(self.idx_range * scale_factor)
        # 限定可以显示的K线图的范围,最少不能少于30个交易日,最大不能超过当前位置与
        # K线数据总长度的差
        data_length = len(self.data)
        if self.idx_range >= data_length - self.idx_start:
            self.idx_range = data_length - self.idx_start
        if self.idx_range <= 30:
            self.idx_range = 30
            # 更新图表(注意因为多了一个参数idx_range,refresh_plot函数也有所改动)
        self.ax1.clear()
        self.ax2.clear()
        self.ax3.clear()
        self.refresh_texts(self.data.iloc[self.idx_start])
        self.refresh_plot(self.idx_start, self.idx_range)

    # 键盘按下处理
    def on_key_press(self, event):
        data_length = len(self.data)
        if event.key == 'a':  # avg_type, 在ma,bb,none之间循环
            if self.avg_type == 'ma':
                self.avg_type = 'bb'
            elif self.avg_type == 'bb':
                self.avg_type = 'none'
            elif self.avg_type == 'none':
                self.avg_type = 'ma'
        elif event.key == 'up':  # 向上,看仔细1倍
            if self.idx_range > 30:
                self.idx_range = self.idx_range // 2
        elif event.key == 'down':  # 向下,看多1倍标的
            if self.idx_range <= data_length - self.idx_start:
                self.idx_range = self.idx_range * 2
        elif event.key == 'left':
            if self.idx_start > self.idx_range:
                self.idx_start = self.idx_start - self.idx_range // 2
        elif event.key == 'right':
            if self.idx_start < data_length - self.idx_range:
                self.idx_start = self.idx_start + self.idx_range //2
        self.ax1.clear()
        self.ax2.clear()
        self.ax3.clear()
        self.refresh_texts(self.data.iloc[self.idx_start])
        self.refresh_plot(self.idx_start, self.idx_range)

if __name__ == '__main__':
    candle = InterCandle(data, my_style)
    candle.idx_start = 150
    candle.idx_range = 100
    candle.refresh_texts(data.iloc[249])
    candle.refresh_plot(150, 100)