297 lines
13 KiB
Python
297 lines
13 KiB
Python
import asyncio
|
||
import json
|
||
import time
|
||
from datetime import datetime
|
||
|
||
import bt
|
||
import numpy as np
|
||
import pandas as pd
|
||
from xtquant import xtdata
|
||
import matplotlib.pyplot as plt
|
||
|
||
from src.backtest.until import get_local_data, convert_pandas_to_json_serializable
|
||
from src.models import wance_data_storage_backtest, wance_data_stock
|
||
from src.tortoises_orm_config import init_tortoise
|
||
|
||
|
||
# RSI策略函数
|
||
async def create_dual_ma_strategy(data, stock_code: str, short_window: int = 50, long_window: int = 200,
|
||
overbought: int = 70, oversold: int = 30):
|
||
# 生成RSI策略信号
|
||
signal = await rsi_strategy(data, short_window, long_window, overbought, oversold)
|
||
|
||
# 使用bt框架构建策略
|
||
strategy = bt.Strategy(f'{stock_code} RSI策略',
|
||
[bt.algos.RunDaily(),
|
||
bt.algos.SelectAll(), # 选择所有股票
|
||
bt.algos.WeighTarget(signal), # 根据信号调整权重
|
||
bt.algos.Rebalance()]) # 调仓
|
||
return strategy, signal
|
||
|
||
|
||
async def rsi_strategy(df, short_window=14, long_window=28, overbought=70, oversold=30):
|
||
"""
|
||
基于RSI的策略生成买卖信号。
|
||
|
||
参数:
|
||
df: pd.DataFrame, 股票的价格数据,行索引为日期,列为股票代码。
|
||
short_window: int, 短期RSI的窗口期。
|
||
long_window: int, 长期RSI的窗口期。
|
||
overbought: int, 超买水平。
|
||
oversold: int, 超卖水平。
|
||
|
||
返回:
|
||
signal: pd.DataFrame, 每只股票的买卖信号,1 表示买入,0 表示卖出。
|
||
"""
|
||
delta = df.diff().fillna(0)
|
||
|
||
gain = (delta.where(delta > 0, 0).rolling(window=short_window).mean()).fillna(0)
|
||
loss = (-delta.where(delta < 0, 0).rolling(window=short_window).mean()).fillna(0)
|
||
|
||
short_rsi = (100 - (100 / (1 + (gain / loss)))).fillna(0)
|
||
|
||
long_gain = (delta.where(delta > 0, 0).rolling(window=long_window).mean()).fillna(0)
|
||
long_loss = (-delta.where(delta < 0, 0).rolling(window=long_window).mean()).fillna(0)
|
||
|
||
long_rsi = (100 - (100 / (1 + (long_gain / long_loss)))).fillna(0)
|
||
|
||
signal = pd.DataFrame(index=df.index, columns=df.columns)
|
||
|
||
for column in df.columns:
|
||
signal[column] = np.where((short_rsi[column] < 30) & (long_rsi[column] < 30) & (short_rsi[column] != 0) & (long_rsi[column] != 0), 1, 0)
|
||
signal[column] = np.where((short_rsi[column] > 70) & (long_rsi[column] > 70) & (short_rsi[column] != 0) & (long_rsi[column] != 0), 0, signal[column])
|
||
|
||
return signal.ffill().fillna(0)
|
||
|
||
|
||
async def storage_backtest_data(source_column_name, result, signal, stock_code, stock_data_series, short_window: int,
|
||
long_window: int, overbought: int = 70,
|
||
oversold: int = 30):
|
||
await init_tortoise()
|
||
|
||
# 要存储的字段列表
|
||
fields_to_store = [
|
||
'stock_code', 'strategy_name', 'stock_close_price', 'daily_price',
|
||
'price', 'returns', 'data_start_time', 'data_end_time',
|
||
'backtest_end_time', 'position', 'backtest_name', 'rf', 'total_return', 'cagr',
|
||
'max_drawdown', 'calmar', 'mtd', 'three_month',
|
||
'six_month', 'ytd', 'one_year', 'three_year',
|
||
'five_year', 'ten_year', 'incep', 'daily_sharpe',
|
||
'daily_sortino', 'daily_mean', 'daily_vol',
|
||
'daily_skew', 'daily_kurt', 'best_day', 'worst_day',
|
||
'monthly_sharpe', 'monthly_sortino', 'monthly_mean',
|
||
'monthly_vol', 'monthly_skew', 'monthly_kurt',
|
||
'best_month', 'worst_month', 'yearly_sharpe',
|
||
'yearly_sortino', 'yearly_mean', 'yearly_vol',
|
||
'yearly_skew', 'yearly_kurt', 'best_year', 'worst_year',
|
||
'avg_drawdown', 'avg_drawdown_days', 'avg_up_month',
|
||
'avg_down_month', 'win_year_perc', 'twelve_month_win_perc'
|
||
]
|
||
|
||
# 准备要存储的数据
|
||
data_to_store = {
|
||
'stock_code': stock_code,
|
||
'strategy_name': "RSI策略",
|
||
'stock_close_price': json.dumps(stock_data_series.fillna(0).rename_axis('time').reset_index().assign(
|
||
time=stock_data_series.index.strftime('%Y%m%d')).set_index('time').to_dict(orient='index')),
|
||
'daily_price': convert_pandas_to_json_serializable(result[source_column_name].daily_prices),
|
||
'price': convert_pandas_to_json_serializable(result[source_column_name].prices),
|
||
'returns': convert_pandas_to_json_serializable(result[source_column_name].returns.fillna(0)),
|
||
'data_start_time': pd.to_datetime(result.stats.loc["start"].iloc[0]).strftime('%Y%m%d'),
|
||
'data_end_time': pd.to_datetime(result.stats.loc["end"].iloc[0]).strftime('%Y%m%d'),
|
||
'backtest_end_time': int(datetime.now().strftime('%Y%m%d')),
|
||
'position': convert_pandas_to_json_serializable(signal),
|
||
'backtest_name': f'{stock_code} RSI策略 RSI{short_window}-RSI{long_window}-overbought{overbought}-oversold{oversold}',
|
||
'indicator_type': 'RSI',
|
||
'indicator_information': json.dumps(
|
||
{'short_window': short_window, 'long_window': long_window, 'overbought': overbought, 'oversold': oversold})
|
||
}
|
||
|
||
# 使用循环填充其他字段
|
||
for field in fields_to_store[12:]: # 从第12个字段开始
|
||
value = result.stats.loc[field].iloc[0]
|
||
|
||
if isinstance(value, float):
|
||
if np.isnan(value):
|
||
data_to_store[field] = 0.0 # NaN 处理为 0
|
||
elif np.isinf(value): # 判断是否为无穷大或无穷小
|
||
if value > 0:
|
||
data_to_store[field] = 99999.9999 # 正无穷处理
|
||
else:
|
||
data_to_store[field] = -99999.9999 # 负无穷处理
|
||
else:
|
||
data_to_store[field] = value # 正常的浮点值
|
||
else:
|
||
data_to_store[field] = value # 非浮点类型保持不变
|
||
|
||
# 检查是否存在该 backtest_name
|
||
existing_record = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
|
||
backtest_name=data_to_store['backtest_name']
|
||
).first()
|
||
|
||
if existing_record:
|
||
# 如果存在,更新记录
|
||
await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
|
||
id=existing_record.id
|
||
).update(**data_to_store)
|
||
else:
|
||
# 如果不存在,创建新的记录
|
||
await wance_data_storage_backtest.WanceDataStorageBacktest.create(**data_to_store)
|
||
|
||
return data_to_store
|
||
|
||
|
||
async def run_rsi_backtest(field_list: list,
|
||
stock_list: list,
|
||
period: str = '1d',
|
||
start_time: str = '',
|
||
end_time: str = '',
|
||
count: int = 100,
|
||
dividend_type: str = 'none',
|
||
fill_data: bool = True,
|
||
data_dir: str = '',
|
||
short_window: int = 50,
|
||
long_window: int = 200,
|
||
overbought: int = 70,
|
||
oversold: int = 30
|
||
):
|
||
try:
|
||
# 初始化一个列表用于存储每只股票的回测结果字典
|
||
results_list = []
|
||
|
||
# 遍历每只股票的数据(每列代表一个股票的收盘价)
|
||
data = await get_local_data(field_list, stock_list, period, start_time, end_time, count, dividend_type,
|
||
fill_data,
|
||
data_dir)
|
||
|
||
for stock_code in stock_list:
|
||
|
||
data_column_name = f'close_{stock_code}'
|
||
source_column_name = f'{stock_code} RSI策略'
|
||
backtest_name = f'{stock_code} RSI策略 RSI{short_window}-RSI{long_window}'
|
||
now_data = int(datetime.now().strftime('%Y%m%d'))
|
||
db_result_data = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
|
||
backtest_name=backtest_name)
|
||
|
||
if db_result_data:
|
||
if db_result_data[0].backtest_end_time == now_data:
|
||
results_list.append({source_column_name: db_result_data[0]})
|
||
|
||
# elif data_column_name in data.columns:
|
||
if data_column_name in data.columns:
|
||
stock_data_series = data[[data_column_name]] # 提取该股票的收盘价 DataFrame
|
||
stock_data_series.columns = ['close'] # 重命名列为 'close'
|
||
|
||
# 创建RSI策略
|
||
strategy, signal = await create_dual_ma_strategy(stock_data_series, stock_code,
|
||
short_window=short_window, long_window=long_window)
|
||
# 创建回测
|
||
backtest = bt.Backtest(strategy=strategy, data=stock_data_series, initial_capital=100000)
|
||
# 运行回测
|
||
result = bt.run(backtest)
|
||
# 存储回测结果
|
||
data_to_store = await storage_backtest_data(source_column_name, result, signal, stock_code,
|
||
stock_data_series, short_window, long_window, overbought,
|
||
oversold)
|
||
# # 绘制回测结果图表
|
||
# result.plot()
|
||
# # 绘制个别股票数据图表
|
||
# plt.figure(figsize=(12, 6))
|
||
# plt.plot(stock_data_series.index, stock_data_series['close'], label='Stock Price')
|
||
# plt.title(f'Stock Price for {stock_code}')
|
||
# plt.xlabel('Date')
|
||
# plt.ylabel('Price')
|
||
# plt.legend()
|
||
# plt.grid(True)
|
||
# plt.show()
|
||
# 将结果存储为字典并添加到列表中
|
||
results_list.append({source_column_name: data_to_store})
|
||
|
||
else:
|
||
print(f"数据中缺少列: {data_column_name}")
|
||
|
||
return results_list # 返回结果列表
|
||
|
||
except Exception as e:
|
||
print(f"Error occurred: {e}")
|
||
|
||
|
||
async def start_rsi_backtest_service(field_list: list,
|
||
stock_list: list,
|
||
period: str = '1d',
|
||
start_time: str = '',
|
||
end_time: str = '',
|
||
count: int = -1,
|
||
dividend_type: str = 'none',
|
||
fill_data: bool = True,
|
||
data_dir: str = '',
|
||
short_window: int = 50,
|
||
long_window: int = 200,
|
||
overbought: int = 70,
|
||
oversold: int = 30
|
||
):
|
||
for stock_code in stock_list:
|
||
backtest_name = f'{stock_code} RSI策略 RSI{short_window}-RSI{long_window}'
|
||
db_result = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
|
||
backtest_name=backtest_name)
|
||
now_time = int(datetime.now().strftime('%Y%m%d'))
|
||
|
||
if db_result and db_result[0].backtest_end_time == now_time:
|
||
return db_result
|
||
else:
|
||
# 执行回测
|
||
result = await run_rsi_backtest(
|
||
field_list=field_list,
|
||
stock_list=stock_list,
|
||
period=period,
|
||
start_time=start_time,
|
||
end_time=end_time,
|
||
count=count,
|
||
dividend_type=dividend_type,
|
||
fill_data=fill_data,
|
||
data_dir=data_dir,
|
||
short_window=short_window,
|
||
long_window=long_window,
|
||
overbought=overbought,
|
||
oversold=oversold
|
||
)
|
||
return result
|
||
|
||
|
||
async def init_backtest_db():
|
||
sma_list = [{"short_window": 3, "long_window": 6}, {"short_window": 6, "long_window": 12},
|
||
{"short_window": 12, "long_window": 24}, {"short_window": 14, "long_window": 18},
|
||
{"short_window": 15, "long_window": 10}]
|
||
await init_tortoise()
|
||
wance_db = await wance_data_stock.WanceDataStock.all()
|
||
sma_list_lenght = len(sma_list)
|
||
|
||
for stock_code in wance_db:
|
||
for i in range(sma_list_lenght):
|
||
short_window = sma_list[i]['short_window']
|
||
long_window = sma_list[i]['long_window']
|
||
source_column_name = f'{stock_code} RSI策略 RSI{short_window}-RSI{long_window}'
|
||
result = await start_rsi_backtest_service(field_list=['close', 'time'],
|
||
stock_list=[stock_code.stock_code],
|
||
short_window=short_window,
|
||
long_window=long_window,
|
||
overbought=70,
|
||
oversold=30)
|
||
|
||
print(f"回测成功 {source_column_name}")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 测试类的回测
|
||
asyncio.run(run_rsi_backtest(field_list=['close', 'time'],
|
||
stock_list=['601222.SH', '601677.SH'],
|
||
count=-1,
|
||
short_window=10,
|
||
long_window=30,
|
||
overbought=70,
|
||
oversold=30
|
||
))
|
||
|
||
# # 初始化数据库表
|
||
# asyncio.run(init_backtest_db())
|