diff --git a/.env b/.env
new file mode 100644
index 0000000..42bf6bc
--- /dev/null
+++ b/.env
@@ -0,0 +1,23 @@
+;使用本地docker数据库就将这段代码解开,如何将IP换成本机IP或者localhost
+; DATABASE_URL="mysql://root:123456@10.1.5.219:3306/ftrade"
+; DATABASE_CREATE_URL ="mysql://root:123456@10.1.5.219:3306/ftrade"
+
+DATABASE_URL="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche"
+DATABASE_CREATE_URL ="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche"
+
+REDIS_URL="redis://:redis_tQNjCH@cqxqg.tech:6380"
+
+SITE_DOMAIN=127.0.0.1
+SECURE_COOKIES=false
+
+ENVIRONMENT=LOCAL
+
+CORS_HEADERS=["*"]
+CORS_ORIGINS=["http://localhost:3000"]
+
+# postgres variables, must be the same as in DATABASE_URL
+POSTGRES_USER=app
+POSTGRES_PASSWORD=app
+POSTGRES_HOST=app_db
+POSTGRES_PORT=5432
+POSTGRES_DB=app
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..35410ca
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/.name b/.idea/.name
new file mode 100644
index 0000000..7bfcf52
--- /dev/null
+++ b/.idea/.name
@@ -0,0 +1 @@
+wance_data
\ No newline at end of file
diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml
new file mode 100644
index 0000000..1e41358
--- /dev/null
+++ b/.idea/dataSources.xml
@@ -0,0 +1,12 @@
+
+
+
+
+ mysql.8
+ true
+ com.mysql.cj.jdbc.Driver
+ jdbc:mysql://lawyer5.cn:3308/wangche
+ $ProjectFileDir$
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..dd1720e
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,146 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..39831c4
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..4c5c71c
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/other.xml b/.idea/other.xml
new file mode 100644
index 0000000..2e75c2e
--- /dev/null
+++ b/.idea/other.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/sqldialects.xml b/.idea/sqldialects.xml
new file mode 100644
index 0000000..6743fb4
--- /dev/null
+++ b/.idea/sqldialects.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/wance_data.iml b/.idea/wance_data.iml
new file mode 100644
index 0000000..d33acca
--- /dev/null
+++ b/.idea/wance_data.iml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/banner.txt b/banner.txt
new file mode 100644
index 0000000..d37d152
--- /dev/null
+++ b/banner.txt
@@ -0,0 +1,19 @@
+ # ## # # # ## ## # # # # # # # # # # # #
+ ## ## # ## ## ## # ########## # ## # # #### ## ## ## ## # ############## ### ## #############
+ ## ############# ## ## ## # ## # ############## # #### # ## ## ## # ############### ## ## ## ## ## ## ##
+ ## ## # ## ## ## # ## ############### ## # ## ## ## ## ## ########## ## ## ## ## ## ## ## # ##
+ # ### ########### ############### # # ## # ## ## # ## #### ###### ## ## # ## ## ####### ## ## ######### ##
+ ## ## # ## # ## ## ## ## ########### ## ## # ## ###### # ## ## ## # ########### ## # ## ## ## ### ##
+ ## ## ## ############### ### ## ## # ## ## ## ## ########### # ## # ## ## ######## ## ## ## ### ## ## ## # ### # # ##
+ ## ## ## # # #### ## ## ## ########## ## ## ## #### ## ## ## # # ## # ## ## ## # ####### ## ### ### ######## ##
+ ## ## ### ########## ##### ###### ## ## ## ## ## ## ## ##### ## ## # ### ## ## ######## ## ## ### # # ## # ## ## ##
+ ## ## ### ## ## # ## # ## ## # ########## ## ## # ### # # ##### ### # ## ## ## ## ## ## # # ## ## ## ## ##
+ ## ## ## ######### # ## ## ## ## # ## # ## ## # ## # ## #### #### #### ## ## ## ## ## ## ## ## ## ## ## ##
+ # ## # ## ## ## ## ## # ## # ## ## # # # ## ### ## # ## ## ######## ##### ## ### ## ## ####### ##
+ ## ######### ## ## ## ## ########### # ## ## ## ## ## ## #### ## ## ## # ## ### ## ### # # ##
+ ## ## ## ## ## ## # ## # # ## # ## ## ## ## ### ## ## ## #### ### ########### ##
+ #### ## #### ## ###### # ## # # # ### # ## ## ### ## ### ## ### ## #### ### ##### ## ####
+ # # # # # # ############### # # # # # # # # # # ## # # # # #
+==========================================================================================================================================================================================================================
+*********************************************************************************************聪明帅气可爱睿智的老板保佑我永不报错************************************************************************************************
+==========================================================================================================================================================================================================================
diff --git a/migrations/models/3_20240909094115_None.py b/migrations/models/3_20240909094115_None.py
new file mode 100644
index 0000000..98b6f97
--- /dev/null
+++ b/migrations/models/3_20240909094115_None.py
@@ -0,0 +1,239 @@
+from tortoise import BaseDBAsyncClient
+
+
+async def upgrade(db: BaseDBAsyncClient) -> str:
+ return """
+ CREATE TABLE IF NOT EXISTS `users` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
+ `nickname` VARCHAR(30) COMMENT '用户昵称',
+ `avatar_url` VARCHAR(255) COMMENT '头像',
+ `member_type` INT COMMENT '会员类型',
+ `beta_account_type` VARCHAR(30) COMMENT '内测账号类型',
+ `pre_cost_time` INT COMMENT '预支付时间(单位年)',
+ `qr_code` VARCHAR(255) COMMENT '专属客服二维码',
+ `dedicated_id` INT COMMENT '专属客服id',
+ `invited_user_id` INT COMMENT '邀请人id',
+ `created_user_id` INT COMMENT '创建人id',
+ `is_deleted` BOOL NOT NULL COMMENT '是否删除' DEFAULT 0,
+ `login_at` DATETIME(6) COMMENT '最后一次登录时间',
+ `created_at` DATETIME(6) NOT NULL COMMENT '创建时间' DEFAULT CURRENT_TIMESTAMP(6),
+ `updated_at` DATETIME(6) NOT NULL COMMENT '修改时间' DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
+ `deleted_at` DATETIME(6) COMMENT '删除时间',
+ KEY `idx_users_is_dele_9cdc79` (`is_deleted`)
+) CHARACTER SET utf8mb4 COMMENT='用户';
+CREATE TABLE IF NOT EXISTS `stock` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
+ `stock_code` VARCHAR(30) NOT NULL COMMENT '股票代码',
+ `stock_name` VARCHAR(30) COMMENT '股票名称',
+ `type` VARCHAR(2) COMMENT '类型',
+ `stock_pinyin` VARCHAR(30) NOT NULL COMMENT '股票拼音'
+) CHARACTER SET utf8mb4 COMMENT='股票相关信息';
+CREATE TABLE IF NOT EXISTS `security_account` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
+ `securities_name` VARCHAR(30) NOT NULL COMMENT '证券公司名字昵称',
+ `fund_account` BIGINT COMMENT '资金账户',
+ `account_alias` VARCHAR(30) NOT NULL COMMENT '账户别名',
+ `money` DOUBLE COMMENT '账户金额',
+ `available_money` DOUBLE COMMENT '可用金额',
+ `available_proportion` DOUBLE COMMENT '可用资金占比',
+ `freeze` DOUBLE COMMENT '冻结金额'
+) CHARACTER SET utf8mb4 COMMENT='证券账户';
+CREATE TABLE IF NOT EXISTS `backtest` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '主键id',
+ `key` VARCHAR(30) NOT NULL COMMENT '回测key'
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `Entrust` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '主键id',
+ `fund_account` INT COMMENT '资金账户',
+ `securities_alias` VARCHAR(30) COMMENT '账户别名',
+ `account_type` SMALLINT COMMENT '账户类型',
+ `stock_code` VARCHAR(30) NOT NULL COMMENT '证券代码',
+ `stock_name` VARCHAR(30) NOT NULL COMMENT '证券名称',
+ `limit_price` DOUBLE NOT NULL COMMENT '委托价',
+ `entrust_number` INT NOT NULL COMMENT '委托数量',
+ `deal_price` DOUBLE COMMENT '成交价格',
+ `deal_number` INT COMMENT '成家数量',
+ `order_type` VARCHAR(30) NOT NULL COMMENT '操作',
+ `entrust_date` DATE COMMENT '委托日期',
+ `entrust_money` DOUBLE NOT NULL COMMENT '委托金额',
+ `is_repair` BOOL NOT NULL COMMENT '是否补单' DEFAULT 0,
+ `entrust_time` TIME(6),
+ KEY `idx_Entrust_is_repa_5a195d` (`is_repair`)
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `orders` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '订单ID',
+ `stock_code` VARCHAR(30) NOT NULL COMMENT '股票代码',
+ `stock_name` VARCHAR(30) NOT NULL COMMENT '股票名称',
+ `limit_price` DOUBLE NOT NULL COMMENT '限价',
+ `order_quantity` INT NOT NULL COMMENT '委托数量',
+ `order_amount` DOUBLE NOT NULL COMMENT '委托金额',
+ `order_type` VARCHAR(20) NOT NULL COMMENT '订单类型',
+ `position` VARCHAR(30) NOT NULL COMMENT '仓位',
+ `user_id` INT NOT NULL COMMENT '用户Id',
+ `entrust_date` DATETIME(6) COMMENT '委托日期',
+ `entrust_time` TIME(6)
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `snowball` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
+ `snowball_token` VARCHAR(10000) COMMENT '雪球用户的token'
+) CHARACTER SET utf8mb4 COMMENT='雪球相关信息';
+CREATE TABLE IF NOT EXISTS `Strategy` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
+ `strategy_name` VARCHAR(255) NOT NULL COMMENT '策略名称',
+ `strategy_hash` VARCHAR(255) NOT NULL COMMENT '策略版本号',
+ `strategy_type` VARCHAR(255) COMMENT '策略类型',
+ `user_id` INT COMMENT '所属用户',
+ `backtest_count` INT COMMENT '回测次数',
+ `backtest_keys` JSON COMMENT '回测key列表',
+ `is_deleted` BOOL NOT NULL COMMENT '是否删除' DEFAULT 0,
+ `created_at` DATETIME(6) NOT NULL COMMENT '创建时间' DEFAULT CURRENT_TIMESTAMP(6),
+ `updated_at` DATETIME(6) NOT NULL COMMENT '修改时间' DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
+ `deleted_at` DATETIME(6) COMMENT '删除时间'
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `Backtest` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `key` VARCHAR(20) NOT NULL UNIQUE COMMENT '回测key',
+ `user_id` INT COMMENT '回测用户',
+ `backtest_at` DATETIME(6) NOT NULL COMMENT '回测时间' DEFAULT CURRENT_TIMESTAMP(6),
+ `backtest_code` LONGTEXT NOT NULL COMMENT '回测代码',
+ `is_running` BOOL NOT NULL COMMENT '回测状态' DEFAULT 1,
+ `updated_at` DATETIME(6) NOT NULL COMMENT '修改时间' DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
+ `deleted_at` DATETIME(6) COMMENT '删除时间',
+ `is_active` BOOL NOT NULL COMMENT '是否可用' DEFAULT 1,
+ `strategy_id` INT NOT NULL,
+ CONSTRAINT `fk_Backtest_Strategy_7eabb20d` FOREIGN KEY (`strategy_id`) REFERENCES `Strategy` (`id`) ON DELETE CASCADE,
+ KEY `idx_Backtest_key_eb92b1` (`key`),
+ KEY `idx_Backtest_strateg_83ac3d` (`strategy_id`)
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `stock_details` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
+ `stock_code` VARCHAR(30) NOT NULL COMMENT '股票代码',
+ `stock_name` VARCHAR(30) COMMENT '股票名称',
+ `type` VARCHAR(2) COMMENT '类型',
+ `stock_pinyin` VARCHAR(30) NOT NULL COMMENT '股票拼音',
+ `latest_price` DOUBLE COMMENT '最新价',
+ `rise_fall` DOUBLE COMMENT '跌涨幅'
+) CHARACTER SET utf8mb4 COMMENT='股票相关信息';
+CREATE TABLE IF NOT EXISTS `Transaction` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
+ `key` VARCHAR(20) NOT NULL COMMENT 'key,数据标识',
+ `cash` DOUBLE COMMENT '资金',
+ `transaction_name` VARCHAR(255) NOT NULL COMMENT '交易名称',
+ `transaction_type` VARCHAR(255) COMMENT '交易类型',
+ `user_id` INT COMMENT '用户id',
+ `is_running` BOOL NOT NULL COMMENT '运行状态' DEFAULT 1,
+ `is_deleted` BOOL NOT NULL COMMENT '是否删除' DEFAULT 0,
+ `process_id` INT COMMENT '进程号',
+ `bar` VARCHAR(10) COMMENT '频率K线',
+ `created_at` DATETIME(6) NOT NULL COMMENT '创建时间' DEFAULT CURRENT_TIMESTAMP(6),
+ `updated_at` DATETIME(6) NOT NULL COMMENT '修改时间' DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
+ `stopped_at` DATETIME(6) COMMENT '停止时间',
+ `deleted_at` DATETIME(6) COMMENT '删除时间',
+ `strategy_id` INT NOT NULL,
+ CONSTRAINT `fk_Transact_Strategy_849d0577` FOREIGN KEY (`strategy_id`) REFERENCES `Strategy` (`id`) ON DELETE CASCADE,
+ KEY `idx_Transaction_strateg_b8f0b2` (`strategy_id`)
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `position` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
+ `key` VARCHAR(30) NOT NULL COMMENT '键',
+ `date` DATE NOT NULL COMMENT '日期',
+ `name` VARCHAR(30) NOT NULL COMMENT '名称',
+ `size` INT NOT NULL COMMENT '数量',
+ `price` DOUBLE NOT NULL COMMENT '价格',
+ `adjbase` DOUBLE NOT NULL COMMENT '复权基数',
+ `profit` DOUBLE NOT NULL COMMENT '利润'
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `trand_info` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT 'id',
+ `key` VARCHAR(255) COMMENT '唯一索引',
+ `tran_info_data` LONGBLOB
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `tran_observer_data` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `key` VARCHAR(255) NOT NULL,
+ `tran_observer_data` LONGBLOB COMMENT '存储大量数据'
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `tranorders` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `key` VARCHAR(255),
+ `order_return` VARCHAR(255)
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `tran_return` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `key` VARCHAR(255),
+ `tran_return_data` LONGBLOB
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `tran_trade_info` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `key` VARCHAR(25),
+ `tran_trade_info` LONGBLOB
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `back_observed_data` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT 'id表',
+ `key` VARCHAR(255) COMMENT 'key',
+ `observed_data` LONGBLOB COMMENT '格式化后的json数据'
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `back_observed_data_detail` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `key` VARCHAR(255) NOT NULL,
+ `back_observed_data` LONGBLOB
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `back_position` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `key` VARCHAR(25),
+ `back_position_data` LONGBLOB
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `back_result_indicator` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `key` VARCHAR(25),
+ `indicator` LONGBLOB
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `back_trand_info` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `key` VARCHAR(25),
+ `trade_info` LONGBLOB
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `tran_position` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `key` VARCHAR(25),
+ `tran_position_data` LONGBLOB
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `stock_bt_history` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `end_bt_time` VARCHAR(8) COMMENT '回测最终时间',
+ `bt_stock_code` VARCHAR(8) COMMENT '回测股票代码',
+ `bt_stock_name` VARCHAR(10) COMMENT '回测股票名称',
+ `bt_benchmark_code` VARCHAR(8) COMMENT '股票基准代码',
+ `bt_stock_period` VARCHAR(10) COMMENT '回测类型',
+ `bt_strategy_name` VARCHAR(10) COMMENT '回测策略名称',
+ `bt_stock_data` LONGBLOB COMMENT '回测股票数据'
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `stock_history` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `stock_code` INT NOT NULL COMMENT '股票代码',
+ `stock_name` VARCHAR(10),
+ `start_time_to_market` VARCHAR(10) COMMENT '股票上市时间',
+ `end_bt_time` VARCHAR(10) COMMENT '最终回测时间',
+ `symbol_data` LONGBLOB COMMENT '股票数据'
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `stock_data_processing` (
+ `bt_benchmark_code` VARCHAR(6) COMMENT '基准代码',
+ `bt_stock_period` VARCHAR(10) COMMENT '数据类型',
+ `bt_strategy_name` VARCHAR(10) COMMENT '回测策略名',
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `processing_data` LONGBLOB COMMENT '清洗后的数据',
+ `prosessing_date` VARCHAR(10) COMMENT '当前回测时间',
+ `stock_name` VARCHAR(10),
+ `stocke_code` VARCHAR(10)
+) CHARACTER SET utf8mb4;
+CREATE TABLE IF NOT EXISTS `aerich` (
+ `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
+ `version` VARCHAR(255) NOT NULL,
+ `app` VARCHAR(100) NOT NULL,
+ `content` JSON NOT NULL
+) CHARACTER SET utf8mb4;"""
+
+
+async def downgrade(db: BaseDBAsyncClient) -> str:
+ return """
+ """
diff --git a/models.py b/models.py
new file mode 100644
index 0000000..4b13f38
Binary files /dev/null and b/models.py differ
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..62a63ce
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,4 @@
+[tool.aerich]
+tortoise_orm = "src.tortoises.TORTOISE_ORM"
+location = "./migrations"
+src_folder = "./."
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..8ab04ac
--- /dev/null
+++ b/src/__init__.py
@@ -0,0 +1,4 @@
+from dotenv import load_dotenv
+
+# load_dotenv(dotenv_path=Path("../.env"))
+load_dotenv()
\ No newline at end of file
diff --git a/src/__pycache__/__init__.cpython-311.pyc b/src/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..7d636cc
Binary files /dev/null and b/src/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/__pycache__/constants.cpython-311.pyc b/src/__pycache__/constants.cpython-311.pyc
new file mode 100644
index 0000000..7bcada3
Binary files /dev/null and b/src/__pycache__/constants.cpython-311.pyc differ
diff --git a/src/__pycache__/exceptions.cpython-311.pyc b/src/__pycache__/exceptions.cpython-311.pyc
new file mode 100644
index 0000000..36d44e7
Binary files /dev/null and b/src/__pycache__/exceptions.cpython-311.pyc differ
diff --git a/src/__pycache__/main.cpython-311.pyc b/src/__pycache__/main.cpython-311.pyc
new file mode 100644
index 0000000..251626c
Binary files /dev/null and b/src/__pycache__/main.cpython-311.pyc differ
diff --git a/src/__pycache__/responses.cpython-311.pyc b/src/__pycache__/responses.cpython-311.pyc
new file mode 100644
index 0000000..497bb68
Binary files /dev/null and b/src/__pycache__/responses.cpython-311.pyc differ
diff --git a/src/__pycache__/tortoises.cpython-311.pyc b/src/__pycache__/tortoises.cpython-311.pyc
new file mode 100644
index 0000000..07b7731
Binary files /dev/null and b/src/__pycache__/tortoises.cpython-311.pyc differ
diff --git a/src/__pycache__/tortoises_orm_config.cpython-311.pyc b/src/__pycache__/tortoises_orm_config.cpython-311.pyc
new file mode 100644
index 0000000..d4e3f0f
Binary files /dev/null and b/src/__pycache__/tortoises_orm_config.cpython-311.pyc differ
diff --git a/src/backtest/__pycache__/bollinger_bands.cpython-311.pyc b/src/backtest/__pycache__/bollinger_bands.cpython-311.pyc
new file mode 100644
index 0000000..df0bd0a
Binary files /dev/null and b/src/backtest/__pycache__/bollinger_bands.cpython-311.pyc differ
diff --git a/src/backtest/__pycache__/dual_moving_average.cpython-311.pyc b/src/backtest/__pycache__/dual_moving_average.cpython-311.pyc
new file mode 100644
index 0000000..4ae1978
Binary files /dev/null and b/src/backtest/__pycache__/dual_moving_average.cpython-311.pyc differ
diff --git a/src/backtest/__pycache__/reverse_dual_ma_strategy.cpython-311.pyc b/src/backtest/__pycache__/reverse_dual_ma_strategy.cpython-311.pyc
new file mode 100644
index 0000000..7173c25
Binary files /dev/null and b/src/backtest/__pycache__/reverse_dual_ma_strategy.cpython-311.pyc differ
diff --git a/src/backtest/__pycache__/router.cpython-311.pyc b/src/backtest/__pycache__/router.cpython-311.pyc
new file mode 100644
index 0000000..5dc9f4e
Binary files /dev/null and b/src/backtest/__pycache__/router.cpython-311.pyc differ
diff --git a/src/backtest/__pycache__/rsi_strategy.cpython-311.pyc b/src/backtest/__pycache__/rsi_strategy.cpython-311.pyc
new file mode 100644
index 0000000..8e81093
Binary files /dev/null and b/src/backtest/__pycache__/rsi_strategy.cpython-311.pyc differ
diff --git a/src/backtest/__pycache__/service.cpython-311.pyc b/src/backtest/__pycache__/service.cpython-311.pyc
new file mode 100644
index 0000000..ce8746d
Binary files /dev/null and b/src/backtest/__pycache__/service.cpython-311.pyc differ
diff --git a/src/backtest/__pycache__/until.cpython-311.pyc b/src/backtest/__pycache__/until.cpython-311.pyc
new file mode 100644
index 0000000..b62941e
Binary files /dev/null and b/src/backtest/__pycache__/until.cpython-311.pyc differ
diff --git a/src/backtest/backtest.py b/src/backtest/backtest.py
new file mode 100644
index 0000000..bf5ef49
--- /dev/null
+++ b/src/backtest/backtest.py
@@ -0,0 +1,106 @@
+import bt
+import pandas as pd
+import numpy as np
+from xtquant import xtdata
+import matplotlib.pyplot as plt
+
+# 数据的列名
+columns = ['open', 'high', 'low', 'close', 'volume', 'amount', 'settelmentPrice',
+ 'openInterest', 'preClose', 'suspendFlag']
+
+
+# 获取本地数据并进行处理
+def get_local_data(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
+ count: int, dividend_type: str, fill_data: bool, data_dir: str):
+ result = xtdata.get_local_data(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
+ end_time=end_time, count=count, dividend_type=dividend_type, fill_data=fill_data,
+ data_dir=data_dir)
+ return data_processing(result)
+
+
+# 数据处理函数
+def data_processing(result_local):
+ # 初始化一个空的列表,用于存储每个股票的数据框
+ df_list = []
+
+ # 遍历字典中的 DataFrame
+ for stock_code, df in result_local.items():
+ # 确保 df 是一个 DataFrame
+ if isinstance(df, pd.DataFrame):
+ # 将时间戳转换为日期时间格式,只保留年-月-日
+ df['time'] = pd.to_datetime(df['time'], unit='ms').dt.date
+ # 将 'time' 列设置为索引,保留为日期格式
+ df.set_index('time', inplace=True)
+ # 指定列名
+ df.columns = columns
+ # 添加一列 'stock_code' 用于标识不同的股票
+ df['stock_code'] = stock_code
+ # 将 DataFrame 添加到列表中
+ df_list.append(df[['close']]) # 只保留 'close' 列
+ else:
+ print(f"数据格式错误: {stock_code} 不包含 DataFrame")
+
+ # 使用 pd.concat() 将所有 DataFrame 合并为一个大的 DataFrame
+ combined_df = pd.concat(df_list, axis=1)
+
+ # 确保返回的 DataFrame 索引是日期格式
+ combined_df.index = pd.to_datetime(combined_df.index)
+
+ # 打印最终的 DataFrame
+ print(combined_df)
+
+ return combined_df
+
+
+# 定义策略
+def moving_average_strategy(data, short_window=20, long_window=50):
+ data['Short_MA'] = data['close'].rolling(window=short_window, min_periods=1).mean()
+ data['Long_MA'] = data['close'].rolling(window=long_window, min_periods=1).mean()
+
+ data['Signal'] = 0
+ data.loc[data.index[short_window:], 'Signal'] = np.where(
+ data['Short_MA'][short_window:] > data['Long_MA'][short_window:], 1, 0)
+ data['Position'] = data['Signal'].diff()
+
+ return data
+
+
+# 定义策略函数
+def bt_strategy(data):
+ # 计算策略信号
+ # data = moving_average_strategy(data)
+
+ # 定义策略
+ dual_ma_strategy = bt.Strategy('Dual MA Strategy', [bt.algos.RunOnce(),
+ bt.algos.SelectAll(),
+ bt.algos.WeighEqually(),
+ bt.algos.Rebalance()])
+ return dual_ma_strategy
+
+
+# 运行回测
+def run_backtest():
+ # 生成数据
+ data = get_local_data(field_list=[], stock_list=["300391.SZ"], period='1d', start_time='', end_time='', count=-1,
+ dividend_type='none', fill_data=True, data_dir="")
+
+ # 创建策略
+ strategy = bt_strategy(data)
+
+ # 创建回测
+ portfolio = bt.Backtest(strategy, data)
+ result = bt.run(portfolio)
+
+ return result
+
+
+# 执行回测并显示结果
+if __name__ == "__main__":
+ result = run_backtest()
+ result.plot()
+ b = xtdata.get_sector_list()
+ print(b)
+ xtdata.download_sector_data()
+ a = result.stats
+ print(a)
+ plt.show()
diff --git a/src/backtest/bollinger.py b/src/backtest/bollinger.py
new file mode 100644
index 0000000..f6e020f
--- /dev/null
+++ b/src/backtest/bollinger.py
@@ -0,0 +1,79 @@
+import bt
+
+# 模拟从前端接收到的因子列表
+factors = [
+ {'name': 'SMA', 'period': 50},
+ {'name': 'EMA', 'period': 200},
+ {'name': 'RSI', 'period': 14},
+ {'name': 'BollingerBands', 'period': 20}
+]
+
+
+def create_indicator(name, **kwargs):
+ """
+ 根据名称和参数动态创建bt的技术指标。
+ """
+ if name == 'SMA':
+ return bt.indicators.SMA(kwargs['period'])
+ elif name == 'EMA':
+ return bt.indicators.EMA(kwargs['period'])
+ elif name == 'RSI':
+ return bt.indicators.RSI(kwargs['period'])
+ elif name == 'BollingerBands':
+ return bt.indicators.BollingerBands(kwargs['period'])
+
+ else:
+ raise ValueError(f"未知的指标名称: {name}")
+
+
+def create_dynamic_strategy(name, factors):
+ """
+ 根据传递的因子列表动态构建策略。
+ """
+ algos = [bt.algos.RunMonthly(), # 每月运行一次
+ bt.algos.SelectAll()] # 选择所有资产
+
+ # 动态生成指标选择逻辑
+ for factor in factors:
+ indicator = create_indicator(factor['name'], **factor)
+ if factor['name'] == 'RSI':
+ # RSI 特定的买入/卖出逻辑
+ buy_signal = bt.algos.SelectWhere(indicator < 30) # 超卖区域买入
+ sell_signal = bt.algos.SelectWhere(indicator > 70) # 超买区域卖出
+ algos.append(buy_signal)
+ algos.append(bt.algos.WeighTarget(1.0))
+ algos.append(bt.algos.Rebalance())
+ algos.append(sell_signal)
+ algos.append(bt.algos.WeighTarget(0.0))
+ algos.append(bt.algos.Rebalance())
+ elif factor['name'] == 'BollingerBands':
+ # 布林带特定的逻辑(买入接近下轨,卖出接近上轨)
+ buy_signal = bt.algos.SelectWhere(bt.data < indicator['lower'])
+ sell_signal = bt.algos.SelectWhere(bt.data > indicator['upper'])
+ algos.append(buy_signal)
+ algos.append(bt.algos.WeighTarget(1.0))
+ algos.append(bt.algos.Rebalance())
+ algos.append(sell_signal)
+ algos.append(bt.algos.WeighTarget(0.0))
+ algos.append(bt.algos.Rebalance())
+ # 可以为其他指标添加更多逻辑
+
+ return bt.Strategy(name, algos)
+
+
+# 创建动态策略
+dynamic_strategy = create_dynamic_strategy("Dynamic_Strategy", factors)
+
+# 获取数据
+data = bt.get('spy,agg', start='2020-01-01', end='2023-01-01')
+
+# 创建回测
+backtest = bt.Backtest(dynamic_strategy, data)
+
+# 运行回测
+result = bt.run(backtest)
+
+# 展示结果
+result.display()
+result.plot()
+
diff --git a/src/backtest/bollinger_bands.py b/src/backtest/bollinger_bands.py
new file mode 100644
index 0000000..576e26d
--- /dev/null
+++ b/src/backtest/bollinger_bands.py
@@ -0,0 +1,274 @@
+import asyncio
+import json
+import time
+from datetime import datetime
+
+import bt
+import numpy as np
+import pandas as pd
+from xtquant import xtdata
+import matplotlib.pyplot as plt
+
+from src.backtest.until import get_local_data, convert_pandas_to_json_serializable
+from src.models import wance_data_storage_backtest, wance_data_stock
+from src.tortoises_orm_config import init_tortoise
+
+
+# 布林带策略函数
+async def create_bollinger_bands_strategy(data, stock_code: str, bollingerMA: int = 50, std_dev: int = 200):
+ # 生成布林带策略信号
+ signal = await bollinger_bands_strategy(data, bollingerMA, std_dev)
+
+ # 使用bt框架构建策略
+ strategy = bt.Strategy(f'{stock_code} 布林带策略',
+ [bt.algos.RunDaily(),
+ bt.algos.SelectAll(), # 选择所有股票
+ bt.algos.WeighTarget(signal), # 根据信号调整权重
+ bt.algos.Rebalance()]) # 调仓
+ return strategy, signal
+
+
+async def bollinger_bands_strategy(df, window=20, num_std_dev=2):
+ """
+ 基于布林带策略生成买卖信号。
+
+ 参数:
+ df: pd.DataFrame, 股票的价格数据,行索引为日期,列为股票代码。
+ window: int, 计算布林带中轨线的窗口期。
+ num_std_dev: float, 标准差的倍数,用于计算上下轨。
+
+ 返回:
+ signal: pd.DataFrame, 每只股票的买卖信号,1 表示买入,0 表示卖出。
+ """
+ # 计算中轨线(移动平均)
+ middle_band = df.rolling(window=window, min_periods=1).mean()
+
+ # 计算滚动标准差
+ rolling_std = df.rolling(window=window, min_periods=1).std()
+
+ # 计算上轨线和下轨线
+ upper_band = middle_band + (rolling_std * num_std_dev)
+ lower_band = middle_band - (rolling_std * num_std_dev)
+
+ # 初始化信号 DataFrame
+ signal = pd.DataFrame(index=df.index, columns=df.columns)
+
+ # 生成买入信号:当价格突破下轨时
+ for column in df.columns:
+ signal[column] = np.where(df[column] < lower_band[column], 1, np.nan) # 买入信号
+
+ # 生成卖出信号:当价格突破上轨时
+ for column in df.columns:
+ signal[column] = np.where(df[column] > upper_band[column], 0, signal[column]) # 卖出信号
+
+ # 前向填充信号,持仓不变
+ signal = signal.ffill()
+
+ # 将剩余的 NaN 替换为 0
+ signal = signal.fillna(0)
+
+ return signal
+
+
+async def storage_backtest_data(source_column_name, result, signal, stock_code, stock_data_series, bollingerMA,
+ std_dev):
+ await init_tortoise()
+
+ # 要存储的字段列表
+ fields_to_store = [
+ 'stock_code', 'strategy_name', 'stock_close_price', 'daily_price',
+ 'price', 'returns', 'data_start_time', 'data_end_time',
+ 'backtest_end_time', 'position', 'backtest_name', 'rf', 'total_return', 'cagr',
+ 'max_drawdown', 'calmar', 'mtd', 'three_month',
+ 'six_month', 'ytd', 'one_year', 'three_year',
+ 'five_year', 'ten_year', 'incep', 'daily_sharpe',
+ 'daily_sortino', 'daily_mean', 'daily_vol',
+ 'daily_skew', 'daily_kurt', 'best_day', 'worst_day',
+ 'monthly_sharpe', 'monthly_sortino', 'monthly_mean',
+ 'monthly_vol', 'monthly_skew', 'monthly_kurt',
+ 'best_month', 'worst_month', 'yearly_sharpe',
+ 'yearly_sortino', 'yearly_mean', 'yearly_vol',
+ 'yearly_skew', 'yearly_kurt', 'best_year', 'worst_year',
+ 'avg_drawdown', 'avg_drawdown_days', 'avg_up_month',
+ 'avg_down_month', 'win_year_perc', 'twelve_month_win_perc'
+ ]
+
+ # 准备要存储的数据
+ data_to_store = {
+ 'stock_code': stock_code,
+ 'strategy_name': "布林带策略",
+ 'stock_close_price': json.dumps(stock_data_series.fillna(0).rename_axis('time').reset_index().assign(
+ time=stock_data_series.index.strftime('%Y%m%d')).set_index('time').to_dict(orient='index')),
+ 'daily_price': convert_pandas_to_json_serializable(result[source_column_name].daily_prices),
+ 'price': convert_pandas_to_json_serializable(result[source_column_name].prices),
+ 'returns': convert_pandas_to_json_serializable(result[source_column_name].returns.fillna(0)),
+ 'data_start_time': pd.to_datetime(result.stats.loc["start"].iloc[0]).strftime('%Y%m%d'),
+ 'data_end_time': pd.to_datetime(result.stats.loc["end"].iloc[0]).strftime('%Y%m%d'),
+ 'backtest_end_time': int(datetime.now().strftime('%Y%m%d')),
+ 'position': convert_pandas_to_json_serializable(signal),
+ 'backtest_name': f'{stock_code} 布林带策略 MA{bollingerMA}-{std_dev}倍标准差',
+ 'indicator_type': 'Bollinger',
+ 'indicator_information': json.dumps({'bollingerMA': bollingerMA, 'std_dev': std_dev})
+ }
+
+ # 使用循环填充其他字段
+ for field in fields_to_store[12:]: # 从第10个字段开始
+ value = result.stats.loc[field].iloc[0]
+ data_to_store[field] = 0.0 if (isinstance(value, float) and np.isnan(value)) else value
+
+ # 检查是否存在该 backtest_name
+ existing_record = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=data_to_store['backtest_name']
+ ).first()
+
+ if existing_record:
+ # 如果存在,更新记录
+ await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ id=existing_record.id
+ ).update(**data_to_store)
+ else:
+ # 如果不存在,创建新的记录
+ await wance_data_storage_backtest.WanceDataStorageBacktest.create(**data_to_store)
+
+ return data_to_store
+
+
+async def run_bollinger_backtest(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = 100,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ bollingerMA: int = 50,
+ std_dev: int = 200):
+ try:
+ # 初始化一个列表用于存储每只股票的回测结果字典
+ results_list = []
+
+ # 遍历每只股票的数据(每列代表一个股票的收盘价)
+ data = await get_local_data(field_list, stock_list, period, start_time, end_time, count, dividend_type,
+ fill_data,
+ data_dir)
+
+ for stock_code in stock_list:
+
+ data_column_name = f'close_{stock_code}'
+ source_column_name = f'{stock_code} 布林带策略'
+ backtest_name = f'{stock_code} 布林带策略 MA{bollingerMA}-{std_dev}倍标准差'
+ now_time = int(datetime.now().strftime('%Y%m%d'))
+ db_result = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=backtest_name)
+ if db_result:
+ if db_result[0].backtest_end_time == now_time:
+ results_list.append({source_column_name: db_result[0]})
+
+ # elif data_column_name in data.columns:
+ if data_column_name in data.columns:
+ stock_data_series = data[[data_column_name]] # 提取该股票的收盘价 DataFrame
+ stock_data_series.columns = ['close'] # 重命名列为 'close'
+
+ # 创建布林带策略
+ strategy, signal = await create_bollinger_bands_strategy(stock_data_series, stock_code,
+ bollingerMA=bollingerMA,
+ std_dev=std_dev)
+ # 创建回测
+ backtest = bt.Backtest(strategy=strategy, data=stock_data_series, initial_capital=100000)
+ # 运行回测
+ result = bt.run(backtest)
+ # 存储回测结果
+ data_to_store = await storage_backtest_data(source_column_name, result, signal, stock_code,
+ stock_data_series,
+ bollingerMA, std_dev)
+ # # 绘制回测结果图表
+ # result.plot()
+ # # 绘制个别股票数据图表
+ # plt.figure(figsize=(12, 6))
+ # plt.plot(stock_data_series.index, stock_data_series['close'], label='Stock Price')
+ # plt.title(f'Stock Price for {stock_code}')
+ # plt.xlabel('Date')
+ # plt.ylabel('Price')
+ # plt.legend()
+ # plt.grid(True)
+ # plt.show()
+ # 将结果存储为字典并添加到列表中
+ results_list.append({source_column_name: data_to_store})
+
+ else:
+ print(f"数据中缺少列: {data_column_name}")
+
+ return results_list # 返回结果列表
+
+ except Exception as e:
+ print(f"Error occurred: {e}")
+
+
+async def start_bollinger_backtest_service(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = -1,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ bollingerMA: int = 50,
+ std_dev: int = 200):
+ for stock_code in stock_list:
+ backtest_name = f'{stock_code} 布林带策略 MA{bollingerMA}-{std_dev}倍标准差'
+ db_result = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=backtest_name)
+ now_time = int(datetime.now().strftime('%Y%m%d'))
+
+ if db_result and db_result[0].backtest_end_time == now_time:
+ return db_result
+ else:
+ # 执行回测
+ result = await run_bollinger_backtest(
+ field_list=field_list,
+ stock_list=stock_list,
+ period=period,
+ start_time=start_time,
+ end_time=end_time,
+ count=count,
+ dividend_type=dividend_type,
+ fill_data=fill_data,
+ data_dir=data_dir,
+ bollingerMA=bollingerMA,
+ std_dev=std_dev,
+ )
+ return result
+
+
+async def init_backtest_db():
+ bollinger_list = [{"bollingerMA": 20, "std_dev": 2}, {"bollingerMA": 30, "std_dev": 2},
+ {"bollingerMA": 70, "std_dev": 2}, {"bollingerMA": 5, "std_dev": 1},
+ {"bollingerMA": 20, "std_dev": 3}, {"bollingerMA": 50, "std_dev": 2.5}]
+ await init_tortoise()
+ wance_db = await wance_data_stock.WanceDataStock.all()
+ bollinger_list_lenght = len(bollinger_list)
+
+ for stock_code in wance_db:
+ for i in range(bollinger_list_lenght):
+ bollingerMA = bollinger_list[i]['bollingerMA']
+ std_dev = bollinger_list[i]['std_dev']
+ source_column_name = f'{stock_code} 布林带策略 MA{bollingerMA}-{std_dev}倍标准差'
+ result = await run_bollinger_backtest(field_list=['close', 'time'],
+ stock_list=[stock_code.stock_code],
+ bollingerMA=bollingerMA,
+ std_dev=std_dev)
+
+ print(f"回测成功 {source_column_name}")
+
+
+if __name__ == '__main__':
+ # 测试类的回测
+ asyncio.run(run_bollinger_backtest(field_list=['close', 'time'],
+ stock_list=['601222.SH', '601677.SH'],
+ bollingerMA=20,
+ std_dev=2))
+
+ # # 初始化数据库表
+ # asyncio.run(init_backtest_db())
diff --git a/src/backtest/dual_moving_average.py b/src/backtest/dual_moving_average.py
new file mode 100644
index 0000000..d5c274f
--- /dev/null
+++ b/src/backtest/dual_moving_average.py
@@ -0,0 +1,267 @@
+import asyncio
+import json
+import time
+from datetime import datetime
+
+import bt
+import numpy as np
+import pandas as pd
+from xtquant import xtdata
+import matplotlib.pyplot as plt
+
+from src.backtest.until import get_local_data, convert_pandas_to_json_serializable
+from src.models import wance_data_storage_backtest, wance_data_stock
+from src.tortoises_orm_config import init_tortoise
+
+
+# 双均线策略函数
+async def create_dual_ma_strategy(data, stock_code: str, short_window: int = 50, long_window: int = 200):
+ # 生成双均线策略信号
+ signal = await dual_ma_strategy(data, short_window, long_window)
+
+ # 使用bt框架构建策略
+ strategy = bt.Strategy(f'{stock_code} 双均线策略',
+ [bt.algos.RunDaily(),
+ bt.algos.SelectAll(), # 选择所有股票
+ bt.algos.WeighTarget(signal), # 根据信号调整权重
+ bt.algos.Rebalance()]) # 调仓
+ return strategy, signal
+
+
+async def dual_ma_strategy(df, short_window=20, long_window=50):
+ """
+ 基于双均线策略生成买卖信号。
+
+ 参数:
+ df: pd.DataFrame, 股票的价格数据,行索引为日期,列为股票代码。
+ short_window: int, 短期均线窗口期。
+ long_window: int, 长期均线窗口期。
+
+ 返回:
+ signal: pd.DataFrame, 每只股票的买卖信号,1 表示买入,0 表示卖出。
+ """
+ # 计算短期均线和长期均线
+ short_ma = df.rolling(window=short_window, min_periods=1).mean()
+ long_ma = df.rolling(window=long_window, min_periods=1).mean()
+
+ # 生成买入信号: 当短期均线从下方穿过长期均线
+ buy_signal = np.where(short_ma > long_ma, 1, np.nan)
+
+ # 生成卖出信号: 当短期均线从上方穿过长期均线
+ sell_signal = np.where(short_ma < long_ma, 0, np.nan)
+
+ # 合并买卖信号
+ signal = pd.DataFrame(buy_signal, index=df.index, columns=df.columns)
+ signal = np.where(short_ma < long_ma, 0, signal)
+
+ # 前向填充信号,持仓不变
+ signal = pd.DataFrame(signal, index=df.index, columns=df.columns).ffill()
+
+ # 将剩余的 NaN 替换为 0
+ signal = signal.fillna(0)
+
+ return signal
+
+
+async def storage_backtest_data(source_column_name, result, signal, stock_code, stock_data_series, short_window,
+ long_window):
+ await init_tortoise()
+
+ # 要存储的字段列表
+ fields_to_store = [
+ 'stock_code', 'strategy_name', 'stock_close_price', 'daily_price',
+ 'price', 'returns', 'data_start_time', 'data_end_time',
+ 'backtest_end_time', 'position', 'backtest_name', 'rf', 'total_return', 'cagr',
+ 'max_drawdown', 'calmar', 'mtd', 'three_month',
+ 'six_month', 'ytd', 'one_year', 'three_year',
+ 'five_year', 'ten_year', 'incep', 'daily_sharpe',
+ 'daily_sortino', 'daily_mean', 'daily_vol',
+ 'daily_skew', 'daily_kurt', 'best_day', 'worst_day',
+ 'monthly_sharpe', 'monthly_sortino', 'monthly_mean',
+ 'monthly_vol', 'monthly_skew', 'monthly_kurt',
+ 'best_month', 'worst_month', 'yearly_sharpe',
+ 'yearly_sortino', 'yearly_mean', 'yearly_vol',
+ 'yearly_skew', 'yearly_kurt', 'best_year', 'worst_year',
+ 'avg_drawdown', 'avg_drawdown_days', 'avg_up_month',
+ 'avg_down_month', 'win_year_perc', 'twelve_month_win_perc'
+ ]
+
+ # 准备要存储的数据
+ data_to_store = {
+ 'stock_code': stock_code,
+ 'strategy_name': "双均线策略",
+ 'stock_close_price': json.dumps(stock_data_series.fillna(0).rename_axis('time').reset_index().assign(
+ time=stock_data_series.index.strftime('%Y%m%d')).set_index('time').to_dict(orient='index')),
+ 'daily_price': convert_pandas_to_json_serializable(result[source_column_name].daily_prices),
+ 'price': convert_pandas_to_json_serializable(result[source_column_name].prices),
+ 'returns': convert_pandas_to_json_serializable(result[source_column_name].returns.fillna(0)),
+ 'data_start_time': pd.to_datetime(result.stats.loc["start"].iloc[0]).strftime('%Y%m%d'),
+ 'data_end_time': pd.to_datetime(result.stats.loc["end"].iloc[0]).strftime('%Y%m%d'),
+ 'backtest_end_time': int(datetime.now().strftime('%Y%m%d')),
+ 'position': convert_pandas_to_json_serializable(signal),
+ 'backtest_name': f'{stock_code} 双均线策略 MA{short_window}-{long_window}日',
+ 'indicator_type': 'SMA',
+ 'indicator_information': json.dumps({'short_window': short_window, 'long_window': long_window})
+ }
+
+ # 使用循环填充其他字段
+ for field in fields_to_store[12:]: # 从第10个字段开始
+ value = result.stats.loc[field].iloc[0]
+ data_to_store[field] = 0.0 if (isinstance(value, float) and np.isnan(value)) else value
+
+ # 检查是否存在该 backtest_name
+ existing_record = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=data_to_store['backtest_name']
+ ).first()
+
+ if existing_record:
+ # 如果存在,更新记录
+ await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ id=existing_record.id
+ ).update(**data_to_store)
+ else:
+ # 如果不存在,创建新的记录
+ await wance_data_storage_backtest.WanceDataStorageBacktest.create(**data_to_store)
+
+ return data_to_store
+
+
+async def run_sma_backtest(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = 100,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ short_window: int = 50,
+ long_window: int = 200):
+ try:
+ # 初始化一个列表用于存储每只股票的回测结果字典
+ results_list = []
+
+ # 遍历每只股票的数据(每列代表一个股票的收盘价)
+ data = await get_local_data(field_list, stock_list, period, start_time, end_time, count, dividend_type,
+ fill_data,
+ data_dir)
+
+ for stock_code in stock_list:
+
+ data_column_name = f'close_{stock_code}'
+ source_column_name = f'{stock_code} 双均线策略'
+ backtest_name = f'{stock_code} 双均线策略 MA{short_window}-{long_window}日'
+ now_data = int(datetime.now().strftime('%Y%m%d'))
+ db_result_data = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=backtest_name)
+
+ if db_result_data:
+ if db_result_data[0].backtest_end_time == now_data:
+ results_list.append({source_column_name: db_result_data[0]})
+
+ if data_column_name in data.columns:
+ stock_data_series = data[[data_column_name]] # 提取该股票的收盘价 DataFrame
+ stock_data_series.columns = ['close'] # 重命名列为 'close'
+
+ # 创建双均线策略
+ strategy, signal = await create_dual_ma_strategy(stock_data_series, stock_code,
+ short_window=short_window,
+ long_window=long_window)
+ # 创建回测
+ backtest = bt.Backtest(strategy=strategy, data=stock_data_series, initial_capital=100000)
+ # 运行回测
+ result = bt.run(backtest)
+ # 存储回测结果
+ data_to_store = await storage_backtest_data(source_column_name, result, signal, stock_code,
+ stock_data_series,
+ short_window, long_window)
+ # # 绘制回测结果图表
+ # result.plot()
+ # # 绘制个别股票数据图表
+ # plt.figure(figsize=(12, 6))
+ # plt.plot(stock_data_series.index, stock_data_series['close'], label='Stock Price')
+ # plt.title(f'Stock Price for {stock_code}')
+ # plt.xlabel('Date')
+ # plt.ylabel('Price')
+ # plt.legend()
+ # plt.grid(True)
+ # plt.show()
+ # 将结果存储为字典并添加到列表中
+ results_list.append({source_column_name: data_to_store})
+
+ else:
+ print(f"数据中缺少列: {data_column_name}")
+
+ return results_list # 返回结果列表
+
+ except Exception as e:
+ print(f"Error occurred: {e}")
+
+
+async def start_sma_backtest_service(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = -1,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ short_window: int = 50,
+ long_window: int = 200):
+ for stock_code in stock_list:
+ backtest_name = f'{stock_code} 双均线策略 MA{short_window}-{long_window}日'
+ db_result = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=backtest_name)
+ now_time = int(datetime.now().strftime('%Y%m%d'))
+
+ if db_result and db_result[0].backtest_end_time == now_time:
+ return db_result
+ else:
+ # 执行回测
+ result = await run_sma_backtest(
+ field_list=field_list,
+ stock_list=stock_list,
+ period=period,
+ start_time=start_time,
+ end_time=end_time,
+ count=count,
+ dividend_type=dividend_type,
+ fill_data=fill_data,
+ data_dir=data_dir,
+ short_window=short_window,
+ long_window=long_window,
+ )
+ return result
+
+
+async def init_backtest_db():
+ sma_list = [{"short_window": 5, "long_window": 10}, {"short_window": 10, "long_window": 30},
+ {"short_window": 30, "long_window": 60}, {"short_window": 30, "long_window": 90},
+ {"short_window": 70, "long_window": 140}, {"short_window": 120, "long_window": 250}]
+ await init_tortoise()
+ wance_db = await wance_data_stock.WanceDataStock.all()
+ sma_list_lenght = len(sma_list)
+
+ for stock_code in wance_db:
+ for i in range(sma_list_lenght):
+ short_window = sma_list[i]['short_window']
+ long_window = sma_list[i]['long_window']
+ source_column_name = f'{stock_code} 双均线策略 MA{short_window}-{long_window}日'
+ result = await run_sma_backtest(field_list=['close', 'time'],
+ stock_list=[stock_code.stock_code],
+ short_window=short_window,
+ long_window=long_window)
+
+ print(f"回测成功 {source_column_name}")
+
+
+if __name__ == '__main__':
+ # 测试类的回测
+ # asyncio.run(run_sma_backtest(field_list=['close', 'time'],
+ # stock_list=['601222.SH', '601677.SH'],
+ # short_window=10,
+ # long_window=30))
+
+ # 初始化数据库表
+ asyncio.run(init_backtest_db())
diff --git a/src/backtest/living_backtesting.py b/src/backtest/living_backtesting.py
new file mode 100644
index 0000000..6da418c
--- /dev/null
+++ b/src/backtest/living_backtesting.py
@@ -0,0 +1,92 @@
+import bt
+import pandas as pd
+import numpy as np
+from xtquant import xtdata
+import matplotlib.pyplot as plt
+
+# 数据的列名
+columns = ['open', 'high', 'low', 'close', 'volume', 'amount', 'settelmentPrice',
+ 'openInterest', 'preClose', 'suspendFlag']
+
+
+def get_local_data(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
+ count: int, dividend_type: str, fill_data: bool, data_dir: str):
+ result = xtdata.get_local_data(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
+ end_time=end_time, count=count, dividend_type=dividend_type, fill_data=fill_data,
+ data_dir=data_dir)
+ return data_processing(result)
+
+
+# 数据处理函数
+def data_processing(result_local):
+ # 初始化一个空的列表,用于存储每个股票的数据框
+ df_list = []
+
+ # 遍历字典中的 DataFrame
+ for stock_code, df in result_local.items():
+ # 确保 df 是一个 DataFrame
+ if isinstance(df, pd.DataFrame):
+ # 将时间戳转换为日期时间格式,只保留年-月-日
+ df['time'] = pd.to_datetime(df['time'], unit='ms').dt.date
+ # 将 'time' 列设置为索引,保留为日期格式
+ df.set_index('time', inplace=True)
+ # 指定列名
+ df.columns = columns
+ # 添加一列 'stock_code' 用于标识不同的股票
+ df['stock_code'] = stock_code
+ # 将 DataFrame 添加到列表中
+ df_list.append(df[['close']]) # 只保留 'close' 列
+ else:
+ print(f"数据格式错误: {stock_code} 不包含 DataFrame")
+
+ # 使用 pd.concat() 将所有 DataFrame 合并为一个大的 DataFrame
+ combined_df = pd.concat(df_list, axis=1)
+
+ # 确保返回的 DataFrame 索引是日期格式
+ combined_df.index = pd.to_datetime(combined_df.index)
+
+ # 打印最终的 DataFrame
+ print(combined_df)
+
+ return combined_df
+
+
+def bt_strategy(data):
+
+ # 计算策略信号
+ # data = moving_average_strategy(data)
+
+ # 定义策略
+ dual_ma_strategy = bt.Strategy('Dual MA Strategy', [bt.algos.RunOnce(),
+ bt.algos.SelectAll(),
+ bt.algos.WeighEqually(),
+ bt.algos.Rebalance()])
+ return dual_ma_strategy
+
+
+# 运行回测
+def run_backtest():
+ # 生成数据
+ data = get_local_data(field_list=[], stock_list=["300391.SZ"], period='1d', start_time='', end_time='', count=-1,
+ dividend_type='none', fill_data=True, data_dir="")
+
+ # 创建策略
+ strategy = bt_strategy(data)
+
+ # 创建回测
+ portfolio = bt.Backtest(strategy, data)
+ result = bt.run(portfolio)
+
+ return result
+
+
+# 执行回测并显示结果
+if __name__ == "__main__":
+ result = run_backtest()
+ result.plot()
+ b = xtdata.get_sector_list()
+ print(b)
+ xtdata.download_sector_data()
+ a = result.stats
+ print(a)
+ plt.show()
diff --git a/src/backtest/macd_strategy.py b/src/backtest/macd_strategy.py
new file mode 100644
index 0000000..7d23359
--- /dev/null
+++ b/src/backtest/macd_strategy.py
@@ -0,0 +1,271 @@
+import asyncio
+import json
+import time
+from datetime import datetime
+
+import bt
+import numpy as np
+import pandas as pd
+from xtquant import xtdata
+import matplotlib.pyplot as plt
+
+from src.backtest.until import get_local_data, convert_pandas_to_json_serializable
+from src.models import wance_data_storage_backtest, wance_data_stock
+from src.tortoises_orm_config import init_tortoise
+
+
+# MACD策略函数
+async def create_dual_ma_strategy(data, stock_code: str, short_window: int = 50, long_window: int = 200):
+ # 生成MACD策略信号
+ signal = await macd_strategy(data, short_window, long_window)
+
+ # 使用bt框架构建策略
+ strategy = bt.Strategy(f'{stock_code} MACD策略',
+ [bt.algos.RunDaily(),
+ bt.algos.SelectAll(), # 选择所有股票
+ bt.algos.WeighTarget(signal), # 根据信号调整权重
+ bt.algos.Rebalance()]) # 调仓
+ return strategy, signal
+
+
+# 定义 MACD 策略的函数
+def macd_strategy(data, short_window=12, long_window=26, signal_window=9):
+ """
+ MACD 策略,当 MACD 线穿过信号线时买入,反之卖出。
+
+ 参数:
+ data: pd.DataFrame, 股票的价格数据,行索引为日期,列为股票代码。
+ short_window: int, 短期 EMA 的窗口期。
+ long_window: int, 长期 EMA 的窗口期。
+ signal_window: int, 信号线 EMA 的窗口期。
+
+ 返回:
+ signal: pd.DataFrame, 每只股票的买卖信号,1 表示买入,-1 表示卖出。
+ """
+ # 计算短期和长期的 EMA
+ short_ema = data.ewm(span=short_window, adjust=False).mean()
+ long_ema = data.ewm(span=long_window, adjust=False).mean()
+
+ # 计算 MACD 线
+ macd_line = short_ema - long_ema
+
+ # 计算信号线
+ signal_line = macd_line.ewm(span=signal_window, adjust=False).mean()
+
+ # 生成买入和卖出信号
+ signal = pd.DataFrame(index=data.index, columns=data.columns)
+ for column in data.columns:
+ signal[column] = 0 # 初始化信号为 0
+ # 买入信号:MACD 线从下方穿过信号线
+ signal[column] = (macd_line[column] > signal_line[column]) & (macd_line[column].shift(1) <= signal_line[column].shift(1)).astype(int)
+ # 卖出信号:MACD 线从上方穿过信号线
+ signal[column] = (macd_line[column] < signal_line[column]) & (macd_line[column].shift(1) >= signal_line[column].shift(1)).astype(int) * -1 + signal[column]
+
+ # 前向填充信号,保持持仓不变
+ signal = signal.ffill().fillna(0)
+
+ return signal
+
+
+async def storage_backtest_data(source_column_name, result, signal, stock_code, stock_data_series, short_window,
+ long_window):
+ await init_tortoise()
+
+ # 要存储的字段列表
+ fields_to_store = [
+ 'stock_code', 'strategy_name', 'stock_close_price', 'daily_price',
+ 'price', 'returns', 'data_start_time', 'data_end_time',
+ 'backtest_end_time', 'position', 'backtest_name', 'rf', 'total_return', 'cagr',
+ 'max_drawdown', 'calmar', 'mtd', 'three_month',
+ 'six_month', 'ytd', 'one_year', 'three_year',
+ 'five_year', 'ten_year', 'incep', 'daily_sharpe',
+ 'daily_sortino', 'daily_mean', 'daily_vol',
+ 'daily_skew', 'daily_kurt', 'best_day', 'worst_day',
+ 'monthly_sharpe', 'monthly_sortino', 'monthly_mean',
+ 'monthly_vol', 'monthly_skew', 'monthly_kurt',
+ 'best_month', 'worst_month', 'yearly_sharpe',
+ 'yearly_sortino', 'yearly_mean', 'yearly_vol',
+ 'yearly_skew', 'yearly_kurt', 'best_year', 'worst_year',
+ 'avg_drawdown', 'avg_drawdown_days', 'avg_up_month',
+ 'avg_down_month', 'win_year_perc', 'twelve_month_win_perc'
+ ]
+
+ # 准备要存储的数据
+ data_to_store = {
+ 'stock_code': stock_code,
+ 'strategy_name': "MACD策略",
+ 'stock_close_price': json.dumps(stock_data_series.fillna(0).rename_axis('time').reset_index().assign(
+ time=stock_data_series.index.strftime('%Y%m%d')).set_index('time').to_dict(orient='index')),
+ 'daily_price': convert_pandas_to_json_serializable(result[source_column_name].daily_prices),
+ 'price': convert_pandas_to_json_serializable(result[source_column_name].prices),
+ 'returns': convert_pandas_to_json_serializable(result[source_column_name].returns.fillna(0)),
+ 'data_start_time': pd.to_datetime(result.stats.loc["start"].iloc[0]).strftime('%Y%m%d'),
+ 'data_end_time': pd.to_datetime(result.stats.loc["end"].iloc[0]).strftime('%Y%m%d'),
+ 'backtest_end_time': int(datetime.now().strftime('%Y%m%d')),
+ 'position': convert_pandas_to_json_serializable(signal),
+ 'backtest_name': f'{stock_code} MACD策略 MA{short_window}-{long_window}日',
+ 'indicator_type': 'MACD',
+ 'indicator_information': json.dumps({'short_window': short_window, 'long_window': long_window})
+ }
+
+ # 使用循环填充其他字段
+ for field in fields_to_store[12:]: # 从第10个字段开始
+ value = result.stats.loc[field].iloc[0]
+ data_to_store[field] = 0.0 if (isinstance(value, float) and np.isnan(value)) else value
+
+ # 检查是否存在该 backtest_name
+ existing_record = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=data_to_store['backtest_name']
+ ).first()
+
+ if existing_record:
+ # 如果存在,更新记录
+ await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ id=existing_record.id
+ ).update(**data_to_store)
+ else:
+ # 如果不存在,创建新的记录
+ await wance_data_storage_backtest.WanceDataStorageBacktest.create(**data_to_store)
+
+ return data_to_store
+
+
+async def run_macd_backtest(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = 100,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ short_window: int = 50,
+ long_window: int = 200):
+ try:
+ # 初始化一个列表用于存储每只股票的回测结果字典
+ results_list = []
+
+ # 遍历每只股票的数据(每列代表一个股票的收盘价)
+ data = await get_local_data(field_list, stock_list, period, start_time, end_time, count, dividend_type,
+ fill_data,
+ data_dir)
+
+ for stock_code in stock_list:
+
+ data_column_name = f'close_{stock_code}'
+ source_column_name = f'{stock_code} MACD策略'
+ backtest_name = f'{stock_code} MACD策略 MA{short_window}-{long_window}日'
+ now_data = int(datetime.now().strftime('%Y%m%d'))
+ db_result_data = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=backtest_name)
+
+ if db_result_data:
+ if db_result_data[0].backtest_end_time == now_data:
+ results_list.append({source_column_name: db_result_data[0]})
+
+ elif data_column_name in data.columns:
+ stock_data_series = data[[data_column_name]] # 提取该股票的收盘价 DataFrame
+ stock_data_series.columns = ['close'] # 重命名列为 'close'
+
+ # 创建MACD策略
+ strategy, signal = await create_dual_ma_strategy(stock_data_series, stock_code,
+ short_window=short_window,
+ long_window=long_window)
+ # 创建回测
+ backtest = bt.Backtest(strategy=strategy, data=stock_data_series, initial_capital=100000)
+ # 运行回测
+ result = bt.run(backtest)
+ # 存储回测结果
+ data_to_store = await storage_backtest_data(source_column_name, result, signal, stock_code,
+ stock_data_series,
+ short_window, long_window)
+ # # 绘制回测结果图表
+ # result.plot()
+ # # 绘制个别股票数据图表
+ # plt.figure(figsize=(12, 6))
+ # plt.plot(stock_data_series.index, stock_data_series['close'], label='Stock Price')
+ # plt.title(f'Stock Price for {stock_code}')
+ # plt.xlabel('Date')
+ # plt.ylabel('Price')
+ # plt.legend()
+ # plt.grid(True)
+ # plt.show()
+ # 将结果存储为字典并添加到列表中
+ results_list.append({source_column_name: data_to_store})
+
+ else:
+ print(f"数据中缺少列: {data_column_name}")
+
+ return results_list # 返回结果列表
+
+ except Exception as e:
+ print(f"Error occurred: {e}")
+
+
+async def start_macd_backtest_service(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = -1,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ short_window: int = 50,
+ long_window: int = 200):
+ for stock_code in stock_list:
+ backtest_name = f'{stock_code} MACD策略 MA{short_window}-{long_window}日'
+ db_result = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=backtest_name)
+ now_time = int(datetime.now().strftime('%Y%m%d'))
+
+ if db_result and db_result[0].backtest_end_time == now_time:
+ return db_result
+ else:
+ # 执行回测
+ result = await run_macd_backtest(
+ field_list=field_list,
+ stock_list=stock_list,
+ period=period,
+ start_time=start_time,
+ end_time=end_time,
+ count=count,
+ dividend_type=dividend_type,
+ fill_data=fill_data,
+ data_dir=data_dir,
+ short_window=short_window,
+ long_window=long_window,
+ )
+ return result
+
+
+async def init_backtest_db():
+ MACD_list = [{"short_window": 5, "long_window": 10}, {"short_window": 10, "long_window": 30},
+ {"short_window": 30, "long_window": 60}, {"short_window": 30, "long_window": 90},
+ {"short_window": 70, "long_window": 140}, {"short_window": 120, "long_window": 250}]
+ await init_tortoise()
+ wance_db = await wance_data_stock.WanceDataStock.all()
+ MACD_list_lenght = len(MACD_list)
+
+ for stock_code in wance_db:
+ for i in range(MACD_list_lenght):
+ short_window = MACD_list[i]['short_window']
+ long_window = MACD_list[i]['long_window']
+ source_column_name = f'{stock_code} MACD策略 MA{short_window}-{long_window}日'
+ result = await run_macd_backtest(field_list=['close', 'time'],
+ stock_list=[stock_code.stock_code],
+ short_window=short_window,
+ long_window=long_window)
+
+ print(f"回测成功 {source_column_name}")
+
+
+if __name__ == '__main__':
+ # 测试类的回测
+ # asyncio.run(run_macd_backtest(field_list=['close', 'time'],
+ # stock_list=['601222.SH', '601677.SH'],
+ # short_window=10,
+ # long_window=30))
+
+ # 初始化数据库表
+ asyncio.run(init_backtest_db())
diff --git a/src/backtest/reverse_dual_ma_strategy.py b/src/backtest/reverse_dual_ma_strategy.py
new file mode 100644
index 0000000..2417c53
--- /dev/null
+++ b/src/backtest/reverse_dual_ma_strategy.py
@@ -0,0 +1,264 @@
+import asyncio
+import json
+import time
+from datetime import datetime
+
+import bt
+import numpy as np
+import pandas as pd
+from xtquant import xtdata
+import matplotlib.pyplot as plt
+
+from src.backtest.until import get_local_data, convert_pandas_to_json_serializable
+from src.models import wance_data_storage_backtest, wance_data_stock
+from src.tortoises_orm_config import init_tortoise
+
+
+# 反双均线策略函数
+async def create_dual_ma_strategy(data, stock_code: str, short_window: int = 50, long_window: int = 200):
+ # 生成反双均线策略信号
+ signal = await reverse_dual_ma_strategy(data, short_window, long_window)
+
+ # 使用bt框架构建策略
+ strategy = bt.Strategy(f'{stock_code} 反双均线策略',
+ [bt.algos.RunDaily(),
+ bt.algos.SelectAll(), # 选择所有股票
+ bt.algos.WeighTarget(signal), # 根据信号调整权重
+ bt.algos.Rebalance()]) # 调仓
+ return strategy, signal
+
+
+
+# 定义反反双均线策略的函数
+def reverse_dual_ma_strategy(data, short_window=50, long_window=200):
+ """
+ 反反双均线策略,当短期均线跌破长期均线时买入,穿过长期均线时卖出。
+
+ 参数:
+ data: pd.DataFrame, 股票的价格数据,行索引为日期,列为股票代码。
+ short_window: int, 短期均线的窗口期。
+ long_window: int, 长期均线的窗口期。
+
+ 返回:
+ signal: pd.DataFrame, 每只股票的买卖信号,1 表示买入,0 表示卖出。
+ """
+ # 计算短期均线和长期均线
+ short_ma = data.rolling(window=short_window).mean()
+ long_ma = data.rolling(window=long_window).mean()
+
+ # 初始化信号 DataFrame
+ signal = pd.DataFrame(index=data.index, columns=data.columns)
+
+ # 生成买入信号:短期均线从上往下穿过长期均线
+ for column in data.columns:
+ signal[column] = (short_ma[column] < long_ma[column]).astype(int) # 跌破时买入,信号为1
+ signal[column] = (short_ma[column] > long_ma[column]).astype(int) * -1 + signal[column] # 穿过时卖出,信号为0
+
+ # 前向填充信号,保持持仓不变
+ signal = signal.ffill()
+
+ return signal
+
+
+async def storage_backtest_data(source_column_name, result, signal, stock_code, stock_data_series, short_window,
+ long_window):
+ await init_tortoise()
+
+ # 要存储的字段列表
+ fields_to_store = [
+ 'stock_code', 'strategy_name', 'stock_close_price', 'daily_price',
+ 'price', 'returns', 'data_start_time', 'data_end_time',
+ 'backtest_end_time', 'position', 'backtest_name', 'rf', 'total_return', 'cagr',
+ 'max_drawdown', 'calmar', 'mtd', 'three_month',
+ 'six_month', 'ytd', 'one_year', 'three_year',
+ 'five_year', 'ten_year', 'incep', 'daily_sharpe',
+ 'daily_sortino', 'daily_mean', 'daily_vol',
+ 'daily_skew', 'daily_kurt', 'best_day', 'worst_day',
+ 'monthly_sharpe', 'monthly_sortino', 'monthly_mean',
+ 'monthly_vol', 'monthly_skew', 'monthly_kurt',
+ 'best_month', 'worst_month', 'yearly_sharpe',
+ 'yearly_sortino', 'yearly_mean', 'yearly_vol',
+ 'yearly_skew', 'yearly_kurt', 'best_year', 'worst_year',
+ 'avg_drawdown', 'avg_drawdown_days', 'avg_up_month',
+ 'avg_down_month', 'win_year_perc', 'twelve_month_win_perc'
+ ]
+
+ # 准备要存储的数据
+ data_to_store = {
+ 'stock_code': stock_code,
+ 'strategy_name': "反双均线策略",
+ 'stock_close_price': json.dumps(stock_data_series.fillna(0).rename_axis('time').reset_index().assign(
+ time=stock_data_series.index.strftime('%Y%m%d')).set_index('time').to_dict(orient='index')),
+ 'daily_price': convert_pandas_to_json_serializable(result[source_column_name].daily_prices),
+ 'price': convert_pandas_to_json_serializable(result[source_column_name].prices),
+ 'returns': convert_pandas_to_json_serializable(result[source_column_name].returns.fillna(0)),
+ 'data_start_time': pd.to_datetime(result.stats.loc["start"].iloc[0]).strftime('%Y%m%d'),
+ 'data_end_time': pd.to_datetime(result.stats.loc["end"].iloc[0]).strftime('%Y%m%d'),
+ 'backtest_end_time': int(datetime.now().strftime('%Y%m%d')),
+ 'position': convert_pandas_to_json_serializable(signal),
+ 'backtest_name': f'{stock_code} 反双均线策略 MA{short_window}-{long_window}日',
+ 'indicator_type': 'reverse_SMA',
+ 'indicator_information': json.dumps({'short_window': short_window, 'long_window': long_window})
+ }
+
+ # 使用循环填充其他字段
+ for field in fields_to_store[12:]: # 从第10个字段开始
+ value = result.stats.loc[field].iloc[0]
+ data_to_store[field] = 0.0 if (isinstance(value, float) and np.isnan(value)) else value
+
+ # 检查是否存在该 backtest_name
+ existing_record = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=data_to_store['backtest_name']
+ ).first()
+
+ if existing_record:
+ # 如果存在,更新记录
+ await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ id=existing_record.id
+ ).update(**data_to_store)
+ else:
+ # 如果不存在,创建新的记录
+ await wance_data_storage_backtest.WanceDataStorageBacktest.create(**data_to_store)
+
+ return data_to_store
+
+
+async def run_reverse_reverse_SMA_backtest(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = 100,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ short_window: int = 50,
+ long_window: int = 200):
+ try:
+ # 初始化一个列表用于存储每只股票的回测结果字典
+ results_list = []
+
+ # 遍历每只股票的数据(每列代表一个股票的收盘价)
+ data = await get_local_data(field_list, stock_list, period, start_time, end_time, count, dividend_type,
+ fill_data,
+ data_dir)
+
+ for stock_code in stock_list:
+
+ data_column_name = f'close_{stock_code}'
+ source_column_name = f'{stock_code} 反双均线策略'
+ backtest_name = f'{stock_code} 反双均线策略 MA{short_window}-{long_window}日'
+ now_data = int(datetime.now().strftime('%Y%m%d'))
+ db_result_data = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=backtest_name)
+
+ if db_result_data:
+ if db_result_data[0].backtest_end_time == now_data:
+ results_list.append({source_column_name: db_result_data[0]})
+
+ elif data_column_name in data.columns:
+ stock_data_series = data[[data_column_name]] # 提取该股票的收盘价 DataFrame
+ stock_data_series.columns = ['close'] # 重命名列为 'close'
+
+ # 创建反双均线策略
+ strategy, signal = await create_dual_ma_strategy(stock_data_series, stock_code,
+ short_window=short_window,
+ long_window=long_window)
+ # 创建回测
+ backtest = bt.Backtest(strategy=strategy, data=stock_data_series, initial_capital=100000)
+ # 运行回测
+ result = bt.run(backtest)
+ # 存储回测结果
+ data_to_store = await storage_backtest_data(source_column_name, result, signal, stock_code,
+ stock_data_series,
+ short_window, long_window)
+ # # 绘制回测结果图表
+ # result.plot()
+ # # 绘制个别股票数据图表
+ # plt.figure(figsize=(12, 6))
+ # plt.plot(stock_data_series.index, stock_data_series['close'], label='Stock Price')
+ # plt.title(f'Stock Price for {stock_code}')
+ # plt.xlabel('Date')
+ # plt.ylabel('Price')
+ # plt.legend()
+ # plt.grid(True)
+ # plt.show()
+ # 将结果存储为字典并添加到列表中
+ results_list.append({source_column_name: data_to_store})
+
+ else:
+ print(f"数据中缺少列: {data_column_name}")
+
+ return results_list # 返回结果列表
+
+ except Exception as e:
+ print(f"Error occurred: {e}")
+
+
+async def start_reverse_SMA_backtest_service(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = -1,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ short_window: int = 50,
+ long_window: int = 200):
+ for stock_code in stock_list:
+ backtest_name = f'{stock_code} 反双均线策略 MA{short_window}-{long_window}日'
+ db_result = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=backtest_name)
+ now_time = int(datetime.now().strftime('%Y%m%d'))
+
+ if db_result and db_result[0].backtest_end_time == now_time:
+ return db_result
+ else:
+ # 执行回测
+ result = await run_reverse_reverse_SMA_backtest(
+ field_list=field_list,
+ stock_list=stock_list,
+ period=period,
+ start_time=start_time,
+ end_time=end_time,
+ count=count,
+ dividend_type=dividend_type,
+ fill_data=fill_data,
+ data_dir=data_dir,
+ short_window=short_window,
+ long_window=long_window,
+ )
+ return result
+
+
+async def init_backtest_db():
+ reverse_SMA_list = [{"short_window": 5, "long_window": 10}, {"short_window": 10, "long_window": 30},
+ {"short_window": 30, "long_window": 60}, {"short_window": 30, "long_window": 90},
+ {"short_window": 70, "long_window": 140}, {"short_window": 120, "long_window": 250}]
+ await init_tortoise()
+ wance_db = await wance_data_stock.WanceDataStock.all()
+ reverse_SMA_list_lenght = len(reverse_SMA_list)
+
+ for stock_code in wance_db:
+ for i in range(reverse_SMA_list_lenght):
+ short_window = reverse_SMA_list[i]['short_window']
+ long_window = reverse_SMA_list[i]['long_window']
+ source_column_name = f'{stock_code} 反双均线策略 MA{short_window}-{long_window}日'
+ result = await run_reverse_reverse_SMA_backtest(field_list=['close', 'time'],
+ stock_list=[stock_code.stock_code],
+ short_window=short_window,
+ long_window=long_window)
+
+ print(f"回测成功 {source_column_name}")
+
+
+if __name__ == '__main__':
+ # 测试类的回测
+ # asyncio.run(run_reverse_SMA_backtest(field_list=['close', 'time'],
+ # stock_list=['601222.SH', '601677.SH'],
+ # short_window=10,
+ # long_window=30))
+
+ # 初始化数据库表
+ asyncio.run(init_backtest_db())
diff --git a/src/backtest/router.py b/src/backtest/router.py
new file mode 100644
index 0000000..4eed6ea
--- /dev/null
+++ b/src/backtest/router.py
@@ -0,0 +1,23 @@
+from fastapi import APIRouter, HTTPException # 从 FastAPI 中导入 APIRouter,用于创建 API 路由器
+
+from src.backtest.service import start_backtest_service
+from src.pydantic.backtest_request import BackRequest
+
+router = APIRouter() # 创建一个 FastAPI 路由器实例
+
+
+@router.get("/start_backtest")
+async def start_backtest(request: BackRequest):
+ result = await start_backtest_service(field_list=['close', 'time'],
+ stock_list=request.stock_list,
+ period=request.period,
+ start_time=request.start_time,
+ end_time=request.end_time,
+ count=request.count,
+ dividend_type=request.dividend_type,
+ fill_data=request.fill_data,
+ ma_type=request.ma_type,
+ short_window=request.short_window,
+ long_window=request.long_window
+ )
+ return result
diff --git a/src/backtest/rsi_strategy.py b/src/backtest/rsi_strategy.py
new file mode 100644
index 0000000..9832db2
--- /dev/null
+++ b/src/backtest/rsi_strategy.py
@@ -0,0 +1,296 @@
+import asyncio
+import json
+import time
+from datetime import datetime
+
+import bt
+import numpy as np
+import pandas as pd
+from xtquant import xtdata
+import matplotlib.pyplot as plt
+
+from src.backtest.until import get_local_data, convert_pandas_to_json_serializable
+from src.models import wance_data_storage_backtest, wance_data_stock
+from src.tortoises_orm_config import init_tortoise
+
+
+# RSI策略函数
+async def create_dual_ma_strategy(data, stock_code: str, short_window: int = 50, long_window: int = 200,
+ overbought: int = 70, oversold: int = 30):
+ # 生成RSI策略信号
+ signal = await rsi_strategy(data, short_window, long_window, overbought, oversold)
+
+ # 使用bt框架构建策略
+ strategy = bt.Strategy(f'{stock_code} RSI策略',
+ [bt.algos.RunDaily(),
+ bt.algos.SelectAll(), # 选择所有股票
+ bt.algos.WeighTarget(signal), # 根据信号调整权重
+ bt.algos.Rebalance()]) # 调仓
+ return strategy, signal
+
+
+async def rsi_strategy(df, short_window=14, long_window=28, overbought=70, oversold=30):
+ """
+ 基于RSI的策略生成买卖信号。
+
+ 参数:
+ df: pd.DataFrame, 股票的价格数据,行索引为日期,列为股票代码。
+ short_window: int, 短期RSI的窗口期。
+ long_window: int, 长期RSI的窗口期。
+ overbought: int, 超买水平。
+ oversold: int, 超卖水平。
+
+ 返回:
+ signal: pd.DataFrame, 每只股票的买卖信号,1 表示买入,0 表示卖出。
+ """
+ delta = df.diff().fillna(0)
+
+ gain = (delta.where(delta > 0, 0).rolling(window=short_window).mean()).fillna(0)
+ loss = (-delta.where(delta < 0, 0).rolling(window=short_window).mean()).fillna(0)
+
+ short_rsi = (100 - (100 / (1 + (gain / loss)))).fillna(0)
+
+ long_gain = (delta.where(delta > 0, 0).rolling(window=long_window).mean()).fillna(0)
+ long_loss = (-delta.where(delta < 0, 0).rolling(window=long_window).mean()).fillna(0)
+
+ long_rsi = (100 - (100 / (1 + (long_gain / long_loss)))).fillna(0)
+
+ signal = pd.DataFrame(index=df.index, columns=df.columns)
+
+ for column in df.columns:
+ signal[column] = np.where((short_rsi[column] < 30) & (long_rsi[column] < 30) & (short_rsi[column] != 0) & (long_rsi[column] != 0), 1, 0)
+ signal[column] = np.where((short_rsi[column] > 70) & (long_rsi[column] > 70) & (short_rsi[column] != 0) & (long_rsi[column] != 0), 0, signal[column])
+
+ return signal.ffill().fillna(0)
+
+
+async def storage_backtest_data(source_column_name, result, signal, stock_code, stock_data_series, short_window: int,
+ long_window: int, overbought: int = 70,
+ oversold: int = 30):
+ await init_tortoise()
+
+ # 要存储的字段列表
+ fields_to_store = [
+ 'stock_code', 'strategy_name', 'stock_close_price', 'daily_price',
+ 'price', 'returns', 'data_start_time', 'data_end_time',
+ 'backtest_end_time', 'position', 'backtest_name', 'rf', 'total_return', 'cagr',
+ 'max_drawdown', 'calmar', 'mtd', 'three_month',
+ 'six_month', 'ytd', 'one_year', 'three_year',
+ 'five_year', 'ten_year', 'incep', 'daily_sharpe',
+ 'daily_sortino', 'daily_mean', 'daily_vol',
+ 'daily_skew', 'daily_kurt', 'best_day', 'worst_day',
+ 'monthly_sharpe', 'monthly_sortino', 'monthly_mean',
+ 'monthly_vol', 'monthly_skew', 'monthly_kurt',
+ 'best_month', 'worst_month', 'yearly_sharpe',
+ 'yearly_sortino', 'yearly_mean', 'yearly_vol',
+ 'yearly_skew', 'yearly_kurt', 'best_year', 'worst_year',
+ 'avg_drawdown', 'avg_drawdown_days', 'avg_up_month',
+ 'avg_down_month', 'win_year_perc', 'twelve_month_win_perc'
+ ]
+
+ # 准备要存储的数据
+ data_to_store = {
+ 'stock_code': stock_code,
+ 'strategy_name': "RSI策略",
+ 'stock_close_price': json.dumps(stock_data_series.fillna(0).rename_axis('time').reset_index().assign(
+ time=stock_data_series.index.strftime('%Y%m%d')).set_index('time').to_dict(orient='index')),
+ 'daily_price': convert_pandas_to_json_serializable(result[source_column_name].daily_prices),
+ 'price': convert_pandas_to_json_serializable(result[source_column_name].prices),
+ 'returns': convert_pandas_to_json_serializable(result[source_column_name].returns.fillna(0)),
+ 'data_start_time': pd.to_datetime(result.stats.loc["start"].iloc[0]).strftime('%Y%m%d'),
+ 'data_end_time': pd.to_datetime(result.stats.loc["end"].iloc[0]).strftime('%Y%m%d'),
+ 'backtest_end_time': int(datetime.now().strftime('%Y%m%d')),
+ 'position': convert_pandas_to_json_serializable(signal),
+ 'backtest_name': f'{stock_code} RSI策略 RSI{short_window}-RSI{long_window}-overbought{overbought}-oversold{oversold}',
+ 'indicator_type': 'RSI',
+ 'indicator_information': json.dumps(
+ {'short_window': short_window, 'long_window': long_window, 'overbought': overbought, 'oversold': oversold})
+ }
+
+ # 使用循环填充其他字段
+ for field in fields_to_store[12:]: # 从第12个字段开始
+ value = result.stats.loc[field].iloc[0]
+
+ if isinstance(value, float):
+ if np.isnan(value):
+ data_to_store[field] = 0.0 # NaN 处理为 0
+ elif np.isinf(value): # 判断是否为无穷大或无穷小
+ if value > 0:
+ data_to_store[field] = 99999.9999 # 正无穷处理
+ else:
+ data_to_store[field] = -99999.9999 # 负无穷处理
+ else:
+ data_to_store[field] = value # 正常的浮点值
+ else:
+ data_to_store[field] = value # 非浮点类型保持不变
+
+ # 检查是否存在该 backtest_name
+ existing_record = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=data_to_store['backtest_name']
+ ).first()
+
+ if existing_record:
+ # 如果存在,更新记录
+ await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ id=existing_record.id
+ ).update(**data_to_store)
+ else:
+ # 如果不存在,创建新的记录
+ await wance_data_storage_backtest.WanceDataStorageBacktest.create(**data_to_store)
+
+ return data_to_store
+
+
+async def run_rsi_backtest(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = 100,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ short_window: int = 50,
+ long_window: int = 200,
+ overbought: int = 70,
+ oversold: int = 30
+ ):
+ try:
+ # 初始化一个列表用于存储每只股票的回测结果字典
+ results_list = []
+
+ # 遍历每只股票的数据(每列代表一个股票的收盘价)
+ data = await get_local_data(field_list, stock_list, period, start_time, end_time, count, dividend_type,
+ fill_data,
+ data_dir)
+
+ for stock_code in stock_list:
+
+ data_column_name = f'close_{stock_code}'
+ source_column_name = f'{stock_code} RSI策略'
+ backtest_name = f'{stock_code} RSI策略 RSI{short_window}-RSI{long_window}'
+ now_data = int(datetime.now().strftime('%Y%m%d'))
+ db_result_data = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=backtest_name)
+
+ if db_result_data:
+ if db_result_data[0].backtest_end_time == now_data:
+ results_list.append({source_column_name: db_result_data[0]})
+
+ # elif data_column_name in data.columns:
+ if data_column_name in data.columns:
+ stock_data_series = data[[data_column_name]] # 提取该股票的收盘价 DataFrame
+ stock_data_series.columns = ['close'] # 重命名列为 'close'
+
+ # 创建RSI策略
+ strategy, signal = await create_dual_ma_strategy(stock_data_series, stock_code,
+ short_window=short_window, long_window=long_window)
+ # 创建回测
+ backtest = bt.Backtest(strategy=strategy, data=stock_data_series, initial_capital=100000)
+ # 运行回测
+ result = bt.run(backtest)
+ # 存储回测结果
+ data_to_store = await storage_backtest_data(source_column_name, result, signal, stock_code,
+ stock_data_series, short_window, long_window, overbought,
+ oversold)
+ # # 绘制回测结果图表
+ # result.plot()
+ # # 绘制个别股票数据图表
+ # plt.figure(figsize=(12, 6))
+ # plt.plot(stock_data_series.index, stock_data_series['close'], label='Stock Price')
+ # plt.title(f'Stock Price for {stock_code}')
+ # plt.xlabel('Date')
+ # plt.ylabel('Price')
+ # plt.legend()
+ # plt.grid(True)
+ # plt.show()
+ # 将结果存储为字典并添加到列表中
+ results_list.append({source_column_name: data_to_store})
+
+ else:
+ print(f"数据中缺少列: {data_column_name}")
+
+ return results_list # 返回结果列表
+
+ except Exception as e:
+ print(f"Error occurred: {e}")
+
+
+async def start_rsi_backtest_service(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = -1,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ short_window: int = 50,
+ long_window: int = 200,
+ overbought: int = 70,
+ oversold: int = 30
+ ):
+ for stock_code in stock_list:
+ backtest_name = f'{stock_code} RSI策略 RSI{short_window}-RSI{long_window}'
+ db_result = await wance_data_storage_backtest.WanceDataStorageBacktest.filter(
+ backtest_name=backtest_name)
+ now_time = int(datetime.now().strftime('%Y%m%d'))
+
+ if db_result and db_result[0].backtest_end_time == now_time:
+ return db_result
+ else:
+ # 执行回测
+ result = await run_rsi_backtest(
+ field_list=field_list,
+ stock_list=stock_list,
+ period=period,
+ start_time=start_time,
+ end_time=end_time,
+ count=count,
+ dividend_type=dividend_type,
+ fill_data=fill_data,
+ data_dir=data_dir,
+ short_window=short_window,
+ long_window=long_window,
+ overbought=overbought,
+ oversold=oversold
+ )
+ return result
+
+
+async def init_backtest_db():
+ sma_list = [{"short_window": 3, "long_window": 6}, {"short_window": 6, "long_window": 12},
+ {"short_window": 12, "long_window": 24}, {"short_window": 14, "long_window": 18},
+ {"short_window": 15, "long_window": 10}]
+ await init_tortoise()
+ wance_db = await wance_data_stock.WanceDataStock.all()
+ sma_list_lenght = len(sma_list)
+
+ for stock_code in wance_db:
+ for i in range(sma_list_lenght):
+ short_window = sma_list[i]['short_window']
+ long_window = sma_list[i]['long_window']
+ source_column_name = f'{stock_code} RSI策略 RSI{short_window}-RSI{long_window}'
+ result = await start_rsi_backtest_service(field_list=['close', 'time'],
+ stock_list=[stock_code.stock_code],
+ short_window=short_window,
+ long_window=long_window,
+ overbought=70,
+ oversold=30)
+
+ print(f"回测成功 {source_column_name}")
+
+
+if __name__ == '__main__':
+ # 测试类的回测
+ asyncio.run(run_rsi_backtest(field_list=['close', 'time'],
+ stock_list=['601222.SH', '601677.SH'],
+ count=-1,
+ short_window=10,
+ long_window=30,
+ overbought=70,
+ oversold=30
+ ))
+
+ # # 初始化数据库表
+ # asyncio.run(init_backtest_db())
diff --git a/src/backtest/service.py b/src/backtest/service.py
new file mode 100644
index 0000000..e07e9b8
--- /dev/null
+++ b/src/backtest/service.py
@@ -0,0 +1,74 @@
+from src.backtest.bollinger_bands import run_bollinger_backtest, start_bollinger_backtest_service
+from src.backtest.dual_moving_average import run_sma_backtest, start_sma_backtest_service
+from src.backtest.reverse_dual_ma_strategy import start_reverse_SMA_backtest_service
+from src.backtest.rsi_strategy import start_rsi_backtest_service
+from src.backtest.until import data_check
+
+
+async def start_backtest_service(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = -1,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '',
+ ma_type: str = 'SMA',
+ short_window: int = 50,
+ long_window: int = 200,
+ bollingerMA: int = 200,
+ std_dev: int = 200,
+ overbought: int = 70,
+ oversold: int = 30,
+ signal_window: int = 9):
+ # 数据检查
+ await data_check(field_list=field_list,
+ stock_list=stock_list,
+ period=period,
+ start_time=start_time,
+ end_time=end_time,
+ count=count,
+ dividend_type=dividend_type,
+ fill_data=fill_data,
+ data_dir=data_dir)
+
+ # 策略映射
+ strategies = {
+ 'SMA': start_sma_backtest_service,
+ 'Bollinger': start_bollinger_backtest_service,
+ 'RSI': start_rsi_backtest_service,
+ 'RESMA': start_reverse_SMA_backtest_service,
+ 'MACD': start_rsi_backtest_service
+ }
+
+ # 通用参数
+ base_params = {
+ 'field_list': field_list,
+ 'stock_list': stock_list,
+ 'period': period,
+ 'start_time': start_time,
+ 'end_time': end_time,
+ 'count': count,
+ 'dividend_type': dividend_type,
+ 'fill_data': fill_data,
+ 'data_dir': data_dir,
+ }
+
+ # 特定策略参数
+ strategy_params = {
+ 'SMA': {'short_window': short_window, 'long_window': long_window},
+ 'Bollinger': {'bollingerMA': bollingerMA, 'std_dev': std_dev},
+ 'RSI': {'short_window': short_window, 'long_window': long_window, 'overbought': overbought,
+ 'oversold': oversold},
+ 'RESMA': {'short_window': short_window, 'long_window': long_window},
+ 'MACD': {'short_window': short_window, 'long_window': signal_window}
+ }
+
+ # 选择策略并执行
+ strategy_func = strategies.get(ma_type)
+ if strategy_func:
+ result = await strategy_func(**base_params, **strategy_params[ma_type])
+ return result
+ else:
+ return None
diff --git a/src/backtest/until.py b/src/backtest/until.py
new file mode 100644
index 0000000..253919c
--- /dev/null
+++ b/src/backtest/until.py
@@ -0,0 +1,99 @@
+import json
+from datetime import datetime
+
+import numpy as np
+from xtquant import xtdata
+import pandas as pd
+
+# 数据的列名
+columns = ['open', 'high', 'low', 'close', 'volume', 'amount', 'settelmentPrice',
+ 'openInterest', 'preClose', 'suspendFlag']
+
+
+# 获取本地数据并进行处理
+async def get_local_data(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
+ count: int, dividend_type: str, fill_data: bool, data_dir: str):
+ result = xtdata.get_local_data(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
+ end_time=end_time, count=count, dividend_type=dividend_type, fill_data=fill_data,
+ data_dir=data_dir)
+ return await data_processing(result)
+
+
+async def data_processing(result_local):
+ # 初始化一个空的列表,用于存储每个股票的数据框
+ df_list = []
+
+ # 遍历字典中的 DataFrame
+ for stock_code, df in result_local.items():
+ # 确保 df 是一个 DataFrame
+ if isinstance(df, pd.DataFrame):
+ # 将时间戳转换为日期时间格式,只保留年-月-日
+ df['time'] = pd.to_datetime(df['time'], unit='ms').dt.date
+ # 将 'time' 列设置为索引,保留为日期格式
+ df.set_index('time', inplace=True)
+ # 将 'close' 列重命名为 'close_股票代码'
+ df.rename(columns={'close': f'close_{stock_code}'}, inplace=True)
+ # 将 DataFrame 添加到列表中
+ df_list.append(df[[f'close_{stock_code}']]) # 只保留 'close_股票代码' 列
+ else:
+ print(f"数据格式错误: {stock_code} 不包含 DataFrame")
+
+ # 使用 pd.concat() 将所有 DataFrame 合并为一个大的 DataFrame
+ combined_df = pd.concat(df_list, axis=1)
+
+ # 确保返回的 DataFrame 索引是日期格式
+ combined_df.index = pd.to_datetime(combined_df.index)
+
+ return combined_df
+
+
+def convert_pandas_to_json_serializable(data: pd.Series) -> str:
+ """
+ 将 Pandas Series 或 DataFrame 中的 Timestamp 索引转换为字符串,并返回 JSON 可序列化的结果。
+
+ 参数:
+ data: pd.Series 或 pd.DataFrame, 带有时间戳索引的 pandas 数据
+
+ 返回:
+ JSON 字符串,键为日期字符串,值为原数据的值。
+ """
+ # 判断数据类型
+ if isinstance(data, (pd.Series, pd.DataFrame)):
+ # 如果索引是时间戳类型,则转换为 YYYYMMDD 格式
+ if isinstance(data.index, pd.DatetimeIndex):
+ data.index = data.index.strftime('%Y%m%d')
+
+ # 处理 NaN 和 None 的情况,替换为 0 或其他合适的默认值
+ data = data.replace([np.nan, None], 0)
+
+ # 将索引重置为普通列,然后转换为字典
+ json_serializable_data = data.rename_axis('date').reset_index().to_dict(orient='records')
+
+ # 将字典转换为 JSON 格式字符串
+ json_string = json.dumps(json_serializable_data)
+ return json_string
+ else:
+ raise ValueError("输入必须为 Pandas Series 或 DataFrame")
+
+
+async def data_check(field_list: list,
+ stock_list: list,
+ period: str = '1d',
+ start_time: str = '',
+ end_time: str = '',
+ count: int = -1,
+ dividend_type: str = 'none',
+ fill_data: bool = True,
+ data_dir: str = '', ):
+ result_data = xtdata.get_local_data(field_list=[], stock_list=stock_list, period=period, start_time=start_time,
+ end_time=end_time, count=count, dividend_type=dividend_type, fill_data=fill_data,
+ data_dir=data_dir)
+ time_now = int(datetime.now().strftime('%Y%m%d'))
+ for i in stock_list:
+ close = int(result_data.get(i).index[-1])
+ if close != 0 and close < time_now:
+ xtdata.download_history_data(stock_code=i,
+ period='1d',
+ start_time='',
+ end_time='',
+ incrementally=True)
diff --git a/src/constants.py b/src/constants.py
new file mode 100644
index 0000000..e1eab81
--- /dev/null
+++ b/src/constants.py
@@ -0,0 +1,43 @@
+from enum import Enum
+
+THISPOLICYNAME = ''
+def setThisPolicyName(name):
+ THISPOLICYNAME = name
+
+def getThisPolicyName():
+ return THISPOLICYNAME
+
+DB_NAMING_CONVENTION = {
+ "ix": "%(column_0_label)s_idx",
+ "uq": "%(table_name)s_%(column_0_name)s_key",
+ "ck": "%(table_name)s_%(constraint_name)s_check",
+ "fk": "%(table_name)s_%(column_0_name)s_fkey",
+ "pk": "%(table_name)s_pkey",
+}
+
+
+class RedisKeyConstants:
+ SMS_CODE_KEY = "sms:code:"
+ TOKEN_KEY = "Youcailogin:"
+ ACCESS_TOKEN_KEY = "applet_access_token_youcai"
+ SESSION_KEY = "session_id_youcai:"
+ OFFICIAL_ACCESS_TOKEN_KEY = "official_access_token_youcai"
+
+
+class Environment(str, Enum):
+ LOCAL = "LOCAL"
+ STAGING = "STAGING"
+ TESTING = "TESTING"
+ PRODUCTION = "PRODUCTION"
+
+ @property
+ def is_debug(self):
+ return self in (self.LOCAL, self.STAGING, self.TESTING)
+
+ @property
+ def is_testing(self):
+ return self == self.TESTING
+
+ @property
+ def is_deployed(self) -> bool:
+ return self in (self.STAGING, self.PRODUCTION)
diff --git a/readme.md b/src/data_processing/__init__.py
similarity index 100%
rename from readme.md
rename to src/data_processing/__init__.py
diff --git a/src/data_processing/history_data_processing.py b/src/data_processing/history_data_processing.py
new file mode 100644
index 0000000..63c1531
--- /dev/null
+++ b/src/data_processing/history_data_processing.py
@@ -0,0 +1,412 @@
+import asyncio
+from datetime import datetime
+
+import akshare as ak
+import pandas as pd
+from xtquant import xtdata
+
+from src.models import wance_data_stock
+from src.tortoises_orm_config import init_tortoise
+from src.utils.history_data_processing_utils import translation_dict, get_best_match, safe_get_value, on_progress
+from src.utils.split_stock_utils import split_stock_code, join_stock_code, percent_to_float
+from src.xtdata.service import download_history_data_service, get_full_tick_keys_service, download_history_data2_service
+
+# period - 周期,用于表示要获取的周期和具体数据类型
+period_list = ["1d", "1h", "30m", "15m", "5m", "1m", "tick", "1w", "1mon", "1q", "1hy", "1y"]
+# 数据的列名
+columns = ['open', 'high', 'low', 'close', 'volume', 'amount', 'settelmentPrice',
+ 'openInterest', 'preClose', 'suspendFlag']
+
+
+def processing_data(field_list: list, stock_list: list, period: str, start_time: str, end_time: str, count: int,
+ dividend_type: str, fill_data: bool):
+ """
+
+ :param field_list: []
+ :param stock_list: ["186511.SH","173312.SH","231720.SH","173709.SH","019523.SH"]
+ :param period: "1d"
+ :param start_time: "20240506"
+ :param end_time: ""
+ :param count: -1
+ :param dividend_type: "none"
+ :param fill_data: False
+ :return:
+ """
+ try:
+ # 获取本地数据
+ result_local = xtdata.get_local_data(field_list=field_list,
+ stock_list=stock_list,
+ period=period,
+ start_time=start_time,
+ end_time=end_time,
+ count=count,
+ dividend_type=dividend_type,
+ fill_data=fill_data,
+ data_dir=""
+ )
+
+ # 初始化一个空的列表,用于存储每个股票的数据框
+ df_list = []
+ # 遍历字典中的 DataFrame
+ for stock_code, df in result_local.items():
+ # 确保 df 是一个 DataFrame
+ if isinstance(df, pd.DataFrame):
+ # 将时间戳转换为日期时间格式,并格式化为字符串 'YYYYMMDD'
+ df['time'] = pd.to_datetime(df['time'], unit='ms').dt.strftime('%Y%m%d')
+ # 将 'time' 列设置为索引
+ df.set_index('time', inplace=True)
+ # 指定列名
+ df.columns = columns
+ # 添加一列 'stock_code' 用于标识不同的股票
+ df['stock_code'] = stock_code
+ # 将 DataFrame 添加到列表中
+ df_list.append(df)
+ else:
+ print(f"数据格式错误: {stock_code} 不包含 DataFrame")
+
+ # 使用 pd.concat() 将所有 DataFrame 合并为一个大的 DataFrame
+ combined_df = pd.concat(df_list)
+
+ # 打印合并后的 DataFrame
+ print(combined_df)
+
+ print(f"开始获取股票数据{result_local}")
+ except Exception as e:
+ print(f"处理数据发生错误: {str(e)}")
+
+
+def history_data_processing():
+ """
+ 本地路径 D:\\e海方舟-量化交易版\\userdata_mini\\datadir
+ :return:
+ """
+
+ try:
+ result = xtdata.get_full_tick(code_list=['SH', 'SZ'])
+ result = list(result.keys())
+ # for key in result:
+ # xtdata.download_history_data(stock_code=key, period="1d", start_time="", end_time="", incrementally="")
+ datas = xtdata.get_local_data(field_list=[], stock_list=result, period='1d', start_time='', end_time='',
+ count=-1,
+ dividend_type='none', fill_data=True)
+ # datas = xtdata.download_history_data2(stock_list=result, period="1d", start_time="", end_time="",
+ # callback=on_progress)
+ print(datas, "这里是返回的数据")
+ except Exception as e:
+ print(f"处理数据发生错误: {str(e)}")
+
+
+async def init_indicator():
+ await init_tortoise()
+ # 从数据库中获取股票列表
+ stock_list = await wance_data_stock.WanceDataStock.filter(stock_type__contains=["stock"]).all()
+ stock_zh_a_spot_em_df = ak.stock_zh_a_spot_em()
+
+ # 遍历股票列表拿到股票实体
+ for stock in stock_list:
+ try:
+ # 使用 akshare 获取股票指标数据
+ stock_code = stock.stock_code # 假设 stock_code 是股票代码的字段
+ stock_code_suffix = split_stock_code(stock_code) # 提取股票代码部分
+ stock_code_xq = join_stock_code(stock_code_suffix)
+ stock_code_front = stock_code_suffix[0]
+ # 筛选匹配的行
+ match = stock_zh_a_spot_em_df.loc[stock_zh_a_spot_em_df['代码'] == stock_code_front, '涨跌幅']
+
+ # 从akshare中获取数据
+ stock_a_indicator_lg_df = ak.stock_a_indicator_lg(symbol=stock_code_front) # 乐咕乐股-A 股个股指标: 市盈率, 市净率, 股息率
+ stock_financial_abstract_ths_df = ak.stock_financial_abstract_ths(symbol=stock_code_front,
+ indicator="按报告期") # 同花顺-财务指标-主要指标
+ stock_financial_abstract_df = ak.stock_financial_abstract(symbol=stock_code_front) # 新浪财经-财务报表-关键指标
+ stock_individual_spot_xq_df = ak.stock_individual_spot_xq(symbol=stock_code_xq) # 雪球-行情中心-个股
+ stock_zh_valuation_baidu_df = ak.stock_zh_valuation_baidu(symbol=stock_code_front, indicator="市现率",
+ period="近一年") # 百度股市通-A 股-财务报表-估值数据
+ stock_fhps_detail_em_df = ak.stock_fhps_detail_em(symbol=stock_code_front) # 东方财富网-数据中心-分红送配-分红送配详情
+
+ # 查询数据库中是否已有该股票的数据
+ existing_record = await wance_data_stock.WanceDataStock.filter(stock_code=stock_code).first()
+
+ if existing_record is None:
+ print(f"未找到股票记录: {stock_code}")
+ continue
+
+ # 处理并插入数据到数据库
+ last_indicator_row = stock_a_indicator_lg_df.iloc[-1]
+ last_abstract_row = stock_financial_abstract_ths_df.iloc[0]
+ last_financial_row = stock_financial_abstract_df.iloc[:, 2]
+ last_spot_xq_row = stock_individual_spot_xq_df.iloc[:, 1]
+ last_baidu_df_row = stock_zh_valuation_baidu_df.iloc[-1]
+ last_detail_em_row = stock_fhps_detail_em_df.iloc[-1]
+
+ # 更新字段的对应数据
+ # 每股指标模块
+ existing_record.financial_dividend = safe_get_value(last_detail_em_row.get('现金分红-现金分红比例'))
+ existing_record.financial_ex_gratia = safe_get_value(last_spot_xq_row[8])
+ existing_record.financial_cash_flow = safe_get_value(last_financial_row[10])
+ existing_record.financial_asset_value = safe_get_value(last_financial_row[9])
+ existing_record.financial_reserve_per = safe_get_value(last_abstract_row.get('每股资本公积金'))
+ existing_record.financial_undistributed_profit = safe_get_value(last_abstract_row.get('每股未分配利润'))
+
+ # 盈利能力模块
+ existing_record.profit_asset_value = safe_get_value(last_financial_row.iloc[33])
+ existing_record.profit_sale_ratio = safe_get_value(last_financial_row.iloc[43])
+ existing_record.profit_gross_rate = safe_get_value(last_financial_row.iloc[44])
+ existing_record.profit_business_increase = safe_get_value(last_financial_row.iloc[53])
+ existing_record.profit_dividend_rate = safe_get_value(last_spot_xq_row.iloc[26])
+
+ # 成长能力
+ existing_record.growth_Income_rate = percent_to_float(
+ safe_get_value(last_abstract_row.get('营业总收入同比增长率', 0.0)))
+ existing_record.growth_growth_rate = safe_get_value(last_financial_row.iloc[46])
+ existing_record.growth_nonnet_profit = percent_to_float(
+ safe_get_value(last_abstract_row.get('扣非净利润同比增长率', 0.0)))
+ existing_record.growth_attributable_rate = safe_get_value(last_financial_row.iloc[54])
+
+ # 估值指标
+ existing_record.valuation_PEGTTM_ratio = safe_get_value(last_indicator_row.get('pe_ttm'))
+ existing_record.valuation_PEG_percentile = safe_get_value(last_indicator_row.get('pe'))
+ existing_record.valuation_PB_TTM = safe_get_value(last_indicator_row.get('ps'))
+ existing_record.valuation_PB_percentile = safe_get_value(last_indicator_row.get('pb'))
+ existing_record.valuation_PTS_TTM = safe_get_value(last_indicator_row.get('dv_ratio'))
+ existing_record.valuation_PTS_percentile = safe_get_value(last_indicator_row.get('ps_ttm'))
+ existing_record.valuation_market_TTM = safe_get_value(last_baidu_df_row[-1])
+ existing_record.valuation_market_percentile = safe_get_value(1 / last_baidu_df_row[-1] * 100) if \
+ last_baidu_df_row[-1] != 0 else 0
+
+ # 行情指标
+ existing_record.market_indicator = safe_get_value(match.values[0]) if not match.empty else 0
+
+ # 保存更改
+ await existing_record.save()
+
+ print(f"更新股票指标成功!: {stock_code}")
+
+ except Exception as e:
+ print(f"处理股票 {stock_code} 时发生错误: {e}")
+ continue # 继续处理下一个股票
+
+
+async def init_stock_pool(incremental: bool = False):
+ """
+ 初始化股票池参数,包括股票名、股票代码、股票上市时间、股票板块等信息。
+ @param incremental: 是否执行增量下载
+ @type incremental: bool
+ """
+ await init_tortoise()
+
+ # 获取所有现有股票代码
+ existing_stocks = set()
+ if incremental:
+ existing_stocks = {stock.stock_code for stock in await wance_data_stock.WanceDataStock.all()}
+
+ # 初始化股票池
+ tick_result = xtdata.get_full_tick(['SH', 'SZ'])
+ stocks_to_create = [] # 使用一个列表批量创建
+ for key in tick_result.keys():
+ if incremental and key in existing_stocks:
+ print(f"股票代码 {key} 已经存在,跳过...\n")
+ continue
+
+ detail_result = xtdata.get_instrument_detail(key, False)
+ InstrumentName_result = detail_result.get("InstrumentName")
+ start_time = detail_result.get("OpenDate") or detail_result.get("CreateDate")
+ end_time = datetime.now().strftime('%Y%m%d')
+ time_expire = detail_result.get("ExpireDate")
+ type_dict = xtdata.get_instrument_type(key)
+ type_list = []
+ if type_dict:
+ for i in type_dict.keys():
+ type_list.append(i)
+
+ # 只在 type_list 包含 "stock" 时继续执行
+ if "stock" in type_list:
+ # 检查股票是否已存在
+ existing_stock = await wance_data_stock.WanceDataStock.filter(stock_code=key).first()
+ if not existing_stock:
+ stocks_to_create.append(
+ wance_data_stock.WanceDataStock(
+ stock_code=key,
+ stock_name=InstrumentName_result,
+ stock_sector=[], # 初始化为空列表,后续会加入板块
+ stock_type=type_list,
+ time_start=start_time,
+ time_end=end_time,
+ time_expire=time_expire,
+ market_sector=key.rsplit('.', 1)[-1]
+ )
+ )
+ print(f"加载成功 股票名称 {InstrumentName_result} 股票代码 {key} 股票类型 {type_list} \n")
+ else:
+ print(f"股票代码 {key} 已经存在,跳过...\n")
+ else:
+ print(f"跳过非股票类型:{key} 类型:{type_list} \n")
+
+ # 如果有新的股票,批量创建所有新股票记录
+ if stocks_to_create:
+ bulk_db_result = await wance_data_stock.WanceDataStock.bulk_create(stocks_to_create)
+ print(bulk_db_result, "股票池创建完成 \n")
+ else:
+ print("没有新的股票需要创建 \n")
+
+ # 获取并更新sector模块
+ sector_list = xtdata.get_sector_list()
+
+ for sector in sector_list:
+ # 使用模糊匹配找到最佳的中文板块匹配
+ best_match = get_best_match(sector, translation_dict)
+ if best_match:
+ translated_sector = translation_dict[best_match] # 获取对应的英文名称
+ else:
+ print(f"没有找到合适的板块匹配:{sector} \n")
+ continue # 如果没有找到匹配,跳过该板块
+
+ # 获取板块对应的股票列表
+ sector_stock = xtdata.get_stock_list_in_sector(sector)
+ # 获取所有相关股票
+ stocks_to_update = await wance_data_stock.WanceDataStock.filter(stock_code__in=sector_stock)
+
+ # 遍历并更新每个股票的sector,避免重复添加相同的英文板块
+ for stock in stocks_to_update:
+ if translated_sector not in stock.stock_sector: # 检查是否已经存在该英文板块
+ stock.stock_sector.append(translated_sector)
+ await stock.save() # 保存更新后的数据
+ else:
+ print(f"{stock.stock_code} 已经包含板块 {translated_sector}, 跳过重复添加 \n")
+
+ print(f"更新板块完成 {sector}: {translated_sector} \n")
+
+ print(f"所有股票已经加载完成 \n")
+
+
+ """
+ # 初始化股票池参数 股票名、股票代码、股票上市时间、股票板块、股票最后回测时间、股票退市时间、股票所属市场
+ @param incremental:是否执行增量下载
+ @type incremental:bool
+ @return:
+ @rtype:
+ """
+
+
+"""
+async def init_stock_pool(incremental: bool = False):
+ # 这行代码存在一个问题就是,存入的板块信息,是中文的,在查询的时候应为是中文的数据所以没有办法被选中
+
+ await init_tortoise()
+
+ # 获取所有现有股票代码
+ existing_stocks = set()
+ if incremental:
+ existing_stocks = {stock.stock_code for stock in await wance_data_stock.WanceDataStock.all()}
+
+ # 初始化股票池
+ tick_result = xtdata.get_full_tick(['SH', 'SZ'])
+ stocks_to_create = [] # 使用一个列表批量创建
+ for key in tick_result.keys():
+ # 如果是增量更新,且该股票已经存在,则跳过
+ if incremental and key in existing_stocks:
+ continue
+
+ detail_result = xtdata.get_instrument_detail(key, False)
+ InstrumentName_result = detail_result.get("InstrumentName")
+ start_time = detail_result.get("OpenDate") or detail_result.get("CreateDate")
+ end_time = datetime.now().strftime('%Y%m%d')
+ time_expire = detail_result.get("ExporeDate")
+ type_dict = xtdata.get_instrument_type(key)
+ type_list = []
+ if type_dict:
+ for i in type_dict.keys():
+ type_list.append(i)
+
+ stocks_to_create.append(
+ wance_data_stock.WanceDataStock(
+ stock_code=key,
+ stock_name=InstrumentName_result,
+ stock_sector=[],
+ stock_type=type_list,
+ time_start=start_time,
+ time_end=end_time,
+ time_expire=time_expire,
+ market_sector=key.rsplit('.', 1)[-1]
+ )
+ )
+ print(f"加载成功 股票名称 {InstrumentName_result} 股票代码 {key} 股票类型 {type_list} \n")
+
+ # 如果有新的股票,批量创建所有新股票记录
+ if stocks_to_create:
+ bulk_db_result = await wance_data_stock.WanceDataStock.bulk_create(stocks_to_create)
+ print(bulk_db_result, "股票池创建完成 \n")
+ else:
+ print("没有新的股票需要创建 \n")
+
+ # 获取并更新sector模块
+ sector_list = xtdata.get_sector_list()
+ if incremental:
+ # 获取已经更新过的sector
+ updated_sectors = set()
+ # 获取所有已经存在的股票及其相关板块
+ existing_stock_sectors = await wance_data_stock.WanceDataStock.all().values('stock_code', 'stock_sector')
+ for stock_info in existing_stock_sectors:
+ for sector in stock_info['stock_sector']:
+ updated_sectors.add(sector)
+
+ # 过滤出没有更新过的sector
+ sector_list = [sector for sector in sector_list if sector not in updated_sectors]
+
+ for sector in sector_list:
+ sector_stock = xtdata.get_stock_list_in_sector(sector)
+ # 获取所有相关股票
+ stocks_to_update = await wance_data_stock.WanceDataStock.filter(stock_code__in=sector_stock)
+
+ # 遍历并更新每个股票的sector,避免重复添加
+ for stock in stocks_to_update:
+ if sector not in stock.stock_sector:
+ stock.stock_sector.append(sector)
+ await stock.save() # 保存更新后的数据
+ print(f"更新板块完成 {stock.stock_code}: {stock.stock_sector} \n")
+
+ print(f"所有股票已经加载完成 \n")
+"""
+
+"""
+# 每次都会操作数据库
+async def init_stock_pool():
+ await init_tortoise()
+
+ # 初始化股票池
+ tick_result = xtdata.get_full_tick(['SH', 'SZ'])
+ tick_keys = list(tick_result.keys())
+ for key in tick_keys:
+ detail_result = xtdata.get_instrument_detail(key, False)
+ InstrumentName_result = detail_result.get("InstrumentName")
+ type_result = xtdata.get_instrument_type(key)
+ # 创建新的股票记录
+ await wance_data_stock.WanceDataStock.create(
+ stock_code=key,
+ stock_name=InstrumentName_result,
+ stock_sector="",
+ stock_type=type_result
+ )
+
+ # 获取并更新sector模块
+ sector_list = xtdata.get_sector_list()
+ for sector in sector_list:
+ sector_stock = xtdata.get_stock_list_in_sector(sector)
+ for stock in sector_stock:
+ # 更新已有股票的sector
+ await wance_data_stock.WanceDataStock.filter(stock_code=stock).update(stock_sector=sector)
+"""
+
+if __name__ == '__main__':
+ # processing_data(field_list=[],
+ # stock_list=["000062.SZ","600611.SH"],
+ # period="1d",
+ # start_time="20240506",
+ # end_time="",
+ # count=-1,
+ # dividend_type="none",
+ # fill_data=False,
+ # )
+ # history_data_processing()
+ asyncio.run(init_stock_pool(False))
+ asyncio.run(init_indicator())
+# xtdata.run()
diff --git a/src/data_processing/response_factor.py b/src/data_processing/response_factor.py
new file mode 100644
index 0000000..6f6f000
--- /dev/null
+++ b/src/data_processing/response_factor.py
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/exceptions.py b/src/exceptions.py
new file mode 100644
index 0000000..0e45d3f
--- /dev/null
+++ b/src/exceptions.py
@@ -0,0 +1,92 @@
+from typing import Any
+
+from fastapi import FastAPI, HTTPException, Request, status
+from fastapi.encoders import jsonable_encoder
+from fastapi.exceptions import RequestValidationError
+from fastapi.responses import JSONResponse
+from starlette.exceptions import HTTPException as StarletteHTTPException
+
+from src.utils.helpers import first
+
+
+def register_exception_handler(app: FastAPI):
+ @app.exception_handler(RequestValidationError)
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
+ """请求参数验证错误处理"""
+ error = first(exc.errors())
+ field = error.get('loc')[1] or ''
+ return JSONResponse(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ content=jsonable_encoder({
+ "status_code": status.HTTP_422_UNPROCESSABLE_ENTITY,
+ "message": "{} param {}".format(field, error.get('msg')),
+ "detail": exc.errors()}),
+ )
+
+ @app.exception_handler(StarletteHTTPException)
+ async def http_exception_handler(request: Request, exc: StarletteHTTPException):
+ """Http 异常处理"""
+ return JSONResponse(
+ status_code=exc.status_code,
+ content=jsonable_encoder({
+ "status_code": exc.status_code,
+ "message": exc.detail}),
+ )
+
+ @app.exception_handler(Exception)
+ async def exception_callback(request: Request, exc: Exception):
+
+ """其他异常处理,遇到其他问题再自定义异常"""
+ return JSONResponse(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ content=jsonable_encoder({
+ "status_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
+ "message": "Internal Server Error",
+ # "detail": ''.join(exc.args)
+ }),
+ )
+
+
+class CommonHttpException(HTTPException):
+ def __init__(self, detail, status_code, **kwargs: dict[str, Any]) -> None:
+ super().__init__(status_code=status_code, detail=detail, **kwargs)
+
+
+class DetailedHTTPException(HTTPException):
+ STATUS_CODE = status.HTTP_500_INTERNAL_SERVER_ERROR
+ DETAIL = "Server error"
+
+ def __init__(self, **kwargs: dict[str, Any]) -> None:
+ super().__init__(status_code=self.STATUS_CODE, detail=self.DETAIL, **kwargs)
+
+
+class PermissionDenied(DetailedHTTPException):
+ STATUS_CODE = status.HTTP_403_FORBIDDEN
+ DETAIL = "Permission denied"
+
+
+class NotFound(DetailedHTTPException):
+ STATUS_CODE = status.HTTP_404_NOT_FOUND
+
+
+class BadRequest(DetailedHTTPException):
+ STATUS_CODE = status.HTTP_400_BAD_REQUEST
+ DETAIL = "Bad Request"
+
+
+class UnprocessableEntity(DetailedHTTPException):
+ STATUS_CODE = status.HTTP_422_UNPROCESSABLE_ENTITY
+ DETAIL = "Unprocessable entity"
+
+
+class NotAuthenticated(DetailedHTTPException):
+ STATUS_CODE = status.HTTP_401_UNAUTHORIZED
+ DETAIL = "User not authenticated"
+
+
+class WxResponseError(DetailedHTTPException):
+ STATUS_CODE = status.HTTP_400_BAD_REQUEST
+ DETAIL = "请求微信异常"
+
+ def __init__(self) -> None:
+ super().__init__(headers={"WWW-Authenticate": "Bearer"})
diff --git a/src/main.py b/src/main.py
new file mode 100644
index 0000000..9f9cac5
--- /dev/null
+++ b/src/main.py
@@ -0,0 +1,52 @@
+import sentry_sdk
+import uvicorn
+from fastapi import FastAPI
+from starlette.middleware.cors import CORSMiddleware
+
+from src.exceptions import register_exception_handler
+from src.tortoises import register_tortoise_orm
+from src.xtdata.router import router as xtdata_router
+from src.backtest.router import router as backtest_router
+
+from xtquant import xtdata
+from src.settings.config import app_configs, settings
+
+
+
+
+app = FastAPI(**app_configs)
+
+register_tortoise_orm(app)
+
+register_exception_handler(app)
+
+app.include_router(xtdata_router, prefix="/getwancedata", tags=["盘口数据"])
+app.include_router(backtest_router, prefix="/backtest", tags=["盘口数据"])
+
+
+if settings.ENVIRONMENT.is_deployed:
+ sentry_sdk.init(
+ dsn=settings.SENTRY_DSN,
+ environment=settings.ENVIRONMENT,
+ )
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=settings.CORS_ORIGINS,
+ allow_origin_regex=settings.CORS_ORIGINS_REGEX,
+ allow_credentials=True,
+ allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
+ allow_headers=settings.CORS_HEADERS,
+)
+
+
+
+
+@app.get("/")
+async def root():
+ return {"message": "Hello, FastAPI!"}
+
+
+
+if __name__ == "__main__":
+ uvicorn.run('src.main:app', host="0.0.0.0", port=8011, reload=True)
diff --git a/src/models/__init__.py b/src/models/__init__.py
new file mode 100644
index 0000000..508eccf
--- /dev/null
+++ b/src/models/__init__.py
@@ -0,0 +1,292 @@
+from datetime import datetime
+from decimal import Decimal
+from enum import Enum, IntEnum
+from typing import Optional
+
+from pydantic import BaseModel
+
+from src.settings.config import settings
+
+
+def with_table_name(table_name: str):
+ return f"{settings.DATABASE_PREFIX}{table_name}"
+
+
+class Decimals:
+ MAX_DIGITS = 15
+ DECIMAL_PLACES = 3
+ DECIMAL_PLACES_MORE = 6
+
+
+class SecurityPriceCode(Enum):
+ TRANSFER_FEE = "fee_1" # 过户费
+
+ STAMP_TAX = "fee_2" # 印花税
+
+ TOTAL_COMMISSION = "csf_fee" # 总证券佣金
+
+
+class UserSecurityExtraReminded(BaseModel):
+ """
+ reminded_target_buy_price: 最近提醒目标买入价
+ reminded_target_buy_time: 目标买入价最近提醒时间
+ reminded_buy_price: 最近提醒买入价
+ reminded_buy_fluctuation: 最近提醒买入波动
+ reminded_buy_time: 买入最近提醒时间
+ reminded_sell_price: 最近提醒目标卖出价
+ reminded_sell_fluctuation: 最近提醒卖出波动
+ reminded_sell_time: 卖出最近提醒时间
+ buy_up_first_interval: 买入上涨初次提醒波动间隔
+ buy_up_continue_interval: 买入继续上涨提醒波动间隔
+ buy_down_first_interval: 买入下跌初次提醒波动间隔
+ buy_down_continue_interval: 买入继续下跌提醒波动间隔
+ sell_down_first_interval: 卖出下跌初次提醒波动间隔
+ sell_down_continue_interval: 卖出继续下跌提醒波动间隔
+ browse_msg_alarm_flag: 浏览器消息提醒通知设置
+ wx_msg_alarm_flag: 微信公众号消息提醒通知设置
+ """
+ reminded_target_buy_price: int = 0
+ reminded_target_buy_time: Optional[datetime] = None
+ reminded_buy_price: Decimal = 0.000
+ reminded_buy_fluctuation: int = 0
+ reminded_buy_time: Optional[datetime] = None
+ reminded_sell_price: Decimal = 0.000
+ reminded_sell_fluctuation: int = 0
+ reminded_sell_time: Optional[datetime] = None
+ buy_up_first_interval: int = 1
+ buy_up_continue_interval: int = 1
+ buy_down_first_interval: int = -1
+ buy_down_continue_interval: int = -1
+ sell_down_first_interval: int = -1
+ sell_down_continue_interval: int = -1
+ browse_msg_alarm_flag: bool = True
+ wx_msg_alarm_flag: bool = True
+
+ class Config:
+ json_encoders = {Decimal: str}
+
+
+class Status(IntEnum):
+ """
+ 状态
+ 0-正常
+ 1-删除
+ 2-禁用
+ 3-退市
+ """
+ # 正常
+ NORMAL = 0
+ STOP = 1
+ DISABLED = 2
+ DELISTED = 3
+
+
+class StockType(str, Enum):
+ """
+ 股票类型
+ """
+ SH = "sh"
+ SZ = "sz"
+
+
+class UserSecurityAccountType(IntEnum):
+ """
+ 用户中国证券账户账户类型
+ """
+
+ NORMAL = 0 # 普通账户
+
+ MARGIN = 1 # 融资账户
+
+
+class ExchangeType(str, Enum):
+ """
+ 交易类型
+ """
+ SH = "SH"
+ SZ = "SZ"
+ NQ = "NQ"
+ BJ = "BJ"
+
+
+class SellReferType(IntEnum):
+ """
+ 参考卖出类型
+ highest-最高卖出价
+ latest-最近卖出价
+ """
+ HIGHEST = 0 # 最高卖出价
+ LATEST = 1 # 最近卖出价
+
+
+class InvestmentSource(IntEnum):
+ """
+ 投资资金来源
+ """
+ CASH = 0 # 现金
+ FINANCING = 1 # 场内融资
+ LOANED = 2 # 场外借贷
+
+
+class SecuritySource(IntEnum):
+ """
+ 证券获得方式
+ """
+ BUY = 0 # 正常买入
+ BONUS_SHARES = 1 # 送股
+ CONVERT_SHARES = 2 # 转增股
+
+
+class StatusSell(IntEnum):
+ """
+ 售出状态
+ """
+
+ UNSOLD = 0 # 未售出
+ PARTIAL_SOLD = 1 # 部分售出
+ ALL_SOLD = 2 # 全部售出
+
+
+class BuyAndSellType(IntEnum):
+ """
+ 卖出卖出记录类型
+ """
+ COMMISSION = 0 # 委托单
+ CONDITION = 1 # 条件单
+ CONTRACT = 2 # 成交单
+
+
+class SecurityType(IntEnum):
+ """
+ 证券类型
+ """
+
+ security_stock_a = 0 # A 股股票
+
+ security_stock_b = 1 # B 股股票
+
+ security_etf = 2 # 场内基金-ETF
+
+ security_lof = 3 # 场内基金-LOF
+
+
+class PushConfigModel(BaseModel):
+ """
+ 驼峰转下线
+ """
+ device_id: int = None
+ device_name: str = None
+ eable_flag: str = None
+ update_at: Optional[datetime] = None
+
+
+class CashFlowType(IntEnum):
+ INTO = 0 # 转入
+ OUT = 1 # 转出
+
+
+class AdminStatus(str, Enum):
+ """
+ 管理员账户状态
+ """
+
+ DELETE = "delete" # 删除
+ NORMAL = "normal" # 正常
+ FORBIDDEN = "forbidden" # 禁用
+ UNAUTHORIZED = "unauthorized" # 未验证
+
+
+class CustomerSecurityAccountSecurityAdminStatus(str, Enum):
+ """
+ 客户的证券账户中的券商状态
+ """
+ JOIN_GROUP = "join_group" # 新入群
+ ACCOUNT_READY_NOT_DEPOSIT = "account_ready_not_deposit" # 已开户但未入金
+ ACCOUNT_SUBSTANDARD = "account_substandard" # 非有效户(20个交易日里平均市值1万)
+ ACCOUNT_UNSETTLED = "account_unsettled" # 待结算
+ ACCOUNT_SETTLED = "account_settled" # 已结算
+
+
+class CustomerSecurityAccountVipStatus(str, Enum):
+ """
+ 客户的证券账户中的vip状态
+ """
+ REWARDED = "rewarded" # 已奖励
+ UNREWARD = "unreward" # 未奖励
+
+
+class UserBenefitStatus(IntEnum):
+ """
+ 用户福利领取状态
+ """
+ UNRECEIVED = 0 # 未领取
+ RECEIVED = 1 # 已领取
+
+
+class SecurityAdminTransferStatus(IntEnum):
+ """
+ 券商转账的状态
+ """
+ PAYED = 1 # 已支付
+ CONFIRMED = 2 # 已确认
+ DISAGREED = 3 # 有异议
+
+
+class YuanMaTalkQuestionPublicStatus(IntEnum):
+ """
+ 问题是否公开
+ """
+ PUBLIC = 1 # 公开
+ PRIVATE = 0 # 私密
+
+
+class YuanMaTalkQuestionStatus(str, Enum):
+ """
+ 问题状态
+ """
+ NEW = "new" # 新问题
+ REPLIED = "replied" # 已回复
+ WAITING = "waiting" # 待回复
+ CLOSED = "closed" # 关闭
+ DELETED = "deleted" # 删除
+ FORBIDDEN = "forbidden" # 禁止
+
+
+class LastReplySourceType(str, Enum):
+ """
+ 问题状态
+ """
+ ADMIN = "admin" # 管理后台
+ USER = "user" # 用户端
+
+
+class AccountType(IntEnum):
+ """
+ 账户类型
+ """
+ LIANGRONG = 1 # 两融账户
+ ORDINARY = 0 # 普通用户
+
+
+class supervisionFee(IntEnum):
+ """
+ 深市是否申请监管费
+ """
+ NOT_CHARGE = 1 # 不收
+ CHARGE = 0 # 收
+
+
+class transferFee(IntEnum):
+ """
+ 账户类型
+ """
+ NOT_CHARGE = 1 # 不收
+ CHARGE = 0 # 收
+
+
+class IsType(IntEnum):
+ """
+ 是否默认
+ """
+ YES = 1 # 是
+ NO = 0 # 否
diff --git a/src/models/__pycache__/__init__.cpython-311.pyc b/src/models/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..c0f0082
Binary files /dev/null and b/src/models/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/models/__pycache__/back_observed_data.cpython-311.pyc b/src/models/__pycache__/back_observed_data.cpython-311.pyc
new file mode 100644
index 0000000..eaefe67
Binary files /dev/null and b/src/models/__pycache__/back_observed_data.cpython-311.pyc differ
diff --git a/src/models/__pycache__/back_observed_data_detail.cpython-311.pyc b/src/models/__pycache__/back_observed_data_detail.cpython-311.pyc
new file mode 100644
index 0000000..2b335e8
Binary files /dev/null and b/src/models/__pycache__/back_observed_data_detail.cpython-311.pyc differ
diff --git a/src/models/__pycache__/back_position.cpython-311.pyc b/src/models/__pycache__/back_position.cpython-311.pyc
new file mode 100644
index 0000000..83fae24
Binary files /dev/null and b/src/models/__pycache__/back_position.cpython-311.pyc differ
diff --git a/src/models/__pycache__/back_result_indicator.cpython-311.pyc b/src/models/__pycache__/back_result_indicator.cpython-311.pyc
new file mode 100644
index 0000000..66b76ba
Binary files /dev/null and b/src/models/__pycache__/back_result_indicator.cpython-311.pyc differ
diff --git a/src/models/__pycache__/back_trand_info.cpython-311.pyc b/src/models/__pycache__/back_trand_info.cpython-311.pyc
new file mode 100644
index 0000000..621c257
Binary files /dev/null and b/src/models/__pycache__/back_trand_info.cpython-311.pyc differ
diff --git a/src/models/__pycache__/backtest.cpython-311.pyc b/src/models/__pycache__/backtest.cpython-311.pyc
new file mode 100644
index 0000000..5ea8bc4
Binary files /dev/null and b/src/models/__pycache__/backtest.cpython-311.pyc differ
diff --git a/src/models/__pycache__/order.cpython-311.pyc b/src/models/__pycache__/order.cpython-311.pyc
new file mode 100644
index 0000000..d6eb0c3
Binary files /dev/null and b/src/models/__pycache__/order.cpython-311.pyc differ
diff --git a/src/models/__pycache__/position.cpython-311.pyc b/src/models/__pycache__/position.cpython-311.pyc
new file mode 100644
index 0000000..75caafd
Binary files /dev/null and b/src/models/__pycache__/position.cpython-311.pyc differ
diff --git a/src/models/__pycache__/security_account.cpython-311.pyc b/src/models/__pycache__/security_account.cpython-311.pyc
new file mode 100644
index 0000000..48bb598
Binary files /dev/null and b/src/models/__pycache__/security_account.cpython-311.pyc differ
diff --git a/src/models/__pycache__/snowball.cpython-311.pyc b/src/models/__pycache__/snowball.cpython-311.pyc
new file mode 100644
index 0000000..540cc20
Binary files /dev/null and b/src/models/__pycache__/snowball.cpython-311.pyc differ
diff --git a/src/models/__pycache__/stock.cpython-311.pyc b/src/models/__pycache__/stock.cpython-311.pyc
new file mode 100644
index 0000000..b70ea63
Binary files /dev/null and b/src/models/__pycache__/stock.cpython-311.pyc differ
diff --git a/src/models/__pycache__/stock_bt_history.cpython-311.pyc b/src/models/__pycache__/stock_bt_history.cpython-311.pyc
new file mode 100644
index 0000000..4044df9
Binary files /dev/null and b/src/models/__pycache__/stock_bt_history.cpython-311.pyc differ
diff --git a/src/models/__pycache__/stock_data_processing.cpython-311.pyc b/src/models/__pycache__/stock_data_processing.cpython-311.pyc
new file mode 100644
index 0000000..1f6b4f5
Binary files /dev/null and b/src/models/__pycache__/stock_data_processing.cpython-311.pyc differ
diff --git a/src/models/__pycache__/stock_details.cpython-311.pyc b/src/models/__pycache__/stock_details.cpython-311.pyc
new file mode 100644
index 0000000..dae2e16
Binary files /dev/null and b/src/models/__pycache__/stock_details.cpython-311.pyc differ
diff --git a/src/models/__pycache__/stock_history.cpython-311.pyc b/src/models/__pycache__/stock_history.cpython-311.pyc
new file mode 100644
index 0000000..dd84a1a
Binary files /dev/null and b/src/models/__pycache__/stock_history.cpython-311.pyc differ
diff --git a/src/models/__pycache__/strategy.cpython-311.pyc b/src/models/__pycache__/strategy.cpython-311.pyc
new file mode 100644
index 0000000..abe28ae
Binary files /dev/null and b/src/models/__pycache__/strategy.cpython-311.pyc differ
diff --git a/src/models/__pycache__/test_table.cpython-311.pyc b/src/models/__pycache__/test_table.cpython-311.pyc
new file mode 100644
index 0000000..b609422
Binary files /dev/null and b/src/models/__pycache__/test_table.cpython-311.pyc differ
diff --git a/src/models/__pycache__/tran_observer_data.cpython-311.pyc b/src/models/__pycache__/tran_observer_data.cpython-311.pyc
new file mode 100644
index 0000000..1eb549f
Binary files /dev/null and b/src/models/__pycache__/tran_observer_data.cpython-311.pyc differ
diff --git a/src/models/__pycache__/tran_orders.cpython-311.pyc b/src/models/__pycache__/tran_orders.cpython-311.pyc
new file mode 100644
index 0000000..f7d5c7a
Binary files /dev/null and b/src/models/__pycache__/tran_orders.cpython-311.pyc differ
diff --git a/src/models/__pycache__/tran_position.cpython-311.pyc b/src/models/__pycache__/tran_position.cpython-311.pyc
new file mode 100644
index 0000000..ecc3158
Binary files /dev/null and b/src/models/__pycache__/tran_position.cpython-311.pyc differ
diff --git a/src/models/__pycache__/tran_return.cpython-311.pyc b/src/models/__pycache__/tran_return.cpython-311.pyc
new file mode 100644
index 0000000..99ac910
Binary files /dev/null and b/src/models/__pycache__/tran_return.cpython-311.pyc differ
diff --git a/src/models/__pycache__/tran_trade_info.cpython-311.pyc b/src/models/__pycache__/tran_trade_info.cpython-311.pyc
new file mode 100644
index 0000000..c92430b
Binary files /dev/null and b/src/models/__pycache__/tran_trade_info.cpython-311.pyc differ
diff --git a/src/models/__pycache__/trand_info.cpython-311.pyc b/src/models/__pycache__/trand_info.cpython-311.pyc
new file mode 100644
index 0000000..7337dfd
Binary files /dev/null and b/src/models/__pycache__/trand_info.cpython-311.pyc differ
diff --git a/src/models/__pycache__/transaction.cpython-311.pyc b/src/models/__pycache__/transaction.cpython-311.pyc
new file mode 100644
index 0000000..6f4971f
Binary files /dev/null and b/src/models/__pycache__/transaction.cpython-311.pyc differ
diff --git a/src/models/__pycache__/wance_data_stock.cpython-311.pyc b/src/models/__pycache__/wance_data_stock.cpython-311.pyc
new file mode 100644
index 0000000..d6231a5
Binary files /dev/null and b/src/models/__pycache__/wance_data_stock.cpython-311.pyc differ
diff --git a/src/models/__pycache__/wance_data_storage_backtest.cpython-311.pyc b/src/models/__pycache__/wance_data_storage_backtest.cpython-311.pyc
new file mode 100644
index 0000000..596712c
Binary files /dev/null and b/src/models/__pycache__/wance_data_storage_backtest.cpython-311.pyc differ
diff --git a/src/models/back_observed_data.py b/src/models/back_observed_data.py
new file mode 100644
index 0000000..ce6662e
--- /dev/null
+++ b/src/models/back_observed_data.py
@@ -0,0 +1,13 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+
+class BackObservedData(Model):
+ id = fields.IntField(pk=True, description='id表', )
+ key = fields.CharField(max_length=255, null=True, description='key', )
+ observed_data = fields.BinaryField(null=True, description='格式化后的json数据', )
+
+
+ class Meta:
+ table = with_table_name("back_observed_data")
\ No newline at end of file
diff --git a/src/models/back_observed_data_detail.py b/src/models/back_observed_data_detail.py
new file mode 100644
index 0000000..3b23197
--- /dev/null
+++ b/src/models/back_observed_data_detail.py
@@ -0,0 +1,12 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+
+
+class BackObservedDataDetail(Model):
+ id = fields.IntField(pk=True, )
+ key = fields.CharField(max_length=255, )
+ back_observed_data = fields.BinaryField(null=True, )
+
+
+ class Meta:
+ table = with_table_name("back_observed_data_detail")
diff --git a/src/models/back_position.py b/src/models/back_position.py
new file mode 100644
index 0000000..4646eaa
--- /dev/null
+++ b/src/models/back_position.py
@@ -0,0 +1,14 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+
+class BackPosition(Model):
+ id = fields.IntField(pk=True, )
+ key = fields.CharField(max_length=25, null=True, )
+ back_position_data = fields.BinaryField(null=True, )
+
+
+
+ class Meta:
+ table = with_table_name("back_position")
\ No newline at end of file
diff --git a/src/models/back_result_indicator.py b/src/models/back_result_indicator.py
new file mode 100644
index 0000000..ce3b07b
--- /dev/null
+++ b/src/models/back_result_indicator.py
@@ -0,0 +1,14 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+
+
+
+class BackResultIndicator(Model):
+ id = fields.IntField(pk=True, )
+ key = fields.CharField(max_length=25, null=True, )
+ indicator = fields.BinaryField(null=True, )
+
+
+ class Meta:
+ table = with_table_name("back_result_indicator")
+
diff --git a/src/models/back_trand_info.py b/src/models/back_trand_info.py
new file mode 100644
index 0000000..f80f8f9
--- /dev/null
+++ b/src/models/back_trand_info.py
@@ -0,0 +1,13 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+
+
+
+class BackTrandInfo(Model):
+ id = fields.IntField(pk=True, )
+ key = fields.CharField(max_length=25, null=True, )
+ trade_info = fields.BinaryField(null=True, )
+
+
+ class Meta:
+ table = with_table_name("back_trand_info")
\ No newline at end of file
diff --git a/src/models/backtest.py b/src/models/backtest.py
new file mode 100644
index 0000000..6ec1224
--- /dev/null
+++ b/src/models/backtest.py
@@ -0,0 +1,25 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+
+class Backtest(Model):
+ id = fields.IntField(pk=True)
+ strategy: fields.ForeignKeyRelation[Strategy] = fields.ForeignKeyField(
+ model_name="models.Strategy",
+ related_name="integration_strategy",
+ # db_constraint=False,
+ on_delete=fields.CASCADE,
+ index=True,
+ )
+ key = fields.CharField(max_length=20, unique=True, index=True, description="回测key")
+ user_id = fields.IntField(null=True, description="回测用户")
+ backtest_at = fields.DatetimeField(auto_now_add=True, description="回测时间")
+ backtest_code = fields.TextField(description="回测代码")
+ is_running = fields.BooleanField(default=True, description="回测状态")
+ updated_at = fields.DatetimeField(auto_now=True, description="修改时间")
+ deleted_at = fields.DatetimeField(null=True, description="删除时间")
+ is_active = fields.BooleanField(default=True, description="是否可用")
+
+ class Meta:
+ table = with_table_name("Backtest")
diff --git a/src/models/backtest_parameters.py b/src/models/backtest_parameters.py
new file mode 100644
index 0000000..0c1ee0b
--- /dev/null
+++ b/src/models/backtest_parameters.py
@@ -0,0 +1,23 @@
+from pydantic import BaseModel, Field
+from typing import Dict
+
+
+class BenchmarkCodes(BaseModel):
+ """
+ 模型类用于表示基准名称和对应代码的映射。
+ """
+ codes: Dict[str, str] = Field(
+ default_factory=lambda: {
+ "沪深300": "000300", # 使用字符串来表示代码
+ "中证500": "000905"
+ },
+ title="Benchmark Codes",
+ description="A mapping of benchmark names to their respective codes."
+ )
+
+ # 添加一个属性或方法来根据基准名称返回代码
+ async def get_code(self, benchmark_name: str) -> str:
+ if benchmark_name in self.codes:
+ return self.codes[benchmark_name]
+ else:
+ raise ValueError(f"Unknown benchmark name: {benchmark_name}")
diff --git a/src/models/observed_data.py b/src/models/observed_data.py
new file mode 100644
index 0000000..3ba07ae
--- /dev/null
+++ b/src/models/observed_data.py
@@ -0,0 +1,123 @@
+from objectbox.model import Entity, Id, Property, PropertyType
+
+
+@Entity(id=1, uid=1)
+class ObservedData:
+ id = Id(id=1, uid=1001)
+ key = Property(str, id=2, uid=1002)
+ date = Property(str, id=3, uid=1003)
+ time_return = Property(float, type=PropertyType.float, id=4, uid=1004)
+ cash = Property(float, type=PropertyType.float, id=5, uid=1005)
+ benchmark_return = Property(float, type=PropertyType.float, id=6, uid=1006)
+ is_running = Property(bool, id=7, uid=1007)
+ benchmark_ratio = Property(float, type=PropertyType.float, id=8, uid=1008)
+ max_draw_down = Property(float, type=PropertyType.float, id=9, uid=1009)
+ strategy_ratio = Property(float, type=PropertyType.float, id=10, uid=1010)
+
+
+
+@Entity(id=2, uid=2)
+class BackResultIndicator:
+ id = Id(id=5, uid=2001)
+ key = Property(str, id=6, uid=2002)
+ indicator = Property(str, id=7, uid=2003)
+
+
+@Entity(id=3, uid=3)
+class Position:
+ id = Id(id=1, uid=3001)
+ key = Property(str, id=2, uid=3002)
+ date = Property(str, id=3, uid=3003)
+ name = Property(str, id=4, uid=3004)
+ size = Property(int, id=5, uid=3005)
+ price = Property(float, id=6, uid=3006)
+ adjbase = Property(float, id=7, uid=3007)
+ profit_loss = Property(float, id=8, uid=3008)
+
+
+@Entity(id=4, uid=4)
+class TradeInfo:
+ id = Id(id=1, uid=4001)
+ key = Property(str, id=2, uid=4002)
+ stock_code = Property(str, id=3, uid=4003)
+ ordtype = Property(str, id=4, uid=4004)
+ executed_size = Property(int, id=5, uid=4005)
+ executed_price = Property(float, id=6, uid=4006)
+ value = Property(float, id=7, uid=4007)
+ size = Property(int, id=8, uid=4008)
+ price = Property(float, id=9, uid=4009)
+ status = Property(str, id=10, uid=4010)
+ pnl = Property(float, id=11, uid=4011)
+ date = Property(str, id=13, uid=4013)
+ commission = Property(float, id=14, uid=4014)
+
+
+@Entity(id=5, uid=5)
+class TranObserverData:
+ id = Id(id=1, uid=5001)
+ key = Property(str, id=2, uid=5002)
+ date = Property(str, id=3, uid=5003)
+ time_return = Property(float, type=PropertyType.float, id=4, uid=5004)
+ cumulative_return = Property(float, type=PropertyType.float, id=5, uid=5005)
+ trades = Property(float, type=PropertyType.float, id=6, uid=5006)
+ total_revenue = Property(float, type=PropertyType.float, id=7, uid=5007)
+ current_price = Property(float, type=PropertyType.float, id=8, uid=5008)
+ open = Property(float, type=PropertyType.float, id=9, uid=5009)
+ close = Property(float, type=PropertyType.float, id=10, uid=50010)
+ high = Property(float, type=PropertyType.float, id=11, uid=50011)
+ low = Property(float, type=PropertyType.float, id=12, uid=50012)
+ volume = Property(float, type=PropertyType.float, id=13, uid=50013)
+ cash = Property(float, type=PropertyType.float, id=14, uid=50014)
+ annualized_return = Property(float, type=PropertyType.float, id=15, uid=50015)
+
+
+@Entity(id=6, uid=6)
+class TranReturn:
+ id = Id(id=1, uid=6001)
+ key = Property(str, id=2, uid=6002)
+ date = Property(str, id=3, uid=6003)
+ time_return = Property(float, type=PropertyType.float, id=4, uid=6004)
+ benchmark_return = Property(float, type=PropertyType.float, id=5, uid=6005)
+ cumulative_return = Property(float, type=PropertyType.float, id=6, uid=6006)
+ max_draw_down = Property(float, type=PropertyType.float, id=7, uid=6007)
+
+
+@Entity(id=7, uid=7)
+class TranOrders:
+ id = Id(id=1, uid=7001)
+ key = Property(str, id=2, uid=7002)
+ order_return = Property(str, id=3, uid=7003)
+
+
+@Entity(id=8, uid=8)
+class TranTradeInfo:
+ id = Id(id=1, uid=8001)
+ key = Property(str, id=2, uid=8002)
+ stock_code = Property(str, id=3, uid=8003)
+ ordtype = Property(str, id=4, uid=8004)
+ executed_size = Property(int, id=5, uid=8005)
+ executed_price = Property(float, id=6, uid=8006)
+ value = Property(float, id=7, uid=8007)
+ size = Property(int, id=8, uid=8008)
+ price = Property(float, id=9, uid=8009)
+ status = Property(str, id=10, uid=8010)
+ pnl = Property(float, id=11, uid=8011)
+ date = Property(str, id=13, uid=8013)
+ commission = Property(float, id=14, uid=8014)
+ exec_type = Property(str, id=15, uid=8015)
+ trail_stop = Property(float, id=16, uid=8016)
+ trail_price_limit = Property(float, id=17, uid=8017)
+
+
+@Entity(id=9, uid=9)
+class TranData:
+ id = Id(id=1, uid=9001)
+ key = Property(str, id=2, uid=9002)
+ alpha = Property(float,type=PropertyType.float, id=3,uid=9003)
+ beta = Property(float, type=PropertyType.float, id=4,uid=9004)
+ sharpe_ratio = Property(float, id=5,uid=9005)
+ start_date = Property(str,id=6, uid=9006)
+ close_date = Property(str, id=7, uid=9007)
+ use_date = Property(str, id=8, uid=9008)
+ filename = Property(str, id=11, uid=1011)
+ day_alpha_array = Property(str,id=12,uid=1012)
\ No newline at end of file
diff --git a/src/models/order.py b/src/models/order.py
new file mode 100644
index 0000000..a428d5b
--- /dev/null
+++ b/src/models/order.py
@@ -0,0 +1,68 @@
+from tortoise.models import Model
+from tortoise import fields, models
+from tortoise.contrib.pydantic import pydantic_model_creator
+
+from src.models import with_table_name, AccountType
+
+
+class Order(Model):
+ id = fields.IntField(pk=True, description="订单ID")
+ stock_code = fields.CharField(max_length=30, description="股票代码")
+ stock_name = fields.CharField(max_length=30, description="股票名称")
+ limit_price = fields.FloatField(description="限价")
+ order_quantity = fields.IntField(description="委托数量")
+ order_amount = fields.FloatField(description="委托金额")
+ order_type = fields.CharField(max_length=20, description="订单类型") # 例如:买入、卖出
+ position = fields.CharField(max_length=30, description="仓位")
+ user_id = fields.IntField(null=False, description="用户Id")
+ entrust_date = fields.DatetimeField(null=True, description="委托日期")
+ entrust_time = fields.TimeField(null=True)
+
+ class Meta:
+ table = with_table_name("orders")
+
+
+OrderResponse = pydantic_model_creator(
+ Order
+)
+
+
+class Entrust(Model):
+ id = fields.IntField(pk=True, description="主键id")
+ fund_account = fields.IntField(null=True, description="资金账户")
+ securities_alias = fields.CharField(max_length=30, null=True, description="账户别名")
+ account_type = fields.IntEnumField(AccountType, null=True, description="账户类型")
+ stock_code = fields.CharField(max_length=30, description="证券代码")
+ stock_name = fields.CharField(max_length=30, null=False, description="证券名称")
+ limit_price = fields.FloatField(description="委托价")
+ entrust_number = fields.IntField(description="委托数量")
+ deal_price = fields.FloatField(null=True, max_digits=10, description="成交价格")
+ deal_number = fields.IntField(null=True, description="成家数量")
+ order_type = fields.CharField(max_length=30, description="操作") # 买入卖出
+ entrust_date = fields.DateField(null=True, description="委托日期")
+ entrust_money = fields.FloatField(description="委托金额")
+ # Python
+ is_repair = fields.BooleanField(default=False, index=True, description="是否补单")
+
+ entrust_time = fields.TimeField(null=True) # 委托时间
+
+ class Meta:
+ table = with_table_name("Entrust")
+
+
+EntrustResponse = pydantic_model_creator(
+ Entrust
+)
+
+
+class Backtesting(Model):
+ id = fields.IntField(pk=True, description="主键id")
+ key = fields.CharField(max_length=30, description="回测key")
+
+ class Meta:
+ table = with_table_name("backtest")
+
+
+BacktestingResponse = pydantic_model_creator(
+ Backtesting
+)
diff --git a/src/models/page_Info.py b/src/models/page_Info.py
new file mode 100644
index 0000000..35e7293
--- /dev/null
+++ b/src/models/page_Info.py
@@ -0,0 +1,65 @@
+from datetime import datetime
+
+from pydantic import BaseModel, Field
+from typing import Optional, List
+
+from src.pydantics.transaction import TransactionsPydantic, TransactionPydantic
+
+
+class PageInfo(BaseModel):
+ page: Optional[int]
+ pageSize: Optional[int]
+ searchname: Optional[str]
+
+
+
+# class PageResponse(BaseModel):
+# current_page: int
+# page_size: int
+# total_pages: int
+# total_items: int
+# items: List[dict] # 使用 List[dict] 类型
+#
+# # 单个事务的 Pydantic 模型
+# class TransactionPagePydantic(BaseModel):
+# id: int
+# key: str
+# cash: float
+# transaction_name: str
+# transaction_type: Optional[str] # 可以为 None
+# user_id: Optional[int] = None
+# is_running: bool
+# is_deleted: bool
+# process_id: Optional[int] = None
+# bar: Optional[str] = None
+# created_at: datetime
+# updated_at: datetime
+# stopped_at: Optional[str] = None
+# deleted_at: Optional[str] = None
+# strategy_id: int
+#
+# class Config:
+# orm_mode = True
+# from_attributes = True
+#
+# # 分页的 Pydantic 模型
+# class PageData(BaseModel):
+# results: List[TransactionPydantic]
+# total: int
+# page: int
+# pages: int
+# size: int
+# # next: Optional[str] = None
+# # previous: Optional[str] = None
+# # total_pages: int
+#
+# # 包装响应的 Pydantic 模型
+# class EntityResponse(BaseModel):
+# status_code: int
+# message: str
+# data: PageData
+
+
+
+
+
diff --git a/src/models/position.py b/src/models/position.py
new file mode 100644
index 0000000..16991fe
--- /dev/null
+++ b/src/models/position.py
@@ -0,0 +1,36 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+
+class Position(Model):
+ id = fields.IntField(pk=True, description="主键")
+ key = fields.CharField(max_length=30, description="键")
+ date = fields.DateField(description="日期")
+ name = fields.CharField(max_length=30, description="名称")
+ size = fields.IntField(description="数量")
+ price = fields.FloatField(description="价格")
+ adjbase = fields.FloatField(description="复权基数")
+ profit = fields.FloatField(description="利润")
+
+ class Meta:
+ table = with_table_name("position")
+
+
+def create(id,pk,key, date, name, size, price, adjbase, profit):
+ try:
+ position = Position(id=id,pk=pk,key=key, date=date, name=name, size=size, price=price, adjbase=adjbase, profit=profit)
+ position.save()
+ return position
+ except Exception as e:
+ print(f"Position create error: {str(e)}")
+ return None
+
+
+def bulk_create(param):
+ try:
+ Position.bulk_create(param)
+ return "Position bulk create success"
+ except Exception as e:
+ print(f"Position bulk create error: {str(e)}")
+ return None
\ No newline at end of file
diff --git a/src/models/security_account.py b/src/models/security_account.py
new file mode 100644
index 0000000..1ff3711
--- /dev/null
+++ b/src/models/security_account.py
@@ -0,0 +1,31 @@
+from tortoise import Model, fields
+
+from src.models import with_table_name, AccountType, supervisionFee, transferFee, IsType
+
+from tortoise.contrib.pydantic import pydantic_model_creator
+
+
+class SecurityAccount(Model):
+ """
+ 证券账户
+ """
+ id = fields.IntField(pk=True, description="主键")
+ securities_name = fields.CharField(max_length=30, null=False, description="证券公司名字昵称")
+
+ fund_account = fields.BigIntField(null=True, description="资金账户")
+
+
+ account_alias = fields.CharField(max_length=30, null=False, description="账户别名")
+
+ money = fields.FloatField(null=True, description="账户金额")
+ available_money = fields.FloatField(null=True, description="可用金额")
+ available_proportion = fields.FloatField(null=True, description="可用资金占比")
+ freeze = fields.FloatField(null=True, description="冻结金额")
+
+ class Meta:
+ table = with_table_name("security_account")
+
+
+SecurityAccountResponse = pydantic_model_creator(
+ SecurityAccount
+)
diff --git a/src/models/snowball.py b/src/models/snowball.py
new file mode 100644
index 0000000..767819e
--- /dev/null
+++ b/src/models/snowball.py
@@ -0,0 +1,21 @@
+from tortoise import Model, fields
+
+from src.models import with_table_name
+
+from tortoise.contrib.pydantic import pydantic_model_creator
+
+
+class Snowball(Model):
+ """
+ 雪球相关信息
+ """
+ id = fields.IntField(pk=True, description="主键")
+ snowball_token = fields.CharField(max_length=10000,null=True, description="雪球用户的token")
+
+ class Meta:
+ table = with_table_name("snowball")
+
+
+SnowballResponse = pydantic_model_creator(
+ Snowball
+)
diff --git a/src/models/stock.py b/src/models/stock.py
new file mode 100644
index 0000000..74fdeea
--- /dev/null
+++ b/src/models/stock.py
@@ -0,0 +1,23 @@
+from tortoise import Model, fields
+from tortoise.contrib.pydantic import pydantic_model_creator
+
+from src.models import with_table_name, StockType
+
+
+class Stock(Model):
+ """
+ 股票相关信息
+ """
+ id = fields.IntField(pk=True, description="主键")
+ stock_code = fields.CharField(max_length=30, description="股票代码")
+ stock_name = fields.CharField(max_length=30, null=True, description="股票名称")
+ type = fields.CharEnumField(StockType, null=True, description="类型")
+ stock_pinyin = fields.CharField(max_length=30, description="股票拼音")
+
+ class Meta:
+ table = with_table_name("stock")
+
+
+StockResponse = pydantic_model_creator(
+ Stock
+)
diff --git a/src/models/stock_bt_history.py b/src/models/stock_bt_history.py
new file mode 100644
index 0000000..011670e
--- /dev/null
+++ b/src/models/stock_bt_history.py
@@ -0,0 +1,17 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+
+class StockBtHistory(Model):
+ id = fields.IntField(pk=True, )
+ end_bt_time = fields.CharField(max_length=8, null=True,description='回测最终时间', )
+ bt_stock_code = fields.CharField(max_length=8, null=True,description='回测股票代码', )
+ bt_stock_name = fields.CharField(max_length=10, null=True, description='回测股票名称', )
+ bt_benchmark_code = fields.CharField(max_length=8, null=True,description='股票基准代码', )
+ bt_stock_period = fields.CharField(max_length=10, null=True, description='回测类型', )
+ bt_strategy_name = fields.CharField(max_length=10, null=True, description='回测策略名称', )
+ bt_stock_data = fields.BinaryField(null=True, description='回测股票数据', )
+
+
+
+ class Meta:
+ table = with_table_name("stock_bt_history")
\ No newline at end of file
diff --git a/src/models/stock_data_processing.py b/src/models/stock_data_processing.py
new file mode 100644
index 0000000..bcf16f7
--- /dev/null
+++ b/src/models/stock_data_processing.py
@@ -0,0 +1,19 @@
+from tortoise import Model, fields
+from tortoise.contrib.pydantic import pydantic_model_creator
+
+from src.models import with_table_name, StockType
+
+
+class StockDataProcessing(Model):
+ bt_benchmark_code = fields.CharField(max_length=6, null=True, description='基准代码', )
+ bt_stock_period = fields.CharField(max_length=10, null=True, description='数据类型', )
+ bt_strategy_name = fields.CharField(max_length=10, null=True, description='回测策略名', )
+ id = fields.IntField(pk=True, )
+ processing_data = fields.BinaryField(null=True, description='清洗后的数据', )
+ prosessing_date = fields.CharField(max_length=10, null=True, description='当前回测时间', )
+ stock_name = fields.CharField(max_length=10, null=True, )
+ stocke_code = fields.CharField(max_length=10, null=True, )
+
+
+ class Meta:
+ table = with_table_name("stock_data_processing")
diff --git a/src/models/stock_details.py b/src/models/stock_details.py
new file mode 100644
index 0000000..2f53022
--- /dev/null
+++ b/src/models/stock_details.py
@@ -0,0 +1,25 @@
+from tortoise import Model, fields
+from tortoise.contrib.pydantic import pydantic_model_creator
+
+from src.models import with_table_name, StockType
+
+
+class StockDetails(Model):
+ """
+ 股票相关信息
+ """
+ id = fields.IntField(pk=True, description="主键")
+ stock_code = fields.CharField(max_length=30, description="股票代码")
+ stock_name = fields.CharField(max_length=30, null=True, description="股票名称")
+ type = fields.CharEnumField(StockType, null=True, description="类型")
+ stock_pinyin = fields.CharField(max_length=30, description="股票拼音")
+ latest_price = fields.FloatField(null=True, description="最新价")
+ rise_fall = fields.FloatField(null=True, description="跌涨幅")
+
+ class Meta:
+ table = with_table_name("stock_details")
+
+
+StockDetailsResponse = pydantic_model_creator(
+ StockDetails
+)
diff --git a/src/models/stock_history.py b/src/models/stock_history.py
new file mode 100644
index 0000000..50ba323
--- /dev/null
+++ b/src/models/stock_history.py
@@ -0,0 +1,13 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+class StockHistory(Model):
+ id = fields.IntField(pk=True, )
+ stock_code = fields.IntField(description='股票代码', )
+ stock_name = fields.CharField(max_length=10, null=True, )
+ start_time_to_market = fields.CharField(max_length=10, null=True, description='股票上市时间', )
+ end_bt_time = fields.CharField(max_length=10, null=True, description='最终回测时间', )
+ symbol_data = fields.BinaryField(null=True, description='股票数据', )
+
+
+ class Meta:
+ table = with_table_name("stock_history")
\ No newline at end of file
diff --git a/src/models/strategy.py b/src/models/strategy.py
new file mode 100644
index 0000000..fc01fdd
--- /dev/null
+++ b/src/models/strategy.py
@@ -0,0 +1,19 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+
+
+class Strategy(Model):
+ id = fields.IntField(pk=True, description='主键')
+ strategy_name = fields.CharField(max_length=255, description='策略名称')
+ strategy_hash = fields.CharField(max_length=255, description='策略版本号')
+ strategy_type = fields.CharField(max_length=255, null=True, description='策略类型')
+ user_id = fields.IntField(null=True, description='所属用户')
+ backtest_count = fields.IntField(null=True, description='回测次数')
+ backtest_keys = fields.JSONField(null=True, description='回测key列表')
+ is_deleted = fields.BooleanField(max_length=255, default=False, description='是否删除')
+ created_at = fields.DatetimeField(auto_now_add=True, description="创建时间")
+ updated_at = fields.DatetimeField(auto_now=True, description="修改时间")
+ deleted_at = fields.DatetimeField(null=True, description="删除时间")
+
+ class Meta:
+ table = with_table_name("Strategy")
diff --git a/src/models/test_table.py b/src/models/test_table.py
new file mode 100644
index 0000000..50caed8
--- /dev/null
+++ b/src/models/test_table.py
@@ -0,0 +1,35 @@
+from tortoise import Model, fields
+from tortoise.contrib.pydantic import pydantic_model_creator
+
+from src.models import with_table_name
+
+
+class TestTable(Model):
+ """
+ 用户
+ """
+ id = fields.IntField(pk=True, description="主键")
+ nickname = fields.CharField(max_length=30, null=True, description="用户昵称")
+ avatar_url = fields.CharField(max_length=255, null=True, description="头像")
+ member_type = fields.IntField(null=True, description="会员类型")
+ beta_account_type = fields.CharField(null=True, max_length=30, description="内测账号类型")
+ pre_cost_time = fields.IntField(null=True, description="预支付时间(单位年)")
+
+ qr_code = fields.CharField(max_length=255, null=True, description="专属客服二维码")
+ dedicated_id = fields.IntField(null=True, description="专属客服id")
+
+ invited_user_id = fields.IntField(null=True, description="邀请人id")
+ created_user_id = fields.IntField(null=True, description="创建人id")
+ is_deleted = fields.BooleanField(default=False, index=True, description="是否删除")
+ login_at = fields.DatetimeField(null=True, description="最后一次登录时间")
+ created_at = fields.DatetimeField(auto_now_add=True, description="创建时间")
+ updated_at = fields.DatetimeField(auto_now=True, description="修改时间")
+ deleted_at = fields.DatetimeField(null=True, description="删除时间")
+
+ class Meta:
+ table = with_table_name("users")
+
+
+UserResponse = pydantic_model_creator(
+ TestTable
+)
diff --git a/src/models/tran_observer_data.py b/src/models/tran_observer_data.py
new file mode 100644
index 0000000..01d145e
--- /dev/null
+++ b/src/models/tran_observer_data.py
@@ -0,0 +1,13 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+
+class TranObserverData(Model):
+ id = fields.IntField(pk=True, )
+ key = fields.CharField(max_length=255, )
+ tran_observer_data = fields.BinaryField(null=True, description='存储大量数据', )
+
+ class Meta:
+ table = with_table_name("tran_observer_data")
+
diff --git a/src/models/tran_orders.py b/src/models/tran_orders.py
new file mode 100644
index 0000000..868ea50
--- /dev/null
+++ b/src/models/tran_orders.py
@@ -0,0 +1,14 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+
+
+
+class TranOrders(Model):
+ id = fields.IntField(pk=True, )
+ key = fields.CharField(max_length=255, null=True, )
+ order_return = fields.CharField(max_length=255, null=True, )
+
+class meta:
+ table = with_table_name('tradorders')
\ No newline at end of file
diff --git a/src/models/tran_position.py b/src/models/tran_position.py
new file mode 100644
index 0000000..8e1f6ee
--- /dev/null
+++ b/src/models/tran_position.py
@@ -0,0 +1,13 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+
+
+class TranPosition(Model):
+ id = fields.IntField(pk=True, )
+ key = fields.CharField(max_length=25, null=True, )
+ tran_position_data = fields.BinaryField(null=True, )
+
+ class Meta:
+ table = with_table_name("tran_position")
diff --git a/src/models/tran_return.py b/src/models/tran_return.py
new file mode 100644
index 0000000..f89f4a3
--- /dev/null
+++ b/src/models/tran_return.py
@@ -0,0 +1,13 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+
+class TranReturn(Model):
+ id = fields.IntField(pk=True, )
+ key = fields.CharField(max_length=255, null=True, )
+ tran_return_data = fields.BinaryField(null=True, )
+
+ class Meta:
+ table = with_table_name('tran_return')
+
diff --git a/src/models/tran_trade_info.py b/src/models/tran_trade_info.py
new file mode 100644
index 0000000..dc3e742
--- /dev/null
+++ b/src/models/tran_trade_info.py
@@ -0,0 +1,12 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+
+class TranTradeInfo(Model):
+ id = fields.IntField(pk=True, )
+ key = fields.CharField(max_length=25, null=True, )
+ tran_trade_info = fields.BinaryField(null=True, )
+
+ class Meta:
+ table = with_table_name("tran_trade_info")
diff --git a/src/models/trand_info.py b/src/models/trand_info.py
new file mode 100644
index 0000000..dd836c6
--- /dev/null
+++ b/src/models/trand_info.py
@@ -0,0 +1,30 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+class TrandInfo(Model):
+ id = fields.IntField(pk=True, description='id', )
+ key = fields.CharField(max_length=255, null=True, description='唯一索引', )
+ tran_info_data = fields.BinaryField(null=True, )
+
+
+ class Meta:
+ table = with_table_name("trand_info")
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/models/transaction.py b/src/models/transaction.py
new file mode 100644
index 0000000..4cf6516
--- /dev/null
+++ b/src/models/transaction.py
@@ -0,0 +1,30 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+from src.models.strategy import Strategy
+
+
+class Transaction(Model):
+ id = fields.IntField(pk=True, description='主键')
+ key = fields.CharField(max_length=20, description='key,数据标识')
+ cash = fields.FloatField(description='资金', null=True)
+ transaction_name = fields.CharField(max_length=255, description='交易名称')
+ transaction_type = fields.CharField(max_length=255, null=True, description='交易类型')
+ user_id = fields.IntField(null=True, description='用户id')
+ is_running = fields.BooleanField(max_length=255, default=True, description='运行状态')
+ is_deleted = fields.BooleanField(max_length=255, default=False, description='是否删除')
+ process_id = fields.IntField(null=True, description='进程号')
+ bar = fields.CharField(max_length=10, description='频率K线', null=True)
+ created_at = fields.DatetimeField(auto_now_add=True, description="创建时间")
+ updated_at = fields.DatetimeField(auto_now=True, description="修改时间")
+ stopped_at = fields.DatetimeField(null=True, description="停止时间")
+ deleted_at = fields.DatetimeField(null=True, description="删除时间")
+ strategy: fields.ForeignKeyRelation[Strategy] = fields.ForeignKeyField(
+ model_name="models.Strategy",
+ related_name="transaction_strategy",
+ # db_constraint=False,
+ on_delete=fields.CASCADE,
+ index=True,
+ )
+
+ class Meta:
+ table = with_table_name("Transaction")
diff --git a/src/models/transaction_strategy.py b/src/models/transaction_strategy.py
new file mode 100644
index 0000000..f472707
--- /dev/null
+++ b/src/models/transaction_strategy.py
@@ -0,0 +1,19 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+
+
+class Strategy(Model):
+ id = fields.IntField(pk=True, description='主键')
+ strategy_name = fields.CharField(max_length=255, description='策略名称')
+ strategy_hash = fields.CharField(max_length=255, description='策略版本号')
+ strategy_type = fields.CharField(max_length=255, null=True, description='策略类型')
+ user_id = fields.IntField(null=True, description='所属用户')
+ backtest_count = fields.IntField(null=True, description='回测次数')
+ backtest_keys = fields.JSONField(null=True, description='回测key列表')
+ is_deleted = fields.BooleanField(max_length=255, default=False, description='是否删除')
+ created_at = fields.DatetimeField(auto_now_add=True, description="创建时间")
+ updated_at = fields.DatetimeField(auto_now=True, description="修改时间")
+ deleted_at = fields.DatetimeField(null=True, description="删除时间")
+
+ class Meta:
+ table = with_table_name("strategy")
diff --git a/src/models/user.py b/src/models/user.py
new file mode 100644
index 0000000..97f2195
--- /dev/null
+++ b/src/models/user.py
@@ -0,0 +1,27 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+
+
+
+
+class Users(Model):
+ avatar_url = fields.CharField(max_length=255, null=True, )
+ beta_account_type = fields.CharField(max_length=30, null=True, )
+ created_at = fields.DatetimeField(auto_now_add=True, )
+ created_user_id = fields.IntField()
+ dedicated_id = fields.IntField()
+ deleted_at = fields.DatetimeField(null=True, )
+ id = fields.IntField(pk=True, )
+ invitation_code = fields.CharField(max_length=10, null=True, )
+ invited_user_id = fields.IntField()
+ is_deleted = fields.IntField(index=True, )
+ login_at = fields.DatetimeField(null=True, )
+ member_type = fields.IntField()
+ nickname = fields.CharField(max_length=30, null=True, )
+ pre_cost_time = fields.IntField()
+ qr_code = fields.CharField(max_length=255, null=True, )
+ status = fields.BooleanField(null=True, default=False, )
+ updated_at = fields.DatetimeField(auto_now_add=True, )
+
+ class Meta:
+ table = with_table_name("users")
\ No newline at end of file
diff --git a/src/models/wance_data_stock.py b/src/models/wance_data_stock.py
new file mode 100644
index 0000000..17d6e63
--- /dev/null
+++ b/src/models/wance_data_stock.py
@@ -0,0 +1,43 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+
+
+class WanceDataStock(Model):
+ id = fields.IntField(pk=True, )
+ stock_name = fields.CharField(max_length=50, null=True, description='股票名称', )
+ stock_code = fields.CharField(max_length=50, null=True, description='股票代码', )
+ stock_sector = fields.JSONField(null=True, description='股票板块', )
+ stock_type = fields.JSONField(null=True, description='股票类型', )
+ time_start = fields.CharField(max_length=10, null=True, description='上市时间', )
+ time_expire = fields.CharField(max_length=10, null=True, description='退市时间', )
+ time_end = fields.CharField(max_length=10, null=True, description='上一次回测结束时间', )
+ market_sector = fields.CharField(max_length=20, null=True, description='所属市场', )
+ financial_dividend = fields.FloatField(null=True, default=0, description='分红率', )
+ financial_ex_gratia = fields.FloatField(null=True, default=0, description='扣非后每股收益', )
+ financial_cash_flow = fields.FloatField(null=True, default=0, description='每股现金流', )
+ financial_asset_value = fields.FloatField(null=True, default=0, description='每股净资产', )
+ financial_reserve_per = fields.FloatField(null=True, default=0, description='每股资本公积金', )
+ financial_undistributed_profit = fields.FloatField(null=True, default=0, description='每股未分配利润', )
+ profit_asset_value = fields.FloatField(null=True, default=0, description='净资产收益率', )
+ profit_sale_ratio = fields.FloatField(null=True, default=0, description='盈利销售净利率', )
+ profit_gross_rate = fields.FloatField(null=True, default=0, description='销售毛利率', )
+ profit_business_increase = fields.FloatField(null=True, default=0, description='营业收入增长率', )
+ profit_dividend_rate = fields.FloatField(null=True, default=0, description='股息率', )
+ growth_Income_rate = fields.FloatField(null=True, default=0, description='营业总收入同比增长率', )
+ growth_growth_rate = fields.FloatField(null=True, default=0, description='营业利润同比增长率', )
+ growth_nonnet_profit = fields.FloatField(null=True, default=0, description='扣非净利润增长率', )
+ growth_attributable_rate = fields.FloatField(null=True, default=0, description='归母净利润同比增长率', )
+ valuation_PEGTTM_ratio = fields.FloatField(null=True, default=0, description='市盈率TTM', )
+ valuation_PEG_percentile = fields.FloatField(null=True, default=0, description='市盈率百分位', )
+ valuation_PB_TTM = fields.FloatField(null=True, default=0, description='市净率TTM', )
+ valuation_PB_percentile = fields.FloatField(null=True, default=0, description='市净率百分比', )
+ valuation_PTS_TTM = fields.FloatField(null=True, default=0, description='市销率TTM', )
+ valuation_PTS_percentile = fields.FloatField(null=True, default=0, description='市销率百分位', )
+ valuation_market_TTM = fields.FloatField(null=True, default=0, description='市现率TTM', )
+ valuation_market_percentile = fields.FloatField(null=True, default=0, description='市现率百分比', )
+ market_indicator = fields.FloatField(null=True, default=0, description='行情指标', )
+
+
+
+ class Meta:
+ table = with_table_name("wance_data_stock")
\ No newline at end of file
diff --git a/src/models/wance_data_storage_backtest.py b/src/models/wance_data_storage_backtest.py
new file mode 100644
index 0000000..02914cc
--- /dev/null
+++ b/src/models/wance_data_storage_backtest.py
@@ -0,0 +1,67 @@
+from tortoise import Model, fields
+from src.models import with_table_name
+
+class WanceDataStorageBacktest(Model):
+ avg_down_month = fields.FloatField(null=True, default=0, description='回测期间每个月下跌时的平均损失', )
+ avg_drawdown = fields.FloatField(null=True, default=0, description='平均回撤,表示每次回撤的平均幅度', )
+ avg_drawdown_days = fields.FloatField(null=True, default=0, description='平均回撤持续的天数', )
+ avg_up_month = fields.FloatField(null=True, default=0, description='回测期间每个月上涨时的平均收益', )
+ backtest_end_time = fields.IntField(description='回测结束时间', )
+ backtest_name = fields.CharField(max_length=100, null=True, )
+ best_day = fields.FloatField(null=True, default=0, description='回测期间单日的最大回报', )
+ best_month = fields.FloatField(null=True, default=0, description='回测期间的最佳月份回报', )
+ best_year = fields.FloatField(null=True, default=0, description='回测期间的最佳年度回报(若为空,表示回测不足一年)', )
+ cagr = fields.FloatField(null=True, default=0, description='年化复合增长率,表示投资在整个回测期间的年均增长率', )
+ calmar = fields.FloatField(null=True, default=0, description='Calmar 比率,表示年化收益率与最大回撤的比率,用于衡量风险调整后的收益。Calmar 比率越高,意味着风险调整后的表现越好', )
+ daily_kurt = fields.FloatField(null=True, default=0, description='每日回报的峰度,表示回报分布的尖锐程度', )
+ daily_mean = fields.FloatField(null=True, default=0, description='每日的平均回报率', )
+ daily_price = fields.JSONField(null=True, description='每日的价格', )
+ daily_sharpe = fields.FloatField(null=True, default=0, description='日频率的夏普比率,表示每单位风险所获得的超额收益', )
+ daily_skew = fields.FloatField(null=True, default=0, description='每日回报的偏度,表示回报分布的对称性', )
+ daily_sortino = fields.FloatField(null=True, default=0, description='日频率的 Sortino 比率,衡量投资组合的下行风险,考虑的是负波动率', )
+ daily_vol = fields.FloatField(null=True, default=0, description='每日回报率的标准差(波动率),表示收益的波动性', )
+ data_end_time = fields.CharField(max_length=10, null=True, description='回测数据结束时间', )
+ data_start_time = fields.CharField(max_length=10, null=True, description='回测数据开始时间', )
+ five_year = fields.FloatField(null=True, default=0, description='过去五年的回报', )
+ id = fields.IntField(pk=True, )
+ incep = fields.FloatField(null=True, default=0, description='自策略开始运行以来的年化回报率(如果策略没有完整的一年,则等于 cagr)', )
+ indicator_information = fields.JSONField(null=True, description='指标信息', )
+ indicator_type = fields.CharField(max_length=20, null=True, description='指标类型', )
+ max_drawdown = fields.FloatField(null=True, default=0, description='最大回撤,表示投资组合在回测期间从最高点到最低点的最大亏损百分比', )
+ monthly_kurt = fields.FloatField(null=True, default=0, description='月度回报的峰度', )
+ monthly_mean = fields.FloatField(null=True, default=0, description='月度平均回报率', )
+ monthly_sharpe = fields.FloatField(null=True, default=0, description='月频率的夏普比率', )
+ monthly_skew = fields.FloatField(null=True, default=0, description='月度回报的偏度', )
+ monthly_sortino = fields.FloatField(null=True, default=0, description='月频率的 Sortino 比率', )
+ monthly_vol = fields.FloatField(null=True, default=0, description='月度回报率的标准差(波动率)', )
+ mtd = fields.FloatField(null=True, default=0, description='月内截至目前的回报', )
+ one_year = fields.FloatField(null=True, default=0, description='的回报', )
+ position = fields.JSONField(null=True, description='持仓记录', )
+ price = fields.JSONField(null=True, description='每日的余额', )
+ returns = fields.JSONField(null=True, description='每日的回报率', )
+ rf = fields.FloatField(null=True, default=0, description='无风险收益率,通常用来计算夏普比率等指标,通常为0表示忽略无风险收益', )
+ six_month = fields.FloatField(null=True, default=0, description='过去六个月的回报(如果为空,表示无相关数据', )
+ stock_close_price = fields.JSONField(null=True, description='股票收盘价', )
+ stock_code = fields.CharField(max_length=20, null=True, )
+ strategy_name = fields.CharField(max_length=50, null=True, description='回测的策略名称', )
+ ten_year = fields.IntField(description='过去十年的回报', )
+ three_month = fields.FloatField(null=True, default=0, description='过去三个月的回报', )
+ three_year = fields.FloatField(null=True, default=0, description='过去三年的回报', )
+ total_return = fields.FloatField(null=True, default=0, description='总回报率', )
+ twelve_month_win_perc = fields.FloatField(null=True, default=0, description='过去12个月的胜率,表示在过去12个月中有多少月的回报为正', )
+ win_year_perc = fields.FloatField(null=True, default=0, description='策略年度胜率,表示回测期间策略表现优于无风险收益率的年份百分比', )
+ worst_day = fields.FloatField(null=True, default=0, description='回测期间单日的最差回报', )
+ worst_month = fields.FloatField(null=True, default=0, description='回测期间的最差月份回报', )
+ worst_year = fields.FloatField(null=True, default=0, description='回测期间的最差年度回报(若为空,表示回测不足一年)', )
+ yearly_kurt = fields.FloatField(null=True, default=0, description='年度的峰度', )
+ yearly_mean = fields.FloatField(null=True, default=0, description='年度的平均回报', )
+ yearly_sharpe = fields.FloatField(null=True, default=0, description='年度的夏普比率', )
+ yearly_skew = fields.FloatField(null=True, default=0, description='年度的偏度', )
+ yearly_sortino = fields.FloatField(null=True, default=0, description='年度的Sortino 比率', )
+ yearly_vol = fields.FloatField(null=True, default=0, description='年度的波动率', )
+ ytd = fields.FloatField(null=True, default=0, description='年初至今的回报', )
+
+
+
+ class Meta:
+ table = with_table_name("wance_data_storage_backtest")
\ No newline at end of file
diff --git a/src/pydantic/__pycache__/backtest_request.cpython-311.pyc b/src/pydantic/__pycache__/backtest_request.cpython-311.pyc
new file mode 100644
index 0000000..8b44f58
Binary files /dev/null and b/src/pydantic/__pycache__/backtest_request.cpython-311.pyc differ
diff --git a/src/pydantic/__pycache__/codelistrequest.cpython-311.pyc b/src/pydantic/__pycache__/codelistrequest.cpython-311.pyc
new file mode 100644
index 0000000..bc8a46f
Binary files /dev/null and b/src/pydantic/__pycache__/codelistrequest.cpython-311.pyc differ
diff --git a/src/pydantic/__pycache__/factor_request.cpython-311.pyc b/src/pydantic/__pycache__/factor_request.cpython-311.pyc
new file mode 100644
index 0000000..3718b05
Binary files /dev/null and b/src/pydantic/__pycache__/factor_request.cpython-311.pyc differ
diff --git a/src/pydantic/__pycache__/request_data.cpython-311.pyc b/src/pydantic/__pycache__/request_data.cpython-311.pyc
new file mode 100644
index 0000000..8b512bc
Binary files /dev/null and b/src/pydantic/__pycache__/request_data.cpython-311.pyc differ
diff --git a/src/pydantic/backtest_request.py b/src/pydantic/backtest_request.py
new file mode 100644
index 0000000..905e025
--- /dev/null
+++ b/src/pydantic/backtest_request.py
@@ -0,0 +1,29 @@
+from typing import List, Optional
+
+from pydantic import BaseModel, Field
+
+
+class BackRequest(BaseModel):
+ field_list: List[str] = Field(default_factory=list, description="字段列表,用于指定获取哪些数据字段")
+ stock_list: List[str] = Field(default_factory=list, description="股票列表,用于指定获取哪些股票的数据")
+ stock_code: str = Field(default="000300.SH", description="股票代码,用于指定获取哪些股票的数据")
+ period: str = Field(default='1d', description="数据周期,如 '1d' 表示日线数据")
+ start_time: Optional[str] = Field(default='', description="开始时间,格式为 'YYYY-MM-DD',默认为空字符串")
+ end_time: Optional[str] = Field(default='', description="结束时间,格式为 'YYYY-MM-DD',默认为空字符串")
+ count: int = Field(default=-1, description="数据条数,默认为 -1 表示获取所有数据")
+ dividend_type: str = Field(default='none', description="分红类型,默认值为 'none'")
+ fill_data: bool = Field(default=True, description="是否填充数据,默认为 True")
+ incrementally: bool = Field(default=True, description="是否增量下载")
+ callback: bool = Field(default=True, description="是否开启回调函数")
+ data_dir: str = Field(default="D:\\e海方舟-量化交易版\\userdata_mini\\datadir", description="数据存储路径")
+ sector_name:str = Field(default="沪深指数", description="板块名称")
+ iscomplete: bool = Field(default=False, description="是否获取全部字段")
+ ma_type: str = Field(default='SMA', description="移动平均线类型,如 'SMA' 表示简单移动平均线")
+ short_window: int = Field(default=50, description="短周期线长度")
+ long_window: int = Field(default=200, description="长周期线长度")
+ bollingerMA: int = Field(default=50, description="布林带中的移动平均线周期,决定计算均值的时间窗口长度")
+ std_dev: int = Field(default=200, description="布林带中用于计算上下轨的标准差倍数,影响带宽大小")
+ overbought: int = Field(default=70, description="超买区间的RSI阈值,表示价格处于相对高点,可能面临回调")
+ oversold: int = Field(default=30, description="超卖区间的RSI阈值,表示价格处于相对低点,可能面临反弹")
+ signal_window: int = Field(default=9, description="超卖区间的RSI阈值,表示价格处于相对低点,可能面临反弹")
+
diff --git a/src/pydantic/codelistrequest.py b/src/pydantic/codelistrequest.py
new file mode 100644
index 0000000..2b0de4e
--- /dev/null
+++ b/src/pydantic/codelistrequest.py
@@ -0,0 +1,5 @@
+from pydantic import BaseModel
+
+
+class CodeListRequest(BaseModel):
+ code_list : list[str]
\ No newline at end of file
diff --git a/src/pydantic/factor_request.py b/src/pydantic/factor_request.py
new file mode 100644
index 0000000..09af0ab
--- /dev/null
+++ b/src/pydantic/factor_request.py
@@ -0,0 +1,140 @@
+from typing import List, Optional, Any, Dict
+
+from pydantic import BaseModel, Field
+
+
+# 定义查询条件模型,覆盖所有字段
+class StockQuery(BaseModel):
+ financial_asset_value: Optional[float] = None
+ financial_cash_flow: Optional[float] = None
+ financial_dividend: Optional[float] = None
+ financial_ex_gratia: Optional[float] = None
+ financial_reserve_per: Optional[float] = None
+ financial_undistributed_profit: Optional[float] = None
+ growth_attributable_rate: Optional[float] = None
+ growth_growth_rate: Optional[float] = None
+ growth_Income_rate: Optional[float] = None
+ growth_nonnet_profit: Optional[float] = None
+ id: Optional[int] = None
+ market_indicator: Optional[float] = None
+ market_sector: Optional[str] = None
+ profit_asset_value: Optional[float] = None
+ profit_business_increase: Optional[float] = None
+ profit_dividend_rate: Optional[float] = None
+ profit_gross_rate: Optional[float] = None
+ profit_sale_ratio: Optional[float] = None
+ stock_code: Optional[str] = None
+ stock_name: Optional[str] = None
+ stock_sector: Optional[List[str]] = None
+ stock_type: Optional[List[str]] = None
+ time_start: Optional[str] = None
+ time_end: Optional[str] = None
+ time_expire: Optional[str] = None
+ valuation_market_percentile: Optional[float] = None
+ valuation_market_TTM: Optional[float] = None
+ valuation_PB_percentile: Optional[float] = None
+ valuation_PB_TTM: Optional[float] = None
+ valuation_PEG_percentile: Optional[float] = None
+ valuation_PEGTTM_ratio: Optional[float] = None
+ valuation_PTS_percentile: Optional[float] = None
+ valuation_PTS_TTM: Optional[float] = None
+ gt: Optional[Dict[str, float]] = None # 大于
+ lt: Optional[Dict[str, float]] = None # 小于
+ gte: Optional[Dict[str, float]] = None # 大于等于
+ lte: Optional[Dict[str, float]] = None # 小于等于
+ between: Optional[Dict[str, Dict[str, float]]] = None # 在某两个值之间
+
+"""
+# 这里是因子pydantic的定义
+
+class FactorRequest(BaseModel):
+ financial_asset_value: Optional[Dict[str, Any]] = Field(None, description='每股净资产')
+ financial_cash_flow: Optional[Dict[str, Any]] = Field(None, description='每股现金流')
+ financial_dividend: Optional[Dict[str, Any]] = Field(None, description='分红率')
+ financial_ex_gratia: Optional[Dict[str, Any]] = Field(None, description='扣非后每股收益')
+ financial_reserve_per: Optional[Dict[str, Any]] = Field(None, description='每股资本公积金')
+ financial_undistributed_profit: Optional[Dict[str, Any]] = Field(None, description='每股未分配利润')
+ growth_attributable_rate: Optional[Dict[str, Any]] = Field(None, description='归母净利润同比增长率')
+ growth_growth_rate: Optional[Dict[str, Any]] = Field(None, description='营业利润同比增长率')
+ growth_Income_rate: Optional[Dict[str, Any]] = Field(None, description='营业总收入同比增长率')
+ growth_nonnet_profit: Optional[Dict[str, Any]] = Field(None, description='扣非净利润增长率')
+ market_indicator: Optional[Dict[str, Any]] = Field(None, description='行情指标')
+ market_sector: Optional[str] = Field(None, description='所属市场')
+ profit_asset_value: Optional[Dict[str, Any]] = Field(None, description='净资产收益率')
+ profit_business_increase: Optional[Dict[str, Any]] = Field(None, description='营业收入增长率')
+ profit_dividend_rate: Optional[Dict[str, Any]] = Field(None, description='股息率')
+ profit_gross_rate: Optional[Dict[str, Any]] = Field(None, description='销售毛利率')
+ profit_sale_ratio: Optional[Dict[str, Any]] = Field(None, description='盈利销售净利率')
+ stock_code: Optional[str] = Field(None, description='股票代码')
+ stock_name: Optional[str] = Field(None, description='股票名称')
+ stock_sector: Optional[List[str]] = Field(None, description='股票板块')
+ stock_type: Optional[dict] = Field(None, description='股票类型')
+ time_end: Optional[str] = Field(None, description='上一次回测结束时间')
+ time_expire: Optional[str] = Field(None, description='退市时间')
+ time_start: Optional[Dict[str, Any]] = Field(None, description='上市时间')
+ valuation_market_percentile: Optional[Dict[str, Any]] = Field(None, description='市现率百分比')
+ valuation_market_TTM: Optional[Dict[str, Any]] = Field(None, description='市现率TTM')
+ valuation_PB_percentile: Optional[Dict[str, Any]] = Field(None, description='市净率百分位')
+ valuation_PB_TTM: Optional[Dict[str, Any]] = Field(None, description='市净率TTM')
+ valuation_PEG_percentile: Optional[Dict[str, Any]] = Field(None, description='市盈率百分位')
+ valuation_PEGTTM_ratio: Optional[Dict[str, Any]] = Field(None, description='市盈率TTM')
+ valuation_PTS_percentile: Optional[Dict[str, Any]] = Field(None, description='市销率百分位')
+ valuation_PTS_TTM: Optional[Dict[str, Any]] = Field(None, description='市销率TTM')
+"""
+
+"""
+使用示例:
+ factor_request = FactorRequest(
+ {
+ "financial_asset_value": 10.7169,
+ "financial_cash_flow": -0.42982,
+ "gt": {
+ "profit_gross_rate": 4.0,
+ "valuation_PB_TTM": 0.8
+ },
+ "lt": {
+ "market_indicator": -0.6
+ },
+ "gte": {
+ "valuation_market_percentile":5.38793
+ },
+ "lte": {
+ "valuation_PEGTTM_ratio": 15.2773
+ },
+ "between": {
+ "financial_dividend": {
+ "min": 1.0,
+ "max": 2.0
+ },
+ "growth_Income_rate": {
+ "min": -0.6,
+ "max": 1.0
+ }
+ },
+ "stock_code": "600051.SH",
+ "stock_name": "宁波联合",
+ "stock_sector": ["Shanghai A-shares", "Shenzhen A-shares"],
+ "stock_type": ["stock"],
+ "time_start": "0",
+ "time_end": "20240911",
+ "valuation_PEG_percentile": 15.2222,
+ "valuation_PTS_TTM": 2.3077
+ }
+ )
+
+ result = await WanceDataStock.dynamic_query(**factor_request.dict(exclude_none=True))
+ print(result)
+
+
+比较大小字段说明:
+ 'gt' 对应 __gt(大于)
+ 'lt' 对应 __lt(小于)
+ 'gte' 对应 __gte(大于等于)
+ 'lte' 对应 __lte(小于等于)
+ 'between' 对应在某个值之间的条件,使用 __gt 和 __lt 组合实现
+
+时间字段说明:
+ time_start={'recent_years': 2} 近2年
+ time_start={'recent_months': 2} 近2月
+
+"""
diff --git a/src/pydantic/request_data.py b/src/pydantic/request_data.py
new file mode 100644
index 0000000..381048c
--- /dev/null
+++ b/src/pydantic/request_data.py
@@ -0,0 +1,21 @@
+from typing import List, Optional
+
+from pydantic import BaseModel, Field
+
+
+class DataRequest(BaseModel):
+ field_list: List[str] = Field(default_factory=list, description="字段列表,用于指定获取哪些数据字段")
+ stock_list: List[str] = Field(default_factory=list, description="股票列表,用于指定获取哪些股票的数据")
+ stock_code: str = Field(default="000300.SH", description="股票代码,用于指定获取哪些股票的数据")
+ period: str = Field(default='1d', description="数据周期,如 '1d' 表示日线数据")
+ start_time: Optional[str] = Field(default='', description="开始时间,格式为 'YYYY-MM-DD',默认为空字符串")
+ end_time: Optional[str] = Field(default='', description="结束时间,格式为 'YYYY-MM-DD',默认为空字符串")
+ count: int = Field(default=-1, description="数据条数,默认为 -1 表示获取所有数据")
+ dividend_type: str = Field(default='none', description="分红类型,默认值为 'none'")
+ fill_data: bool = Field(default=True, description="是否填充数据,默认为 True")
+ incrementally: bool = Field(default=True, description="是否增量下载")
+ callback: bool = Field(default=True, description="是否开启回调函数")
+ data_dir: str = Field(default="D:\\e海方舟-量化交易版\\userdata_mini\\datadir", description="数据存储路径")
+ sector_name:str = Field(default="沪深指数", description="板块名称")
+ iscomplete: bool = Field(default=False, description="是否获取全部字段")
+
diff --git a/src/responses.py b/src/responses.py
new file mode 100644
index 0000000..b607ca6
--- /dev/null
+++ b/src/responses.py
@@ -0,0 +1,68 @@
+import typing
+from typing import Generic, Sequence, TypeVar
+
+from fastapi import status
+from fastapi.encoders import jsonable_encoder
+from fastapi.responses import JSONResponse
+from pydantic import BaseModel
+
+T = TypeVar("T")
+
+
+class EntityResponse(BaseModel, Generic[T]):
+ """实体数据"""
+ status_code: int = 200
+ message: str = "Success"
+ data: T
+
+
+class ListResponse(BaseModel, Generic[T]):
+ """列表数据"""
+ status_code: int = 200
+ message: str = "Success"
+ data: Sequence[T]
+
+
+# 包装响应的 Pydantic 模型
+# class EntityPageResponse(BaseModel, Generic[T]):
+# status_code: int
+# message: str
+# data: PageData
+
+
+def response_entity_response(data, status_code=200, message="Success") -> EntityResponse:
+ """普通实体类"""
+ return EntityResponse(data=data, status_code=status_code, message=message)
+
+
+# def response_page_response(data, status_code=200, message="Success") -> EntityPageResponse:
+# """普通分页类"""
+# return EntityPageResponse(data=data, status_code=status_code, message=message)
+
+
+def response_list_response(data, status_code=200, message="Success") -> ListResponse:
+ """普通列表数据"""
+ return ListResponse(data=data, status_code=status_code, message=message)
+
+
+def response_success(message: str = 'Success',
+ headers: typing.Optional[typing.Dict[str, str]] = None) -> JSONResponse:
+ """成功返回"""
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ headers=headers,
+ content=jsonable_encoder({
+ "status_code": status.HTTP_200_OK,
+ "message": message,
+ }))
+
+
+def response_unauthorized() -> JSONResponse:
+ """未登录"""
+ return JSONResponse(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ headers={"WWW-Authenticate": "Bearer"},
+ content=jsonable_encoder({
+ "status_code": status.HTTP_401_UNAUTHORIZED,
+ "message": '用户认证失败',
+ }))
diff --git a/src/settings/__init__.py b/src/settings/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/settings/__pycache__/__init__.cpython-311.pyc b/src/settings/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..c285bd8
Binary files /dev/null and b/src/settings/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/settings/__pycache__/config.cpython-311.pyc b/src/settings/__pycache__/config.cpython-311.pyc
new file mode 100644
index 0000000..b205db1
Binary files /dev/null and b/src/settings/__pycache__/config.cpython-311.pyc differ
diff --git a/src/settings/config.py b/src/settings/config.py
new file mode 100644
index 0000000..9e04aec
--- /dev/null
+++ b/src/settings/config.py
@@ -0,0 +1,52 @@
+from typing import Any
+
+from pydantic import RedisDsn, model_validator
+from pydantic_settings import BaseSettings
+
+from src.constants import Environment
+
+
+class Config(BaseSettings):
+ DATABASE_PREFIX: str = ""
+ DATABASE_URL: str
+ DATABASE_CREATE_URL: str
+ REDIS_URL: RedisDsn
+
+ SITE_DOMAIN: str = "myapp.com"
+
+ ENVIRONMENT: Environment = Environment.PRODUCTION
+
+ SENTRY_DSN: str | None = None
+
+ CORS_ORIGINS: list[str]
+ CORS_ORIGINS_REGEX: str | None = None
+ CORS_HEADERS: list[str]
+
+ APP_VERSION: str = "1"
+ APP_ROUTER_PREFIX: str = "/api/v1"
+
+ SHOULD_SEND_SMS: bool = False
+
+ @model_validator(mode="after")
+ def validate_sentry_non_local(self) -> "Config":
+ if self.ENVIRONMENT.is_deployed and not self.SENTRY_DSN:
+ raise ValueError("Sentry is not set")
+
+ return self
+
+
+settings = Config()
+
+# fastapi/applications.py
+app_configs: dict[str, Any] = {
+ "title": "Wance QMT API",
+ "root_path": settings.APP_ROUTER_PREFIX,
+ "docs_url": "/api/docs",
+}
+
+# app_configs['debug'] = True
+# if settings.ENVIRONMENT.is_deployed:
+# app_configs["root_path"] = f"/v{settings.APP_VERSION}"
+#
+# if not settings.ENVIRONMENT.is_debug:
+# app_configs["openapi_url"] = None # hide docs
diff --git a/src/tests/__init__.py b/src/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/tests/stock_query.py b/src/tests/stock_query.py
new file mode 100644
index 0000000..49a6b1c
--- /dev/null
+++ b/src/tests/stock_query.py
@@ -0,0 +1,37 @@
+import asyncio
+import json
+
+from src.models import wance_data_stock
+from src.tortoises_orm_config import init_tortoise
+
+
+async def init_test():
+ await init_tortoise()
+
+ await wance_data_stock.WanceDataStock.create(
+ stock_code='600051.SH',
+ stock_name='宁波联合',
+ stock_sector=['stockA'], # 确保这是一个列表格式
+ stock_type=['stock'],
+ time_start='2010-01-01',
+ time_end='20240911',
+ time_expire='2024-09-11',
+ market_sector='SH'
+ )
+ filters = {
+ 'stock_sector__contains': ["上证A股"] # 列表形式用于 JSON 数组匹配
+ }
+ # filters = {
+ # 'stock_sector__contains': ['stockA'] # 列表形式用于 JSON 数组匹配
+ # }
+ # stocks = await wance_data_stock.WanceDataStock.filter(stock_sector__contains=["上证A股"]).all()
+ # stocks = await wance_data_stock.WanceDataStock.filter(stock_sector__contains=["上证A股", "沪深A股", "TGN共同富裕示范区", "TGN旅游概念", "TGN煤炭概念", "TGN物业管理", "TGN跨境电商", "THY1综合", "THY2综合", "THY3综合", "DY1浙江省", "DY2浙江省宁波市", "上证指数", "上证收益", "上证流通", "中型综指", "中证全指", "中证民企", "中证流通", "国证A指", "新综指", "民企成长", "浙企综指", "浙江民企", "综合指数", "A股指数", "GN共同富裕示范区", "GN婚庆", "GN小盘", "GN微盘股", "GN房地产", "GN批发业", "GN旅游", "GN浙江", "GN煤炭", "GN物业管理", "GN环杭州湾大湾区", "GN电力改革", "GN破净股", "GN舟山自贸区", "GN跨境电商", "GN进口博览会", "SW1综合", "SW1综合加权", "SW2综合", "SW2综合加权", "SW3综合", "SW3综合加权", "A股非科创等权", "上海A股等权", "上海主板等权", "上海全市场等权", "上证指数等权", "上证收益等权", "上证流通等权", "中型综指等权", "中证全指等权", "国证A指等权", "新综指等权", "沪深全市场等权", "综合指数等权", "A股指数等权", "CSRC1批发和零售业", "CSRC2批发业"]).all()
+ # stocks = await wance_data_stock.WanceDataStock.filter(stock_sector__contains=["stock"]).all()
+ stocks = await wance_data_stock.WanceDataStock.filter(**filters).all()
+
+
+ print(stocks)
+
+
+if __name__ == '__main__':
+ asyncio.run(init_test())
diff --git a/src/tests/xtquan_data_test.py b/src/tests/xtquan_data_test.py
new file mode 100644
index 0000000..bd0c0fa
--- /dev/null
+++ b/src/tests/xtquan_data_test.py
@@ -0,0 +1,78 @@
+# 用前须知
+
+## xtdata提供和MiniQmt的交互接口,本质是和MiniQmt建立连接,由MiniQmt处理行情数据请求,再把结果回传返回到python层。使用的行情服务器以及能获取到的行情数据和MiniQmt是一致的,要检查数据或者切换连接时直接操作MiniQmt即可。
+
+## 对于数据获取接口,使用时需要先确保MiniQmt已有所需要的数据,如果不足可以通过补充数据接口补充,再调用数据获取接口获取。
+
+## 对于订阅接口,直接设置数据回调,数据到来时会由回调返回。订阅接收到的数据一般会保存下来,同种数据不需要再单独补充。
+
+# 代码讲解
+
+# 从本地python导入xtquant库,如果出现报错则说明安装失败
+from xtquant import xtdata
+import time
+
+# 设定一个标的列表
+code_list = ["000001.SZ"]
+# 设定获取数据的周期
+period = "1d"
+
+# 下载标的行情数据
+if 1:
+ ## 为了方便用户进行数据管理,xtquant的大部分历史数据都是以压缩形式存储在本地的
+ ## 比如行情数据,需要通过download_history_data下载,财务数据需要通过
+ ## 所以在取历史数据之前,我们需要调用数据下载接口,将数据下载到本地
+ for i in code_list:
+ xtdata.download_history_data(i, period=period, incrementally=True) # 增量下载行情数据(开高低收,等等)到本地
+
+ xtdata.download_financial_data(code_list) # 下载财务数据到本地
+ xtdata.download_sector_data() # 下载板块数据到本地
+ # 更多数据的下载方式可以通过数据字典查询
+
+# 读取本地历史行情数据
+history_data = xtdata.get_market_data_ex([], code_list, period=period, count=-1)
+print(history_data)
+print("=" * 20)
+
+# 如果需要盘中的实时行情,需要向服务器进行订阅后才能获取
+# 订阅后,get_market_data函数于get_market_data_ex函数将会自动拼接本地历史行情与服务器实时行情
+
+# 向服务器订阅数据
+for i in code_list:
+ xtdata.subscribe_quote(i, period=period, count=-1) # 设置count = -1来取到当天所有实时行情
+
+# 等待订阅完成
+time.sleep(1)
+
+# 获取订阅后的行情
+kline_data = xtdata.get_market_data_ex([], code_list, period=period)
+print(kline_data)
+
+# 获取订阅后的行情,并以固定间隔进行刷新,预期会循环打印10次
+for i in range(10):
+ # 这边做演示,就用for来循环了,实际使用中可以用while True
+ kline_data = xtdata.get_market_data_ex([], code_list, period=period)
+ print(kline_data)
+ time.sleep(3) # 三秒后再次获取行情
+
+
+# 如果不想用固定间隔触发,可以以用订阅后的回调来执行
+# 这种模式下当订阅的callback回调函数将会异步的执行,每当订阅的标的tick发生变化更新,callback回调函数就会被调用一次
+# 本地已有的数据不会触发callback
+
+# 定义的回测函数
+## 回调函数中,data是本次触发回调的数据,只有一条
+def f(data):
+ # print(data)
+
+ code_list = list(data.keys()) # 获取到本次触发的标的代码
+
+ kline_in_callabck = xtdata.get_market_data_ex([], code_list, period=period) # 在回调中获取klines数据
+ print(kline_in_callabck)
+
+
+for i in code_list:
+ xtdata.subscribe_quote(i, period=period, count=-1, callback=f) # 订阅时设定回调函数
+
+# 使用回调时,必须要同时使用xtdata.run()来阻塞程序,否则程序运行到最后一行就直接结束退出了。
+xtdata.run()
\ No newline at end of file
diff --git a/src/tortoises.py b/src/tortoises.py
new file mode 100644
index 0000000..a9db0fe
--- /dev/null
+++ b/src/tortoises.py
@@ -0,0 +1,185 @@
+import logging
+from fastapi import FastAPI
+from tortoise.contrib.fastapi import register_tortoise
+from src.settings.config import settings
+
+DATABASE_URL = str(settings.DATABASE_URL)
+DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL)
+
+
+# 定义不同日志级别的颜色
+class ColoredFormatter(logging.Formatter):
+ COLORS = {
+ "DEBUG": "\033[94m", # 蓝色
+ "INFO": "\033[92m", # 绿色
+ "WARNING": "\033[93m", # 黄色
+ "ERROR": "\033[91m", # 红色
+ "CRITICAL": "\033[41m", # 红色背景
+ }
+ RESET = "\033[0m" # 重置颜色
+
+ def format(self, record):
+ color = self.COLORS.get(record.levelname, self.RESET)
+ log_message = super().format(record)
+ return f"{color}{log_message}{self.RESET}"
+
+
+# 配置日志
+logger_db_client = logging.getLogger("tortoise.db_client")
+logger_db_client.setLevel(logging.DEBUG)
+
+# 创建一个控制台处理程序
+ch = logging.StreamHandler()
+ch.setLevel(logging.DEBUG)
+
+# 创建并设置格式化器
+formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
+ch.setFormatter(formatter)
+
+# 将处理程序添加到日志记录器中
+logger_db_client.addHandler(ch)
+
+# 定义模型
+models = [
+ "src.models.test_table",
+ "src.models.stock",
+ "src.models.security_account",
+ "src.models.order",
+ "src.models.snowball",
+ "src.models.backtest",
+ "src.models.strategy",
+ "src.models.stock_details",
+ "src.models.transaction",
+ "src.models.position",
+ "src.models.trand_info",
+ "src.models.tran_observer_data",
+ "src.models.tran_orders",
+ "src.models.tran_return",
+ "src.models.tran_trade_info",
+ "src.models.back_observed_data",
+ "src.models.back_observed_data_detail",
+ "src.models.back_position",
+ "src.models.back_result_indicator",
+ "src.models.back_trand_info",
+ "src.models.tran_position",
+ "src.models.stock_bt_history",
+ "src.models.stock_history",
+ "src.models.stock_data_processing",
+ "src.models.wance_data_stock",
+ "src.models.wance_data_storage_backtest"
+
+]
+
+aerich_models = models
+aerich_models.append("aerich.models")
+
+# Tortoise ORM 配置
+TORTOISE_ORM = {
+ "connections": {"default": DATABASE_CREATE_URL,
+ },
+ "apps": {
+ "models": {
+ "models": aerich_models,
+ "default_connection": "default",
+ },
+ },
+}
+
+config = {
+ 'connections': {
+ 'default': DATABASE_URL
+ },
+ 'apps': {
+ 'models': {
+ 'models': models,
+ 'default_connection': 'default',
+ }
+ },
+ 'use_tz': False,
+ 'timezone': 'Asia/Shanghai'
+}
+
+
+def register_tortoise_orm(app: FastAPI):
+ # 注册 Tortoise ORM 配置
+ register_tortoise(
+ app,
+ config=config,
+ generate_schemas=False,
+ add_exception_handlers=True,
+ )
+
+
+"""
+**********
+以下内容对原有无改色配置文件进行注释,上述内容为修改后日志为蓝色
+**********
+"""
+
+# import logging
+#
+# from fastapi import FastAPI
+# from tortoise.contrib.fastapi import register_tortoise
+#
+# from src.settings.config import settings
+#
+# DATABASE_URL = str(settings.DATABASE_URL)
+# DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL)
+#
+# # will print debug sql
+# logging.basicConfig()
+# logger_db_client = logging.getLogger("tortoise.db_client")
+# logger_db_client.setLevel(logging.DEBUG)
+#
+# models = [
+# "src.models.test_table",
+# "src.models.stock",
+# "src.models.security_account",
+# "src.models.order",
+# "src.models.snowball",
+# "src.models.backtest",
+# "src.models.strategy",
+# "src.models.stock_details",
+# "src.models.transaction",
+# ]
+#
+# aerich_models = models
+# aerich_models.append("aerich.models")
+#
+# TORTOISE_ORM = {
+#
+# "connections": {"default": DATABASE_CREATE_URL},
+# "apps": {
+# "models": {
+# "models": aerich_models,
+# "default_connection": "default",
+# },
+# },
+# }
+#
+# config = {
+# 'connections': {
+# # Using a DB_URL string
+# 'default': DATABASE_URL
+# },
+# 'apps': {
+# 'models': {
+# 'models': models,
+# 'default_connection': 'default',
+# }
+# },
+# 'use_tz': False,
+# 'timezone': 'Asia/Shanghai'
+# }
+#
+#
+# def register_tortoise_orm(app: FastAPI):
+# register_tortoise(
+# app,
+# config=config,
+# # db_url=DATABASE_URL,
+# # db_url="mysql://adams:adams@127.0.0.1:3306/adams",
+# # modules={"models": models},
+# generate_schemas=False,
+# add_exception_handlers=True,
+# )
diff --git a/src/tortoises_orm_config.py b/src/tortoises_orm_config.py
new file mode 100644
index 0000000..b09b70c
--- /dev/null
+++ b/src/tortoises_orm_config.py
@@ -0,0 +1,30 @@
+from tortoise import Tortoise
+from tortoise.exceptions import OperationalError
+from src.tortoises import TORTOISE_ORM
+import asyncio
+
+
+async def init_tortoise():
+ """
+ 初始化 Tortoise-ORM 数据库连接
+ """
+ try:
+ await Tortoise.init(config=TORTOISE_ORM)
+ # print("Tortoise-ORM initialized successfully.")
+ except OperationalError as e:
+ print(f"OperationalError during initialization: {e}")
+ except Exception as e:
+ print(f"Error during initialization: {e}")
+
+
+async def close_tortoise():
+ """
+ 关闭 Tortoise-ORM 数据库连接
+ """
+ try:
+ await Tortoise.close_connections()
+ # print("Tortoise-ORM connections closed successfully.")
+ except Exception as e:
+ print(f"Error during closing connections: {e}")
+
+
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000..679558b
--- /dev/null
+++ b/src/utils/__init__.py
@@ -0,0 +1,37 @@
+import sqlite3
+from collections import deque
+
+import redis.asyncio as aioredis
+from fastapi import FastAPI
+
+from src.settings.config import settings
+from src.utils import redis
+
+
+import akshare
+
+from pypinyin import lazy_pinyin
+
+from src.models import StockType
+from src.models.stock import Stock, StockResponse
+from src.utils.paginations import PaginationPydantic, pagination, Params
+
+
+
+def register_redis(app: FastAPI):
+ @app.on_event("startup")
+ async def startup_redis():
+ pool = aioredis.ConnectionPool.from_url(
+ str(settings.REDIS_URL), max_connections=10, decode_responses=True
+ )
+ redis.redis_client = aioredis.Redis(connection_pool=pool)
+
+
+
+
+ @app.on_event("shutdown")
+ async def shutdown_redis():
+ await redis.redis_client.connection_pool.disconnect()
+ await redis.redis_client.shutdown()
+
+
diff --git a/src/utils/__pycache__/__init__.cpython-311.pyc b/src/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..3b87d3d
Binary files /dev/null and b/src/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/utils/__pycache__/helpers.cpython-311.pyc b/src/utils/__pycache__/helpers.cpython-311.pyc
new file mode 100644
index 0000000..2d83290
Binary files /dev/null and b/src/utils/__pycache__/helpers.cpython-311.pyc differ
diff --git a/src/utils/__pycache__/history_data_processing_utils.cpython-311.pyc b/src/utils/__pycache__/history_data_processing_utils.cpython-311.pyc
new file mode 100644
index 0000000..6527cca
Binary files /dev/null and b/src/utils/__pycache__/history_data_processing_utils.cpython-311.pyc differ
diff --git a/src/utils/__pycache__/models.cpython-311.pyc b/src/utils/__pycache__/models.cpython-311.pyc
new file mode 100644
index 0000000..3fa21bd
Binary files /dev/null and b/src/utils/__pycache__/models.cpython-311.pyc differ
diff --git a/src/utils/__pycache__/paginations.cpython-311.pyc b/src/utils/__pycache__/paginations.cpython-311.pyc
new file mode 100644
index 0000000..8ddbb1d
Binary files /dev/null and b/src/utils/__pycache__/paginations.cpython-311.pyc differ
diff --git a/src/utils/__pycache__/redis.cpython-311.pyc b/src/utils/__pycache__/redis.cpython-311.pyc
new file mode 100644
index 0000000..dd1277a
Binary files /dev/null and b/src/utils/__pycache__/redis.cpython-311.pyc differ
diff --git a/src/utils/helpers.py b/src/utils/helpers.py
new file mode 100644
index 0000000..d42497c
--- /dev/null
+++ b/src/utils/helpers.py
@@ -0,0 +1,55 @@
+import numpy as np
+
+
+# 英文连接的多个参数转成list
+def comma_string_to_array(value):
+ if not value:
+ return []
+
+ return np.array(value.split(','))
+
+
+def first(iterable, default=None, condition=lambda x: True):
+ """
+ Returns the first item in the `iterable` that
+ satisfies the `condition`.
+
+ If the condition is not given, returns the first item of
+ the iterable.
+
+ If the `default` argument is given and the iterable is empty,
+ or if it has no items matching the condition, the `default` argument
+ is returned if it matches the condition.
+
+ The `default` argument being None is the same as it not being given.
+
+ Raises `StopIteration` if no item satisfying the condition is found
+ and default is not given or doesn't satisfy the condition.
+
+ >>> first( (1,2,3), condition=lambda x: x % 2 == 0)
+ 2
+ >>> first(range(3, 100))
+ 3
+ >>> first( () )
+ Traceback (most recent call last):
+ ...
+ StopIteration
+ >>> first([], default=1)
+ 1
+ >>> first([], default=1, condition=lambda x: x % 2 == 0)
+ Traceback (most recent call last):
+ ...
+ StopIteration
+ >>> first([1,3,5], default=1, condition=lambda x: x % 2 == 0)
+ Traceback (most recent call last):
+ ...
+ StopIteration
+ """
+
+ try:
+ return next(x for x in iterable if condition(x))
+ except StopIteration:
+ if default is not None and condition(default):
+ return default
+ else:
+ raise
diff --git a/src/utils/history_data_processing_utils.py b/src/utils/history_data_processing_utils.py
new file mode 100644
index 0000000..c155b23
--- /dev/null
+++ b/src/utils/history_data_processing_utils.py
@@ -0,0 +1,249 @@
+import difflib
+import re
+
+import pandas as pd
+
+
+def arrays_to_dict(Chinese: list, English: list) -> dict:
+ # 使用 zip 将两个数组配对,然后转换为字典
+ return dict(zip(Chinese, English))
+
+
+# 中文数组
+Chinese = [
+ "上证A股", "深证A股", "沪深300", "中证500", "中证1000", "国证2000", "农林牧渔基础", "化工钢铁", "有色金属",
+ "电子", "汽车", "家用电器", "食品饮料", "纺织服饰", "轻工制造", "医药生物", "公用事业", "交通运输",
+ "房地产", "商贸零售", "社会服务", "银行", "非银金融", "综合", "建筑材料", "建筑装饰", "电力设备",
+ "机械设备", "国防军工", "计算机", "传媒", "通信", "煤炭", "石油石化", "环保", "美容护理", "中字头股票",
+ "中特估100", "人形机器人", "低空经济", "数字货币", "工业互联网", "Web3.0", "网络直播", "跨境电商",
+ "消费电子概念", "网络游戏", "国产软件", "抖音概念", "手机游戏", "元宇宙", "区块链", "虚拟现实", "量子科技",
+ "智能穿戴", "机器视觉", "ChatGPT概念", "脑机接口", "人工智能", "机器人概念", "数字孪生", "算力租赁", "碳中和",
+ "风电", "光热发电", "光伏概念", "中俄贸易概念", "露营经济", "柔性屏", "宠物经济", "免税店", "预制菜",
+ "在线旅游", "减肥药", "体育产业", "无人机", "华为海思概念股", "先进封装(Chiplet)", "第三代半导体", "中芯国际概念",
+ "光刻机", "氟化工概念", "光刻胶", "汽车芯片", "传感器", "英伟达概念", "边缘计算", "存储芯片", "6G概念",
+ "石墨烯", "芯片概念", "特斯拉", "充电桩", "抽水蓄能", "固态电池", "一体化压铸", "动力电池回收", "新能源汽车",
+ "钠离子电池", "储能", "钙钛矿电池", "汽车电子", "稀土永磁", "华为汽车", "毫米波雷达", "锂电池", "阿尔茨海默概念",
+ "创新药", "仿制药一致性评价", "重组蛋白", "毛发医疗", "生物医药", "医美概念辅助生殖", "生物疫苗", "细胞免疫治疗",
+ "基因测序", "集成电路概念", "黄金概念", "新疆振兴", "一带一路", "养老概念", "新材料概念", "职业教育", "工业4.0",
+ "碳纤维", "军工", "航运概念", "独角兽概念", "国家大基金持股", "海南自贸区", "专精特新", "智能制造", "超超临界发电"
+]
+
+# 英文数组(仅为示例,可以根据具体需求替换成正确的英文)
+English = [
+ "Shanghai A-shares", "Shenzhen A-shares", "CSI 300", "CSI 500", "CSI 1000", "CSI 2000", "Agriculture and Forestry",
+ "Chemical and Steel", "Non-ferrous Metals", "Electronics", "Automobile", "Household Appliances",
+ "Food and Beverage",
+ "Textiles and Apparel", "Light Industry Manufacturing", "Pharmaceuticals", "Public Utilities", "Transportation",
+ "Real Estate", "Retail", "Social Services", "Banking", "Non-bank Finance", "Comprehensive", "Building Materials",
+ "Building Decoration", "Electric Power Equipment", "Machinery", "Defense", "Computer", "Media", "Telecom", "Coal",
+ "Petroleum", "Environmental Protection", "Beauty Care", "State-owned Enterprises", "Special Valuation 100",
+ "Humanoid Robot",
+ "Low-altitude Economy", "Digital Currency", "Industrial Internet", "Web3.0", "Live Streaming",
+ "Cross-border E-commerce",
+ "Consumer Electronics", "Online Games", "Domestic Software", "Douyin Concept", "Mobile Games", "Metaverse",
+ "Blockchain",
+ "Virtual Reality", "Quantum Technology", "Wearable Devices", "Machine Vision", "ChatGPT Concept",
+ "Brain-computer Interface",
+ "Artificial Intelligence", "Robotics", "Digital Twin", "Computing Power Leasing", "Carbon Neutrality", "Wind Power",
+ "Solar Thermal", "Photovoltaic", "China-Russia Trade", "Camping Economy", "Flexible Screen", "Pet Economy",
+ "Duty-free",
+ "Prepared Foods", "Online Travel", "Weight Loss Drugs", "Sports Industry", "Drone", "Huawei HiSilicon", "Chiplet",
+ "Third-generation Semiconductors", "SMIC Concept", "Lithography Machine", "Fluorine Chemicals", "Photoresist",
+ "Automotive Chips", "Sensors", "NVIDIA Concept", "Edge Computing", "Memory Chips", "6G Concept", "Graphene",
+ "Chip Concept",
+ "Tesla", "Charging Pile", "Pumped Storage", "Solid-state Battery", "Integrated Die-casting", "Battery Recycling",
+ "New Energy Vehicles", "Sodium-ion Battery", "Energy Storage", "Perovskite Battery", "Automotive Electronics",
+ "Rare Earth",
+ "Huawei Vehicles", "Millimeter Wave Radar", "Lithium Battery", "Alzheimer's Concept", "Innovative Drugs",
+ "Generic Drugs",
+ "Recombinant Protein", "Hair Medical", "Biomedical", "Aesthetic Medicine", "Reproductive Assistance", "Biovaccine",
+ "Cell Immunotherapy", "Gene Sequencing", "IC Concept", "Gold Concept", "Xinjiang Development", "Belt and Road",
+ "Elderly Care Concept", "New Materials", "Vocational Education", "Industry 4.0", "Carbon Fiber", "Military",
+ "Shipping",
+ "Unicorn Concept", "National Fund Holdings", "Hainan Free Trade Zone", "Specialized and New", "Smart Manufacturing",
+ "Ultra-supercritical Power"
+]
+
+# 调用方法
+result_dict = arrays_to_dict(Chinese, English)
+
+
+translation_dict = {
+ "上证A股": "Shanghai A-shares",
+ "深证A股": "Shenzhen A-shares",
+ "沪深300": "CSI 300",
+ "中证500": "CSI 500",
+ "中证1000": "CSI 1000",
+ "国证2000": "CSI 2000",
+ "农林牧渔基础": "Agriculture and Forestry",
+ "化工钢铁": "Chemical and Steel",
+ "有色金属": "Non-ferrous Metals",
+ "电子": "Electronics",
+ "汽车": "Automobile",
+ "家用电器": "Household Appliances",
+ "食品饮料": "Food and Beverage",
+ "纺织服饰": "Textiles and Apparel",
+ "轻工制造": "Light Industry Manufacturing",
+ "医药生物": "Pharmaceuticals",
+ "公用事业": "Public Utilities",
+ "交通运输": "Transportation",
+ "房地产": "Real Estate",
+ "商贸零售": "Retail",
+ "社会服务": "Social Services",
+ "银行": "Banking",
+ "非银金融": "Non-bank Finance",
+ "综合": "Comprehensive",
+ "建筑材料": "Building Materials",
+ "建筑装饰": "Building Decoration",
+ "电力设备": "Electric Power Equipment",
+ "机械设备": "Machinery",
+ "国防军工": "Defense",
+ "计算机": "Computer",
+ "传媒": "Media",
+ "通信": "Telecom",
+ "煤炭": "Coal",
+ "石油石化": "Petroleum",
+ "环保": "Environmental Protection",
+ "美容护理": "Beauty Care",
+ "中字头股票": "State-owned Enterprises",
+ "中特估100": "Special Valuation 100",
+ "人形机器人": "Humanoid Robot",
+ "低空经济": "Low-altitude Economy",
+ "数字货币": "Digital Currency",
+ "工业互联网": "Industrial Internet",
+ "Web3.0": "Web3.0",
+ "网络直播": "Live Streaming",
+ "跨境电商": "Cross-border E-commerce",
+ "消费电子概念": "Consumer Electronics",
+ "网络游戏": "Online Games",
+ "国产软件": "Domestic Software",
+ "抖音概念": "Douyin Concept",
+ "手机游戏": "Mobile Games",
+ "元宇宙": "Metaverse",
+ "区块链": "Blockchain",
+ "虚拟现实": "Virtual Reality",
+ "量子科技": "Quantum Technology",
+ "智能穿戴": "Wearable Devices",
+ "机器视觉": "Machine Vision",
+ "ChatGPT概念": "ChatGPT Concept",
+ "脑机接口": "Brain-computer Interface",
+ "人工智能": "Artificial Intelligence",
+ "机器人概念": "Robotics",
+ "数字孪生": "Digital Twin",
+ "算力租赁": "Computing Power Leasing",
+ "碳中和": "Carbon Neutrality",
+ "风电": "Wind Power",
+ "光热发电": "Solar Thermal",
+ "光伏概念": "Photovoltaic",
+ "中俄贸易概念": "China-Russia Trade",
+ "露营经济": "Camping Economy",
+ "柔性屏": "Flexible Screen",
+ "宠物经济": "Pet Economy",
+ "免税店": "Duty-free",
+ "预制菜": "Prepared Foods",
+ "在线旅游": "Online Travel",
+ "减肥药": "Weight Loss Drugs",
+ "体育产业": "Sports Industry",
+ "无人机": "Drone",
+ "华为海思概念股": "Huawei HiSilicon",
+ "先进封装(Chiplet)": "Chiplet",
+ "第三代半导体": "Third-generation Semiconductors",
+ "中芯国际概念": "SMIC Concept",
+ "光刻机": "Lithography Machine",
+ "氟化工概念": "Fluorine Chemicals",
+ "光刻胶": "Photoresist",
+ "汽车芯片": "Automotive Chips",
+ "传感器": "Sensors",
+ "英伟达概念": "NVIDIA Concept",
+ "边缘计算": "Edge Computing",
+ "存储芯片": "Memory Chips",
+ "6G概念": "6G Concept",
+ "石墨烯": "Graphene",
+ "芯片概念": "Chip Concept",
+ "特斯拉": "Tesla",
+ "充电桩": "Charging Pile",
+ "抽水蓄能": "Pumped Storage",
+ "固态电池": "Solid-state Battery",
+ "一体化压铸": "Integrated Die-casting",
+ "动力电池回收": "Battery Recycling",
+ "新能源汽车": "New Energy Vehicles",
+ "钠离子电池": "Sodium-ion Battery",
+ "储能": "Energy Storage",
+ "钙钛矿电池": "Perovskite Battery",
+ "汽车电子": "Automotive Electronics",
+ "稀土永磁": "Rare Earth",
+ "华为汽车": "Huawei Vehicles",
+ "毫米波雷达": "Millimeter Wave Radar",
+ "锂电池": "Lithium Battery",
+ "阿尔茨海默概念": "Alzheimer's Concept",
+ "创新药": "Innovative Drugs",
+ "仿制药一致性评价": "Generic Drugs",
+ "重组蛋白": "Recombinant Protein",
+ "毛发医疗": "Hair Medical",
+ "生物医药": "Biomedical",
+ "医美概念辅助生殖": "Aesthetic Medicine",
+ "生物疫苗": "Biovaccine",
+ "细胞免疫治疗": "Cell Immunotherapy",
+ "基因测序": "Gene Sequencing",
+ "集成电路概念": "IC Concept",
+ "黄金概念": "Gold Concept",
+ "新疆振兴": "Xinjiang Development",
+ "一带一路": "Belt and Road",
+ "养老概念": "Elderly Care Concept",
+ "新材料概念": "New Materials",
+ "职业教育": "Vocational Education",
+ "工业4.0": "Industry 4.0",
+ "碳纤维": "Carbon Fiber",
+ "军工": "Military",
+ "航运概念": "Shipping",
+ "独角兽概念": "Unicorn Concept",
+ "国家大基金持股": "National Fund Holdings",
+ "海南自贸区": "Hainan Free Trade Zone",
+ "专精特新": "Specialized and New",
+ "智能制造": "Smart Manufacturing",
+ "超超临界发电": "Ultra-supercritical Power"
+}
+
+
+def chinese_to_english(chinese_term):
+ return translation_dict.get(chinese_term, "Translation not found")
+
+
+def english_to_chinese(english_term):
+ # Invert the dictionary for reverse lookup
+ inverted_dict = {v: k for k, v in translation_dict.items()}
+ return inverted_dict.get(english_term, "Translation not found")
+
+
+def fuzzy_search(pattern, string_list):
+ """
+ 使用正则表达式在字符串列表中进行模糊搜索。
+
+ :param pattern: 要搜索的模式(字符串)
+ :param string_list: 字符串列表
+ :return: 匹配的字符串列表
+ """
+ regex = re.compile(pattern)
+ return [s for s in string_list if regex.search(s)]
+
+
+def get_best_match(sector, translation_dict):
+ """
+ 模糊匹配获取最佳匹配的中文板块。
+ @param sector: 需要匹配的中文板块名
+ @param translation_dict: 中文板块到英文的映射字典
+ @return: 最佳匹配的中文板块名(如果找到),否则返回None
+ """
+ matches = difflib.get_close_matches(sector, translation_dict.keys(), n=1, cutoff=0.6)
+ return matches[0] if matches else None
+
+
+def on_progress(data):
+ print("这里是回执数据", data)
+
+
+def safe_get_value(value):
+ """检查值是否为 None、NaN 或 False,如果是,则返回 0,否则返回值本身"""
+ if value is None or pd.isna(value) or value is False:
+ return 0
+ return value
diff --git a/src/utils/history_stock.py b/src/utils/history_stock.py
new file mode 100644
index 0000000..fc9b596
--- /dev/null
+++ b/src/utils/history_stock.py
@@ -0,0 +1,40 @@
+from datetime import datetime
+
+import akshare
+from pypinyin import lazy_pinyin
+
+from src.models import StockType
+from src.models.snowball import Snowball
+from src.models.stock_details import StockDetails
+from src.tortoises_orm_config import init_tortoise, close_tortoise
+
+
+async def job():
+ await init_tortoise()
+ print(f"执行定时任务: 当前时间是 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
+
+ await StockDetails.all().delete()
+ stock_name_df = akshare.stock_sh_a_spot_em()
+ stock_list = stock_name_df[['代码', '名称', '最新价', '涨跌幅']].values.tolist()
+
+ for stock in stock_list:
+ # 将文字转化为拼音列表
+ pinyin_list = lazy_pinyin(stock[1])
+
+ # 取每个拼音的首字母并连接成字符串
+ initial_pinyin = ''.join([pinyin[0] for pinyin in pinyin_list])
+ await StockDetails.create(stock_code=stock[0], stock_name=stock[1], type=StockType.SH,
+ stock_pinyin=initial_pinyin, latest_price=stock[2], rise_fall=stock[3])
+ stock_name_df = akshare.stock_sz_a_spot_em()
+ stock_list = stock_name_df[['代码', '名称', '最新价', '涨跌幅']].values.tolist()
+
+ for stock in stock_list:
+ # 将文字转化为拼音列表
+ pinyin_list = lazy_pinyin(stock[1])
+
+ # 取每个拼音的首字母并连接成字符串
+ initial_pinyin = ''.join([pinyin[0] for pinyin in pinyin_list])
+ await StockDetails.create(stock_code=stock[0], stock_name=stock[1], type=StockType.SZ,
+ stock_pinyin=initial_pinyin, latest_price=stock[2], rise_fall=stock[3])
+
+ await close_tortoise()
\ No newline at end of file
diff --git a/src/utils/jsonformatter.py b/src/utils/jsonformatter.py
new file mode 100644
index 0000000..32ae936
--- /dev/null
+++ b/src/utils/jsonformatter.py
@@ -0,0 +1,15 @@
+import json
+import logging
+
+
+class JsonFormatter(logging.Formatter):
+ """日志文件转为json格式"""
+
+ def format(self, record):
+ obj = {
+ 'timestamp': self.formatTime(record, self.datefmt),
+ 'name': record.name,
+ 'level': record.levelname,
+ 'message': record.getMessage()
+ }
+ return json.dumps(obj)
diff --git a/src/utils/models.py b/src/utils/models.py
new file mode 100644
index 0000000..587efbc
--- /dev/null
+++ b/src/utils/models.py
@@ -0,0 +1,37 @@
+from datetime import datetime
+from typing import Any
+from zoneinfo import ZoneInfo
+
+from fastapi.encoders import jsonable_encoder
+from pydantic import BaseModel, ConfigDict, model_validator
+
+
+def convert_datetime_to_gmt(dt: datetime) -> str:
+ if not dt.tzinfo:
+ dt = dt.replace(tzinfo=ZoneInfo("UTC"))
+
+ return dt.strftime("%Y-%m-%dT%H:%M:%S%z")
+
+
+class CustomModel(BaseModel):
+ model_config = ConfigDict(
+ json_encoders={datetime: convert_datetime_to_gmt},
+ populate_by_name=True,
+ )
+
+ @model_validator(mode="before")
+ @classmethod
+ def set_null_microseconds(cls, data: dict[str, Any]) -> dict[str, Any]:
+ datetime_fields = {
+ k: v.replace(microsecond=0)
+ for k, v in data.items()
+ if isinstance(k, datetime)
+ }
+
+ return {**data, **datetime_fields}
+
+ def serializable_dict(self, **kwargs):
+ """Return a dict which contains only serializable fields."""
+ default_dict = self.model_dump()
+
+ return jsonable_encoder(default_dict)
diff --git a/src/utils/paginations.py b/src/utils/paginations.py
new file mode 100644
index 0000000..21cf264
--- /dev/null
+++ b/src/utils/paginations.py
@@ -0,0 +1,74 @@
+from typing import Generic, Optional, Sequence, TypeVar
+
+import math
+from fastapi import Query
+from pydantic import BaseModel
+from tortoise.queryset import QuerySet
+
+T = TypeVar("T")
+
+
+class PaginationPydantic(BaseModel, Generic[T]):
+ """分页模型"""
+ status_code: int = 200
+ message: str = "Success"
+ total: int
+ page: int
+ size: int
+ total_pages: int
+ data: Sequence[T]
+
+
+class Params(BaseModel):
+ """传参"""
+ # 设置默认值为1,不能够小于1
+ page: int = Query(1, ge=1, description="Page number")
+
+ # 设置默认值为10,最大为100
+ size: int = Query(10, gt=0, le=200, description="Page size")
+
+ # 默认值None表示选传
+ order_by: Optional[str] = Query(None, max_length=32, description="Sort key")
+
+
+async def pagination(pydantic_model, query_set: QuerySet, params: Params, callback=None):
+ """分页响应"""
+ """
+ pydantic_model: Pydantic model
+ query_set: QuerySet
+ params: Params
+ callback: if you want to do something for query_set,it will be useful.
+ """
+ page: int = params.page
+ size: int = params.size
+ order_by: str = params.order_by
+ total = await query_set.count()
+
+ # 通过总数和每页数量计算出总页数
+ total_pages = math.ceil(total / size)
+
+ if page > total_pages and total: # 排除查询集为空时报错,即total=0时
+ raise ValueError("页数输入有误")
+
+ # 排序后分页
+ if order_by:
+ order_by = order_by.split(',')
+ query_set = query_set.order_by(*order_by)
+
+ # 分页
+ query_set = query_set.offset((page - 1) * size) # 页数 * 页面大小=偏移量
+ query_set = query_set.limit(size)
+
+ if callback:
+ """对查询集操作"""
+ query_set = await callback(query_set)
+
+ # query_set = await query_set
+ data = await pydantic_model.from_queryset(query_set)
+ return PaginationPydantic(**{
+ "total": total,
+ "page": page,
+ "size": size,
+ "total_pages": total_pages,
+ "data": data,
+ })
diff --git a/src/utils/redis.py b/src/utils/redis.py
new file mode 100644
index 0000000..0e7f139
--- /dev/null
+++ b/src/utils/redis.py
@@ -0,0 +1,31 @@
+from datetime import timedelta
+from typing import Optional
+
+import redis.asyncio as aioredis
+
+from src.utils.models import CustomModel
+
+redis_client: aioredis = None # type: ignore
+
+
+class RedisData(CustomModel):
+ key: bytes | str
+ value: bytes | str
+ ttl: Optional[int | timedelta] = None
+
+
+async def set_redis_key(redis_data: RedisData, *, is_transaction: bool = False) -> None:
+ async with redis_client.pipeline(transaction=is_transaction) as pipe:
+ await pipe.set(redis_data.key, redis_data.value)
+ if redis_data.ttl:
+ await pipe.expire(redis_data.key, redis_data.ttl)
+
+ await pipe.execute()
+
+
+async def get_by_key(key: str) -> Optional[str]:
+ return await redis_client.get(key)
+
+
+async def delete_by_key(key: str) -> None:
+ return await redis_client.delete(key)
diff --git a/src/utils/remove_duplicates_databases.py b/src/utils/remove_duplicates_databases.py
new file mode 100644
index 0000000..972fb04
--- /dev/null
+++ b/src/utils/remove_duplicates_databases.py
@@ -0,0 +1,40 @@
+from tortoise.models import Model
+from tortoise import fields
+from tortoise.functions import Count
+
+# 调用异步方法
+import asyncio
+
+from src.models.wance_data_stock import WanceDataStock
+from src.tortoises_orm_config import init_tortoise
+
+
+async def remove_duplicates():
+ await init_tortoise()
+
+ # 获取重复的股票代码和时间
+ duplicates = await WanceDataStock.all().group_by('stock_code', 'time_end').annotate(count=Count('id')).filter(
+ count__gt=1)
+
+ if not duplicates:
+ print("No duplicates found.")
+ return
+
+ print(f"Found {len(duplicates)} duplicates.")
+
+ for record in duplicates:
+ # 查询相同 stock_code 和 time_end 的所有记录,并按 id 逆序排序
+ duplicate_records = await WanceDataStock.filter(stock_code=record.stock_code,
+ time_end=record.time_end).order_by('-id')
+
+ # 输出当前处理的记录
+ print(
+ f"Processing {len(duplicate_records)} records for stock_code={record.stock_code} and time_end={record.time_end}.")
+
+ if duplicate_records:
+ # 删除最新的一条记录,保留其余的
+ await WanceDataStock.filter(id=duplicate_records[0].id).delete()
+ print(f"Deleted record with id={duplicate_records[0].id}")
+
+
+asyncio.run(remove_duplicates())
diff --git a/src/utils/split_stock_utils.py b/src/utils/split_stock_utils.py
new file mode 100644
index 0000000..945aa1f
--- /dev/null
+++ b/src/utils/split_stock_utils.py
@@ -0,0 +1,41 @@
+"""
+
+通用工具区域
+
+"""
+
+
+def split_stock_code(stock_code):
+ """
+ 将股票代码按点(".")切割成两部分。
+
+ Args:
+ stock_code (str): 原始股票代码(如 "000300.SH")。
+
+ Returns:
+ tuple: 返回切割后的两部分(如 ('000300', 'SH'))。
+ """
+ return stock_code.split('.')
+
+
+def join_stock_code(parts):
+ """
+ 将切割后的股票代码部分按指定格式拼接。
+
+ Args:
+ parts (tuple): 切割后的股票代码部分(如 ('000300', 'SH'))。
+
+ Returns:
+ str: 返回拼接后的字符串(如 "SH000300")。
+ """
+ part1, part2 = parts # 解构切割后的结果
+ return f"{part2}{part1}"
+
+
+def percent_to_float(percent_string):
+ if isinstance(percent_string, int):
+ return float(percent_string)
+ elif isinstance(percent_string, str):
+ return float(percent_string.strip('%'))
+ else:
+ raise ValueError(f"Unsupported type: {type(percent_string)}")
\ No newline at end of file
diff --git a/src/xtdata/__init__.py b/src/xtdata/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/xtdata/__pycache__/__init__.cpython-311.pyc b/src/xtdata/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..8098298
Binary files /dev/null and b/src/xtdata/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/xtdata/__pycache__/router.cpython-311.pyc b/src/xtdata/__pycache__/router.cpython-311.pyc
new file mode 100644
index 0000000..006c2d6
Binary files /dev/null and b/src/xtdata/__pycache__/router.cpython-311.pyc differ
diff --git a/src/xtdata/__pycache__/service.cpython-311.pyc b/src/xtdata/__pycache__/service.cpython-311.pyc
new file mode 100644
index 0000000..54dc4dc
Binary files /dev/null and b/src/xtdata/__pycache__/service.cpython-311.pyc differ
diff --git a/src/xtdata/router.py b/src/xtdata/router.py
new file mode 100644
index 0000000..af619b7
--- /dev/null
+++ b/src/xtdata/router.py
@@ -0,0 +1,202 @@
+import json # 导入 json 库
+
+from fastapi import APIRouter, HTTPException # 从 FastAPI 中导入 APIRouter,用于创建 API 路由器
+
+import src.xtdata.service as service # 导入服务模块,该模块包含各种服务的实现
+from src.pydantic.codelistrequest import CodeListRequest # 导入 Pydantic 模型,用于验证请求参数
+from src.pydantic.factor_request import StockQuery
+from src.pydantic.request_data import DataRequest # 导入 Pydantic 模型,用于验证请求参数
+from src.responses import response_entity_response, response_list_response # 导入自定义响应格式化函数
+
+router = APIRouter() # 创建一个 FastAPI 路由器实例
+
+
+# 获取完整的逐笔成交键
+@router.get("/get_full_tick_keys")
+async def get_full_tick_keys(request: CodeListRequest):
+ """
+ 获取完整的逐笔成交键。
+ """
+ result = await service.get_full_tick_keys_service(request.code_list) # 调用服务获取逐笔成交键
+ return result # 返回获取的结果
+
+
+# 获取完整的逐笔成交数据
+@router.get("/get_full_tick")
+async def get_full_tick(request: CodeListRequest):
+ """
+ 获取完整的逐笔成交数据。
+ """
+ result = await service.get_full_tick_service(request.code_list) # 调用服务获取逐笔成交数据
+ return result # 返回获取的结果
+
+
+# 获取市场数据
+@router.get("/get_market_data")
+async def get_market_data(request: DataRequest):
+ """
+ 获取市场数据。
+ """
+ # 调用服务获取市场数据
+ market_data = await service.get_market_data_service(field_list=request.field_list,
+ stock_list=request.stock_list,
+ period=request.period,
+ start_time=request.start_time,
+ end_time=request.end_time,
+ count=request.count,
+ dividend_type=request.dividend_type,
+ fill_data=request.fill_data)
+ return market_data # 返回市场数据
+
+
+# 获取完整的 K 线数据
+@router.get("/get_full_kline")
+async def get_full_kline(request: DataRequest):
+ """
+ 获取完整的 K 线数据。
+ """
+ # 调用服务获取完整的 K 线数据
+ result = await service.get_full_kline_service(field_list=request.field_list,
+ stock_list=request.stock_list,
+ period=request.period,
+ start_time=request.start_time,
+ end_time=request.end_time,
+ count=request.count,
+ dividend_type=request.dividend_type,
+ fill_data=request.fill_data)
+ return result # 返回 K 线数据
+
+
+# 获取合约详情
+@router.get("/get_instrument_detail")
+async def get_instrument_detail(request: DataRequest):
+ """
+ 获取合约的详细信息。
+ """
+ # 调用服务获取合约详情数据
+ result = await service.get_instrument_detail_service(stock_code=request.stock_code,
+ iscomplete=request.iscomplete)
+ return result # 返回合约详情
+
+
+@router.get("/get_stock_factor")
+async def get_stock_factor(query_params: StockQuery):
+ stocks = await service.get_stock_factor_service(query_params)
+ return stocks
+
+
+# 获取本地数据
+@router.get("/get_local_data")
+async def get_local_data(request: DataRequest):
+ """
+ 获取本地存储的股票数据。
+ """
+ # 调用服务获取本地数据
+ result = await service.get_local_data_service(field_list=request.field_list, stock_list=request.stock_list,
+ period=request.period, start_time=request.start_time,
+ end_time=request.end_time, count=request.count,
+ dividend_type=request.dividend_type, fill_data=request.fill_data,
+ data_dir=request.data_dir)
+ return result
+
+
+# 获取板块列表
+@router.post("/get_sector_list")
+async def get_sector_list():
+ """
+ 获取所有板块的列表。
+ """
+ # 调用服务获取板块列表
+ result = await service.get_sector_list_service()
+ return result # 返回板块列表
+
+
+# 获取某板块中的股票列表
+@router.get("/get_stock_list_in_sector")
+async def get_stock_list_in_sector(request: DataRequest):
+ """
+ 获取指定板块中的股票列表。
+ """
+ # 调用服务获取板块中的股票列表
+ result = await service.get_stock_list_in_sector_service(sector_name=request.sector_name)
+ return result # 返回股票列表
+
+
+# 下载板块数据
+@router.post("/download_sector_data")
+async def download_sector_data():
+ """
+ 下载板块的详细数据。
+ """
+ # 调用服务下载板块数据
+ result = await service.download_sector_data_service()
+ return result # 返回下载结果
+
+
+# 订阅股票报价
+@router.post("/subscribe_quote")
+async def subscribe_quote(request: DataRequest):
+ """
+ 订阅单只股票的报价。
+ """
+ # 调用服务订阅股票报价
+ result = await service.subscribe_quote_service(stock_code=request.stock_code,
+ period=request.period,
+ start_time=request.start_time,
+ end_time=request.end_time,
+ count=request.count,
+ callback=request.callback)
+ return response_list_response(data=result, status_code=200, message="Success") # 返回响应
+
+
+# 订阅所有股票报价
+@router.post("/subscribe_whole_quote")
+async def subscribe_whole_quote(request: DataRequest):
+ """
+ 订阅所有股票的报价。
+ """
+ # 调用服务订阅所有股票的报价
+ result = await service.subscribe_whole_quote_service(code_list=request.code_list,
+ callback=request.callback)
+ return response_list_response(data=result, status_code=200, message="Success") # 返回响应
+
+
+# 下载历史数据
+@router.post("/download_history_data")
+async def download_history_data(request: DataRequest):
+ """
+ 下载指定股票的历史数据。
+ """
+ # 调用服务下载指定股票的历史数据
+ await service.download_history_data_service(stock_code=request.stock_code,
+ period=request.period,
+ start_time=request.start_time,
+ end_time=request.end_time,
+ incrementally=request.incrementally)
+ return response_list_response(data=[], status_code=200, message="Success") # 返回响应
+
+
+# 批量下载历史数据
+@router.post("/batch_download_history_data")
+async def download_history_data(request: DataRequest):
+ """
+ 批量下载多个股票的历史数据。
+ """
+ # 调用服务批量下载历史数据
+ await service.download_history_data2_service(stock_list=request.stock_list,
+ period=request.period,
+ start_time=request.start_time,
+ end_time=request.end_time,
+ callback=request.callback)
+ return response_list_response(data=[], status_code=200, message="Success") # 返回响应
+
+
+# 批量下载历史数据
+@router.post("/update_stock_data")
+async def update_stock_data():
+ """
+ 批量下载多个股票的历史数据。
+ """
+ # 调用服务批量下载历史数据
+ await service.update_stock_data_service()
+ return response_list_response(data=[], status_code=200, message="Success") # 返回响应
diff --git a/src/xtdata/service.py b/src/xtdata/service.py
new file mode 100644
index 0000000..76708be
--- /dev/null
+++ b/src/xtdata/service.py
@@ -0,0 +1,363 @@
+import asyncio
+import json
+from datetime import datetime, timedelta
+
+import numpy as np # 导入 numpy 库
+from tortoise.expressions import Q
+from xtquant import xtdata # 导入 xtquant 库的 xtdata 模块
+
+from src.backtest.until import convert_pandas_to_json_serializable
+from src.models.wance_data_stock import WanceDataStock
+from src.pydantic.factor_request import StockQuery
+from src.utils.history_data_processing_utils import translation_dict
+
+
+# 获取股票池中的所有股票的股票代码
+async def get_full_tick_keys_service(code_list: list):
+ """
+ 获取所有股票的逐笔成交键。若未提供股票列表,则默认使用 ['SH', 'SZ']。
+ """
+ if len(code_list) == 0: # 如果股票代码列表为空
+ code_list = ['SH', 'SZ'] # 默认使用上证和深证股票池
+ result = xtdata.get_full_tick(code_list=['SH', 'SZ']) # 调用 xtdata 库的函数获取逐笔成交数据
+ return list(result.keys()) # 返回所有股票代码的列表
+
+
+# 获取股票池中的所有股票的逐笔成交数据
+async def get_full_tick_service(code_list: list):
+ """
+ 获取所有股票的逐笔成交数据。若未提供股票列表,则默认使用 ['SH', 'SZ']。
+ """
+ if len(code_list) == 0: # 如果股票代码列表为空
+ code_list = ['SH', 'SZ'] # 默认使用上证和深证股票池
+ result = xtdata.get_full_tick(code_list=code_list) # 获取逐笔成交数据
+ return result # 返回逐笔成交数据
+
+
+# 获取板块信息
+async def get_sector_list_service():
+ """
+ 获取所有板块的列表。
+ """
+ result = xtdata.get_sector_list() # 调用 xtdata 库的函数获取板块信息
+ return result # 返回板块列表
+
+
+# 通过板块名称获取股票列表
+async def get_stock_list_in_sector_service(sector_name):
+ """
+ 获取指定板块中的所有股票名称。
+ """
+ result = xtdata.get_stock_list_in_sector(sector_name=sector_name) # 根据板块名称获取股票列表
+ return result # 返回股票列表
+
+
+# 下载板块信息
+async def download_sector_data_service():
+ """
+ 下载板块的详细数据。
+ """
+ result = xtdata.download_sector_data() # 调用 xtdata 库的函数下载板块数据
+ return result # 返回下载结果
+
+
+# 获取股票的详细信息
+async def get_instrument_detail_service(stock_code: str, iscomplete: bool):
+ """
+ 获取股票的详细信息。
+ """
+ result = xtdata.get_instrument_detail(stock_code=stock_code, iscomplete=iscomplete) # 获取股票的详细信息
+ return result # 返回股票的详细信息
+
+
+# 获取市场行情数据
+async def get_market_data_service(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
+ count: int, dividend_type: str, fill_data: bool):
+ """
+ 获取指定条件下的市场行情数据。
+ """
+ result = xtdata.get_market_data(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
+ end_time=end_time, count=count, dividend_type=dividend_type,
+ fill_data=fill_data) # 获取市场数据
+ return result # 返回市场数据
+
+
+# 获取最新的 K 线数据
+async def get_full_kline_service(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
+ count: int,
+ dividend_type: str, fill_data: bool):
+ """
+ 获取指定条件下的完整 K 线数据。
+ """
+ result = xtdata.get_full_kline(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
+ end_time=end_time, count=count, dividend_type=dividend_type,
+ fill_data=fill_data) # 获取完整的 K 线数据
+ return result # 返回 K 线数据
+
+
+# 下载历史数据
+async def download_history_data_service(stock_code: str, period: str, start_time: str, end_time: str,
+ incrementally: bool):
+ """
+ 下载指定股票的历史数据。
+ """
+ xtdata.download_history_data(stock_code=stock_code, period=period, start_time=start_time, end_time=end_time,
+ incrementally=incrementally) # 下载股票历史数据
+
+
+# 获取本地数据
+async def get_local_data_service(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
+ count: int, dividend_type: str, fill_data: bool, data_dir: str):
+ """
+ @param field_list:
+ @type field_list:
+ @param stock_list:
+ @type stock_list:
+ @param period:
+ @type period:
+ @param start_time:
+ @type start_time:
+ @param end_time:
+ @type end_time:
+ @param count:
+ @type count:
+ @param dividend_type:
+ @type dividend_type:
+ @param fill_data:
+ @type fill_data:
+ @param data_dir:
+ @type data_dir:
+ @return:
+ @rtype:
+ """
+ """
+ 获取本地存储的股票数据。
+ """
+ return_list = []
+ result = xtdata.get_local_data(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
+ end_time=end_time, count=count, dividend_type=dividend_type, fill_data=fill_data,
+ data_dir=data_dir) # 获取本地数据
+ for i in stock_list:
+ return_list.append({
+ i : convert_pandas_to_json_serializable(result.get(i))
+ })
+ return return_list
+
+
+# 订阅全市场行情推送
+async def subscribe_whole_quote_service(code_list, callback=None):
+ """
+ 订阅全市场行情推送。
+ """
+ if callback:
+ result = xtdata.subscribe_whole_quote(code_list, callback=on_data) # 有回调函数时调用
+ else:
+ result = xtdata.subscribe_whole_quote(code_list, callback=None) # 无回调函数时调用
+ return result # 返回订阅结果
+
+
+# 订阅单只股票行情推送
+async def subscribe_quote_service(stock_code: str, period: str, start_time: str, end_time: str, count: int,
+ callback: bool):
+ """
+ 订阅单只股票的行情推送。
+ """
+ if callback:
+ result = xtdata.subscribe_quote(stock_code=stock_code, period=period, start_time=start_time, end_time=end_time,
+ count=count, callback=on_data) # 有回调函数时调用
+ else:
+ result = xtdata.subscribe_quote(stock_code=stock_code, period=period, start_time=start_time, end_time=end_time,
+ count=count, callback=None) # 无回调函数时调用
+ return result # 返回订阅结果
+
+
+def on_data(datas):
+ """
+ 行情数据推送的回调函数,处理并打印接收到的行情数据。
+ """
+ for stock_code in datas:
+ print(stock_code, datas[stock_code]) # 打印接收到的行情数据
+
+
+# 批量增量下载或批量下载
+async def download_history_data2_service(stock_list: list, period: str, start_time: str, end_time: str, callback: bool):
+ """
+ 批量下载股票历史数据,支持进度回调。
+ """
+ if callback:
+ xtdata.download_history_data2(stock_list=stock_list, period=period, start_time=start_time, end_time=end_time,
+ callback=on_progress) # 有回调函数时调用
+ else:
+ xtdata.download_history_data2(stock_list=stock_list, period=period, start_time=start_time, end_time=end_time,
+ callback=None) # 无回调函数时调用
+
+
+def on_progress(data):
+ """
+ 数据下载进度的回调函数,实时显示进度。
+ """
+ print(data) # 打印下载进度
+
+
+# 单只股票下载示例
+def download_data():
+ """
+ 测试函数:下载单只股票的历史数据并获取相关信息。
+ """
+ xtdata.download_history_data("000300.SH", "1d", start_time='', end_time='', incrementally=None) # 下载历史数据
+ a = xtdata.get_full_tick(code_list=['SH', 'SZ']) # 获取逐笔成交数据
+ print(list(a.keys())) # 打印获取的股票代码
+
+
+async def update_stock_data_service():
+ """
+ # 用于本地xtdata数据更新
+ @return:
+ @rtype:
+ """
+ stock_list = await get_full_tick_keys_service(['SH', 'SZ'])
+ for stock in stock_list:
+ await download_history_data_service(stock_code=stock,
+ period='1d',
+ start_time='',
+ end_time='',
+ incrementally=True)
+
+
+async def get_stock_factor_service(query_params: StockQuery):
+ """
+ 动态查询方法,根据传入的字段和值进行条件构造
+ :param kwargs: 查询条件,键为字段名,值为查询值
+ :return: 符合条件的记录列表
+ """
+ """
+ 使用示例:
+ {
+ "financial_asset_value": 10.7169,
+ "financial_cash_flow": -0.42982,
+ "gt": {
+ "profit_gross_rate": 4.0,
+ "valuation_PB_TTM": 0.8
+ },
+ "lt": {
+ "market_indicator": -0.6
+ },
+ "gte": {
+ "valuation_market_percentile":5.38793
+ },
+ "lte": {
+ "valuation_PEGTTM_ratio": 15.2773
+ },
+ "between": {
+ "financial_dividend": {
+ "min": 1.0,
+ "max": 2.0
+ },
+ "growth_Income_rate": {
+ "min": -0.6,
+ "max": 1.0
+ }
+ },
+ "stock_code": "600051.SH",
+ "stock_name": "宁波联合",
+ "stock_sector": ["上证A股", "沪深A股"],
+ "stock_type": ["stock"],
+ "time_start": "0",
+ "time_end": "20240911",
+ "valuation_PEG_percentile": 15.2222,
+ "valuation_PTS_TTM": 2.3077
+ }
+
+
+ financial_asset_value 和 financial_cash_flow 字段为普通等于条件。
+ gt 字段用于指定字段大于某个值的条件,例如 profit_gross_rate 大于 15.0。
+ lt 字段用于指定字段小于某个值的条件,例如 market_indicator 小于 30.0。
+ gte 字段用于指定字段大于等于某个值的条件,例如 valuation_market_percentile 大于等于 0.5。
+ lte 字段用于指定字段小于等于某个值的条件,例如 valuation_PEGTTM_ratio 小于等于 2.0。
+ between 字段用于指定字段在两个值之间的范围,例如 financial_dividend 在 1.0 和 5.0 之间,growth_Income_rate 在 10.0 和 25.0 之间。
+ stock_code 和 stock_name 字段为普通等于条件。
+ stock_sector 和 stock_type 字段为列表类型的条件,可能会进行包含匹配。
+ time_start 和 time_end 字段为普通等于条件。
+ valuation_PEG_percentile 和 valuation_PTS_TTM 字段为普通等于条件。
+
+ 比较大小字段说明:
+ 'gt' 对应 __gt(大于)
+ 'lt' 对应 __lt(小于)
+ 'gte' 对应 __gte(大于等于)
+ 'lte' 对应 __lte(小于等于)
+ 'between' 对应在某个值之间的条件,使用 __gt 和 __lt 组合实现
+
+ 时间字段说明:
+ time_start={'recent_years': 2} 近2年
+ time_start={'recent_months': 2} 近2月
+ """
+
+ filters = {}
+
+ # 处理大于、小于等操作
+ if query_params.gt:
+ for field, value in query_params.gt.items():
+ filters[f"{field}__gt"] = value
+ if query_params.lt:
+ for field, value in query_params.lt.items():
+ filters[f"{field}__lt"] = value
+ if query_params.gte:
+ for field, value in query_params.gte.items():
+ filters[f"{field}__gte"] = value
+ if query_params.lte:
+ for field, value in query_params.lte.items():
+ filters[f"{field}__lte"] = value
+
+ # 处理范围查询
+ if query_params.between:
+ for field, range_values in query_params.between.items():
+ if 'min' in range_values and 'max' in range_values:
+ filters[f"{field}__gte"] = range_values['min']
+ filters[f"{field}__lte"] = range_values['max']
+
+ # 处理普通字段
+ for field, value in query_params.model_dump(exclude_unset=True).items():
+ if field not in ['gt', 'lt', 'gte', 'lte', 'between']:
+ if isinstance(value, float):
+ filters[f"{field}__gte"] = value
+ elif isinstance(value, str):
+ filters[f"{field}__icontains"] = value
+ elif isinstance(value, list):
+ if field == "stock_sector":
+ # 转换 `stock_sector` 中的中文板块名到英文
+ translated_sectors = [translation_dict.get(sector, sector) for sector in value]
+ filters[f"{field}__contains"] = translated_sectors
+ else:
+ filters[f"{field}__contains"] = value
+
+ try:
+ stocks = await WanceDataStock.filter(**filters).all()
+ return stocks
+ except Exception as e:
+ print(f"Error occurred when querying stocks: {e}")
+ return []
+
+
+# 将 numpy 类型转换为原生 Python 类型
+async def convert_numpy_to_native(obj):
+ """
+ 递归地将 numpy 类型转换为原生 Python 类型。
+ """
+ if isinstance(obj, np.generic): # 如果是 numpy 类型
+ return obj.item() # 转换为 Python 原生类型
+ elif isinstance(obj, dict): # 如果是字典类型
+ return {k: convert_numpy_to_native(v) for k, v in obj.items()} # 递归转换字典中的值
+ elif isinstance(obj, list): # 如果是列表类型
+ return [convert_numpy_to_native(i) for i in obj] # 递归转换列表中的值
+ elif isinstance(obj, tuple): # 如果是元组类型
+ return tuple(convert_numpy_to_native(i) for i in obj) # 递归转换元组中的值
+ else:
+ return obj # 如果是其他类型,直接返回
+
+
+if __name__ == '__main__':
+ # result = asyncio.run(get_instrument_detail_service("000300.SH",False))
+ # print(result)
+ # resulta = asyncio.run(get_sector_list_service())
+ # print(resulta)
+ pass