python绘制动态k线图 (python k线图)

"""
Python  K线图,均线,MACD,高低点图

"""
# -*- coding:utf-8 -*-
import time
import pandas as pd
from pyecharts.charts import Kline, Line, Bar, Grid
from pyecharts import options as opts
from pyecharts.commons.utils import JsCode
from typing import List, Sequence, Union

"""建立mysql,postgresql等数据库链接"""
engine_mysql = create_engine('mysql+pymysql://用户名:密码@端口:/数据库名?charset=utf8')
engine_pg = create_engine("postgresql+psycopg2://用户名:,密码@端口:/数据库名", client_encoding='utf8')


def make_highest_lowest_date(df, engine_wm, ma_1=5, ma_2=26):
    start_dt = df.trade_date.iloc[0]
    end_dt = df.trade_date.iloc[-1]
    db_nm = 'trade_cal'  # 生成交易日历
    query = "select cal_date from {0} where cal_date >= '{1}' and  cal_date <= '{2}' " \
            "and is_open = '1';".format(db_nm, start_dt, end_dt)
    # print(query)
    df_cal = pd.read_sql(query, con=engine_wm)  # 取单数据
    # print('trade_cal is ok , num is  :',len(df_cal))
    df_cal.drop_duplicates(subset='cal_date', keep='first', inplace=True)  # 去重
    df_cal = df_cal.sort_values(by='cal_date', axis=0, ascending=True)  # 排序 升序
    df_cal = df_cal.reset_index(drop=True)  # 重新索引并排序
    cal_list = df_cal.cal_date.tolist()
    df_cal_list = df.trade_date.tolist()
    df_wm = df
    df_cal_list = df.trade_date.tolist()
    # 设置索引
    df_wm = df_wm.set_index("trade_date", drop="False")
    # 将df的索引设置为日期索引
    df_wm = df_wm.set_index(pd.to_datetime(df_wm.index))
    # 生成完整的日期序列
    pd_cal = pd.to_datetime(cal_list)  # 生成 trade_cal 日期的datetimeindex格式
    # 填充缺失索引,并填充默认值
    df_wm = df_wm.reindex(pd_cal)
    zvalues = df_wm.loc[~(df_wm.vol > 0)].loc[:, ['vol', 'amount']]
    df_wm.update(zvalues.fillna(0))
    df_wm.fillna(method='ffill', inplace=True)
    df_wm.head()

    df_dd = df_wm.reset_index()  # 将索引变回
    df_dd = df_dd.rename(columns={'index': 'trade_date'})  # 改回列名

    df_nv = df_dd
    item_variable = 'trade_date'  # 将日期格式变成字符串
    df_nv[item_variable] = df_nv[item_variable].astype(str)  # 转换数据类型

    df_nv['ma_1'] = df_nv.close.rolling(ma_1).mean()  # ma_1日收盘价均值
    df_nv['ma_2'] = df_nv.close.rolling(ma_2).mean()  # ma_2日收盘价均值

    # 生成境均线上下串日期列表
    point_wm = []  # 时间点
    date_pool = set(df_nv.trade_date.to_list())
    for i in range(len(df_nv)):
        if i == 0:
            point_wm.append(df_nv.iloc[i].trade_date)
        elif i != len(df_nv) - 1:
            for_day = df_nv.iloc[i - 1].close
            to_day = df_nv.iloc[i].close
            next_day = df_nv.iloc[i + 1].close
            for_day5 = df_nv.iloc[i - 1].ma_1
            to_day5 = df_nv.iloc[i].ma_1
            next_day5 = df_nv.iloc[i + 1].ma_1
            for_day26 = df_nv.iloc[i - 1].ma_2
            to_day26 = df_nv.iloc[i].ma_2
            next_day26 = df_nv.iloc[i + 1].ma_2
            if for_day5 > for_day26 and next_day5 < next_day26:
                if point_wm[-1] != df_nv.iloc[i - 1].trade_date:  # 避免上下交易日同时取值
                    point_wm.append(df_nv.iloc[i].trade_date)
                    pass
            elif for_day5 < for_day26 and next_day5 > next_day26:
                if point_wm[-1] != df_nv.iloc[i - 1].trade_date:  # 避免上下交易日同时取值
                    point_wm.append(df_nv.iloc[i].trade_date)
                    pass
        else:
            point_wm.append(df_nv.iloc[i].trade_date)

    # 计算高低点的日期
    h_l_dt = []
    for i in range(len(point_wm)):
        try:
            if i == len(point_wm) - 1:
                pass
            else:
                tday = point_wm[i]
                nday = point_wm[i + 1]
                # print(tday,nday)
                df_wm = df_nv[(df_nv['trade_date'] >= tday) & (df_nv['trade_date'] <= nday)]
                df_wm = df_wm.reset_index()  # 重建索引号
                # print(df_wm)
                if df_wm.iloc[1].ma_1 >= df_wm.iloc[1].ma_2:
                    # m = df_wm.high.idxmax(1)
                    m = df_wm.high.idxmax()
                    # print('highest is :',m,df_wm.iloc[m].trade_date)
                    h_l_dt.append(df_wm.iloc[m].trade_date)
                    pass
                else:
                    n = df_wm.low.idxmin()
                    # print('lowest is :',n,df_wm.iloc[n].trade_date)
                    h_l_dt.append(df_wm.iloc[n].trade_date)
                    pass
        except Exception as err:
            print('pgdtime.make_highest_lowest_date is err.', err)

    return h_l_dt

def data_astype(df, item_list, item5='float'):
    #格式化数据
    try:
        for item_nv in item_list:
            # print(item_nv)
            item_variable = item_nv  # 设定转换类型的变更
            df[item_variable] = df[item_variable].astype(item5)  # 转换数据类型
        return df
    except Exception as err:
        print(err)


def draw_k(db_nm, ts_code, start_dt, end_dt, engine_wm, html_wm):
    time_s = time.time()
    def macd(df_wm,period1 = 12,period2 = 26,period3=9):
       data = df_wm
       data.drop_duplicates(subset=['ts_code','trade_date'],keep='first',inplace=True)  #去重
       data = data.reset_index()   #重新排
       item_variable = 'close'   #设定转换类型的变更
       data[item_variable] = data[item_variable].astype('float')  #转换数据类型
       data['dif']=data['close'].ewm(adjust=False,alpha=2/(period1+1),ignore_na=True).mean()-data['close'].ewm(adjust=False,alpha=2/(period2+1),ignore_na=True).mean()
       data['dea']=data['dif'].ewm(adjust=False,alpha=2/(period3+1),ignore_na=True).mean()
       data['macd']=2*(data['dif']-data['dea'])
       return data

    def calculate_ma(day_count: int):
        result: List[Union[float, str]] = []

        for i in range(len(data["times"])):
            if i < day_count:
                result.append("-")
                continue
            sum_total = 0.0
            for j in range(day_count):
                sum_total += float(data["datas"][i - j][1])
            result.append(abs(float("%.2f" % (sum_total / day_count))))
        return result

    # 数据
    def split_data(origin_data) -> dict:
        datas = []
        times = []
        vols = []
        macds = []
        difs = []
        deas = []

        for i in range(len(origin_data)):
            datas.append(origin_data[i][1:])
            times.append(origin_data[i][0:1][0])
            vols.append(origin_data[i][5])
            macds.append(origin_data[i][7])
            difs.append(origin_data[i][8])
            deas.append(origin_data[i][9])
        vols = [int(v) for v in vols]

        return {
            "datas": datas,
            "times": times,
            "vols": vols,
            "macds": macds,
            "difs": difs,
            "deas": deas,
        }

    def split_data_part() -> Sequence:
        mark_line_data = []
        idx = 0
        tag = 0
        vols = 0
        for i in range(len(data["times"])):
            if data["datas"][i][5] != 0 and tag == 0:
                idx = i
                vols = data["datas"][i][4]
                tag = 1
            if tag == 1:
                vols += data["datas"][i][4]
            if data["datas"][i][5] != 0 or tag == 1:
                mark_line_data.append(
                    [
                        {
                            "xAxis": idx,
                            "yAxis": float("%.2f" % data["datas"][idx][3])
                            if data["datas"][idx][1] > data["datas"][idx][0]
                            else float("%.2f" % data["datas"][idx][2]),
                            "value": vols,
                        },
                        {
                            "xAxis": i,
                            "yAxis": float("%.2f" % data["datas"][i][3])
                            if data["datas"][i][1] > data["datas"][i][0]
                            else float("%.2f" % data["datas"][i][2]),
                        },
                    ]
                )
                idx = i
                vols = data["datas"][i][4]
                tag = 2
            if tag == 2:
                vols += data["datas"][i][4]
            #"value": str(float("%.2f" % (vols / (i - idx + 1)))) + " M"
            if data["datas"][i][5] != 0 and tag == 2:
                mark_line_data.append(
                    [
                        {
                            "xAxis": idx,
                            "yAxis": float("%.2f" % data["datas"][idx][3])
                            if data["datas"][i][1] > data["datas"][i][0]
                            else float("%.2f" % data["datas"][i][2]),
                            "value": "",
                        },
                        {
                            "xAxis": i,
                            "yAxis": float("%.2f" % data["datas"][i][3])
                            if data["datas"][i][1] > data["datas"][i][0]
                            else float("%.2f" % data["datas"][i][2]),
                        },
                    ]
                )
                idx = i
                vols = data["datas"][i][4]
        #return mark_line_data
        return []

    def draw_chart(html_wm):
        kline = (
            Kline()
            .add_xaxis(xaxis_data=data["times"])
            .add_yaxis(
                series_name="",
                y_axis=data["datas"],
                itemstyle_opts=opts.ItemStyleOpts(
                    color="#ef232a",
                    color0="#14b143",
                    border_color="#ef232a",
                    border_color0="#14b143",
                ),
                markpoint_opts=opts.MarkPointOpts(
                    data=[
                        opts.MarkPointItem(type_="max", name="最大值"),
                        opts.MarkPointItem(type_="min", name="最小值"),
                    ]
                ),
                markline_opts=opts.MarkLineOpts(
                    label_opts=opts.LabelOpts(
                        position="middle", color="blue", font_size=15
                    ),
                    data=split_data_part(),
                    symbol=["circle", "none"],
                ),
            )
            .set_series_opts(
                markarea_opts=opts.MarkAreaOpts(is_silent=True, data=split_data_part())
            )
            .set_global_opts(
                title_opts=opts.TitleOpts(title="{}_K线高低点图表".format(ts_code), pos_left="0"),
                xaxis_opts=opts.AxisOpts(
                    type_="category",
                    is_scale=False,
                    boundary_gap=False,
                    axisline_opts=opts.AxisLineOpts(is_on_zero=False),
                    splitline_opts=opts.SplitLineOpts(is_show=False),
                    split_number=20,
                    min_="dataMin",
                    max_="dataMax",
                ),
                yaxis_opts=opts.AxisOpts(
                    is_scale=True, splitline_opts=opts.SplitLineOpts(is_show=False)
                ),
                tooltip_opts=opts.TooltipOpts(trigger="axis", axis_pointer_type="line"),
                datazoom_opts=[
                    opts.DataZoomOpts(
                        is_show=False, type_="inside", xaxis_index=[0, 0], range_end=100
                    ),
                    opts.DataZoomOpts(
                        is_show=False, xaxis_index=[0, 1], pos_top="97%", range_end=100
                    ),
                    opts.DataZoomOpts(is_show=False, xaxis_index=[0, 2], range_end=100),
                ],
            )
        )

        kline_line = (
            Line()
            .add_xaxis(xaxis_data=data["times"])
            .add_yaxis(
                series_name="MA5",
                y_axis=calculate_ma(day_count=5),
                is_smooth=True,
                linestyle_opts=opts.LineStyleOpts(opacity=1),
                label_opts=opts.LabelOpts(is_show=False),
            )
            .add_yaxis(
                series_name="MA10",
                y_axis=calculate_ma(day_count=10),
                is_smooth=True,
                linestyle_opts=opts.LineStyleOpts(opacity=0.5),
                label_opts=opts.LabelOpts(is_show=False),
            )
            .add_yaxis(
                series_name="MA20",
                y_axis=calculate_ma(day_count=20),
                is_smooth=True,
                linestyle_opts=opts.LineStyleOpts(opacity=0.5),
                label_opts=opts.LabelOpts(is_show=False),
            )
            .add_yaxis(
                series_name="MA60",
                y_axis=calculate_ma(day_count=60),
                is_smooth=True,
                linestyle_opts=opts.LineStyleOpts(opacity=0.5),
                label_opts=opts.LabelOpts(is_show=False),
            )

            .set_global_opts(
                xaxis_opts=opts.AxisOpts(
                    type_="category",
                    grid_index=1,
                    axislabel_opts=opts.LabelOpts(is_show=False),
                ),
                yaxis_opts=opts.AxisOpts(
                    grid_index=1,
                    split_number=3,
                    axisline_opts=opts.AxisLineOpts(is_on_zero=False),
                    axistick_opts=opts.AxisTickOpts(is_show=False),
                    splitline_opts=opts.SplitLineOpts(is_show=False),
                    axislabel_opts=opts.LabelOpts(is_show=False),
                ),
            )
        )
        # Overlap Kline + Line
        overlap_kline_line = kline.overlap(kline_line)

        # Bar-1
        bar_1 = (
            Bar()
            .add_xaxis(xaxis_data=data["times"])
            .add_yaxis(
                series_name="Volumn",
                y_axis=data["vols"],
                xaxis_index=1,
                yaxis_index=1,
                label_opts=opts.LabelOpts(is_show=False),
                itemstyle_opts=opts.ItemStyleOpts(
                    color=JsCode(
                        """
                    function(params) {
                        var colorList;
                        if (barData[params.dataIndex][1] > barData[params.dataIndex][0]) {
                            colorList = '#ef232a';
                        } else {
                            colorList = '#14b143';
                        }
                        return colorList;
                    }
                    """
                    )
                ),
            )
            .set_global_opts(
                xaxis_opts=opts.AxisOpts(
                    type_="category",
                    grid_index=1,
                    axislabel_opts=opts.LabelOpts(is_show=False),
                ),
                legend_opts=opts.LegendOpts(is_show=False),
            )
        )

        # Bar-2 (Overlap Bar + Line)
        bar_2 = (
            Bar()
            .add_xaxis(xaxis_data=data["times"])
            .add_yaxis(
                series_name="MACD",
                y_axis=data["macds"],
                xaxis_index=2,
                yaxis_index=2,
                label_opts=opts.LabelOpts(is_show=False),
                itemstyle_opts=opts.ItemStyleOpts(
                    color=JsCode(
                        """
                            function(params) {
                                var colorList;
                                if (params.data >= 0) {
                                  colorList = '#ef232a';
                                } else {
                                  colorList = '#14b143';
                                }
                                return colorList;
                            }
                            """
                    )
                ),
            )
            .set_global_opts(
                xaxis_opts=opts.AxisOpts(
                    type_="category",
                    grid_index=2,
                    axislabel_opts=opts.LabelOpts(is_show=False),
                ),
                yaxis_opts=opts.AxisOpts(
                    grid_index=2,
                    split_number=4,
                    axisline_opts=opts.AxisLineOpts(is_on_zero=False),
                    axistick_opts=opts.AxisTickOpts(is_show=False),
                    splitline_opts=opts.SplitLineOpts(is_show=False),
                    axislabel_opts=opts.LabelOpts(is_show=False),
                ),
                legend_opts=opts.LegendOpts(is_show=False),
            )
        )

        line_2 = (
            Line()
            .add_xaxis(xaxis_data=data["times"])
            .add_yaxis(
                series_name="DIF",
                y_axis=data["difs"],
                xaxis_index=2,
                yaxis_index=2,
                label_opts=opts.LabelOpts(is_show=False),
            )
            .add_yaxis(
                series_name="DIF",
                y_axis=data["deas"],
                xaxis_index=2,
                yaxis_index=2,
                label_opts=opts.LabelOpts(is_show=False),
            )
            .set_global_opts(legend_opts=opts.LegendOpts(is_show=False))
        )
        overlap_bar_line = bar_2.overlap(line_2)
        grid_chart = Grid(init_opts=opts.InitOpts(width="1400px", height="800px"))
        grid_chart.add_js_funcs("var barData = {}".format(data["datas"]))
        grid_chart.add(
            overlap_kline_line,
            grid_opts=opts.GridOpts(pos_left="3%", pos_right="1%", height="60%"),
        )
        grid_chart.add(
            bar_1,
            grid_opts=opts.GridOpts(
                pos_left="3%", pos_right="1%", pos_top="71%", height="10%"
            ),
        )
        grid_chart.add(
            overlap_bar_line,
            grid_opts=opts.GridOpts(
                pos_left="3%", pos_right="1%", pos_top="82%", height="14%"
            ),
        )
        #grid_chart.render("/home/test/picture/pye_k46.html")
        grid_chart.render(html_wm)

    query = "select ts_code,trade_date,open,close,low,high,vol,amount from {0} where trade_date >= '{1}' " \
            "and trade_date <= '{2}' and ts_code = '{3}';".format(db_nm, start_dt, end_dt, ts_code)
    print(query)
    df1 = pd.read_sql(query,con=engine_wm)    #取单数据
    print('df is ok ,first num:',len(df1))
    df1.drop_duplicates(subset=['trade_date','ts_code'], keep='first', inplace=True)  # 去重
    df3 = macd(df1)
    #格式化数据
    item_list = ['open', 'close', 'low', 'high', 'vol', 'macd', 'dif', 'dea']
    df3 = data_astype(df3, item_list, item5='float')
    df5 = df3
    df3 = df3.sort_values(by='trade_date', axis=0, ascending=True)                             #排序 升序
    df3 = df3.reset_index(drop=True)
    df8 = df3    #取单数据
    print('df is ok ,first num:', len(df8))
    df8.drop_duplicates(subset=['trade_date', 'ts_code'], keep='first', inplace=True)  # 去重
    df8 = df8.sort_values(by='trade_date', axis=0, ascending=True)
    #print('df8:',df8)
    item_list = ['open', 'close', 'low', 'high', 'vol']
    df8 = data_astype(df8, item_list, item5='float')
    #排序 升序
    h_l_dt = make_highest_lowest_date(df8, engine_wm, ma_1=5, ma_2=26)     #生成高低点日期
    h_l = []
    for i in range(len(h_l_dt)):
        h_l_sim = h_l_dt[i][:4]+h_l_dt[i][5:7]+h_l_dt[i][-2:]
        h_l.append(h_l_sim)

    def h_l_m(x):
        if x in h_l:
            return 1
        else:
            return 0
    df3['k_form'] = df3.apply(lambda x:h_l_m(x['trade_date']),axis=1)
    df2 = df3[['trade_date', 'open', 'close', 'low', 'high', 'vol', 'k_form', 'macd', 'dif', 'dea']]
    df2 = df2.sort_values(by='trade_date', ascending=True)    #排序

    #格式化数据
    item_list = ['open', 'close', 'low', 'high', 'vol', 'macd', 'dif', 'dea']
    df2 = data_astype(df2, item_list, item5='float')
    echarts_data = df2.values.tolist()

    data = split_data(origin_data=echarts_data)
    draw_chart(html_wm)

if __name__ == '__main__':

    start_dt = '20220101'
    end_dt = '20221024'
    db_nm = 'stock_index_daily'  # 数据库来源
    ts_code = '399300.SZ'
    html_wm = "/home/test/picture/pye_k66.html"
    draw_k(db_nm, ts_code, start_dt, end_dt, engine_pg, html_wm)



    """
    运行结果:
    df is ok ,first num: 471
    df is ok ,first num: 193    
    Process finished with exit code 0
    """
    """
    Pgabc 2022000011
    author : Pgabc
    www.wmdbsoft.com
    """