Compare commits

..

No commits in common. "023cc64d20035367f4ea54c5b9d3391480a7277c" and "1d8b78593ebc3f22d422bb837bb164f5ca3b92c6" have entirely different histories.

96 changed files with 508 additions and 1229 deletions

44
.env

@ -1,23 +1,23 @@
# 使用本地docker数据库就将这段代码解开如何将IP换成本机IP或者localhost
# DATABASE_URL="mysql://root:123456@10.1.5.219:3306/ftrade"
# DATABASE_CREATE_URL ="mysql://root:123456@10.1.5.219:3306/ftrade"
DATABASE_URL="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche"
DATABASE_CREATE_URL ="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche"
REDIS_URL="redis://:redis_tQNjCH@cqxqg.tech:6380"
SITE_DOMAIN=127.0.0.1
SECURE_COOKIES=false
ENVIRONMENT=LOCAL
CORS_HEADERS=["*"]
CORS_ORIGINS=["http://localhost:3000"]
# postgres variables, must be the same as in DATABASE_URL
POSTGRES_USER=app
POSTGRES_PASSWORD=app
POSTGRES_HOST=app_db
POSTGRES_PORT=5432
;使用本地docker数据库就将这段代码解开如何将IP换成本机IP或者localhost
; DATABASE_URL="mysql://root:123456@10.1.5.219:3306/ftrade"
; DATABASE_CREATE_URL ="mysql://root:123456@10.1.5.219:3306/ftrade"
DATABASE_URL="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche"
DATABASE_CREATE_URL ="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche"
REDIS_URL="redis://:redis_tQNjCH@cqxqg.tech:6380"
SITE_DOMAIN=127.0.0.1
SECURE_COOKIES=false
ENVIRONMENT=LOCAL
CORS_HEADERS=["*"]
CORS_ORIGINS=["http://localhost:3000"]
# postgres variables, must be the same as in DATABASE_URL
POSTGRES_USER=app
POSTGRES_PASSWORD=app
POSTGRES_HOST=app_db
POSTGRES_PORT=5432
POSTGRES_DB=app

6
.gitignore vendored

@ -1,6 +0,0 @@
# Ignore compiled Python files
*.pyc
__pycache__/
# Ignore virtual environment folder
venv/

@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.7" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (wance_data)" project-jdk-type="Python SDK" />
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.7" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="ftrade_back" project-jdk-type="Python SDK" />
</project>

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

@ -1,17 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/venv" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PackageRequirementsSettings">
<option name="requirementsPath" value="" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="EPYTEXT" />
<option name="myDocStringFormat" value="Epytext" />
</component>
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="ftrade_back" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PackageRequirementsSettings">
<option name="requirementsPath" value="" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="EPYTEXT" />
<option name="myDocStringFormat" value="Epytext" />
</component>
</module>

17
.vscode/launch.json vendored

@ -1,17 +0,0 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python 调试程序: FastAPI",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"args": [
"src.main:app",
"--reload",
"--port", "8012"
],
"jinja": true
}
]
}

Binary file not shown.

Binary file not shown.

@ -1,32 +1,32 @@
import json
from fastapi import APIRouter
from src.backtest.service import start_backtest_service, stock_chart_service
from src.pydantic.backtest_request import BackRequest
router = APIRouter() # 创建一个 FastAPI 路由器实例
@router.get("/start_backtest")
async def start_backtest(request: BackRequest):
result = await start_backtest_service(field_list=['close', 'time'],
stock_list=request.stock_list,
period=request.period,
start_time=request.start_time,
end_time=request.end_time,
count=request.count,
dividend_type=request.dividend_type,
fill_data=request.fill_data,
ma_type=request.ma_type,
short_window=request.short_window,
long_window=request.long_window
)
return result
@router.get('/stock_chart')
async def stock_chart(request: BackRequest):
result = await stock_chart_service(stock_code=request.stock_code,
benchmark_code=request.benchmark_code)
return result
import json
from fastapi import APIRouter
from src.backtest.service import start_backtest_service, stock_chart_service
from src.pydantic.backtest_request import BackRequest
router = APIRouter() # 创建一个 FastAPI 路由器实例
@router.get("/start_backtest")
async def start_backtest(request: BackRequest):
result = await start_backtest_service(field_list=['close', 'time'],
stock_list=request.stock_list,
period=request.period,
start_time=request.start_time,
end_time=request.end_time,
count=request.count,
dividend_type=request.dividend_type,
fill_data=request.fill_data,
ma_type=request.ma_type,
short_window=request.short_window,
long_window=request.long_window
)
return result
@router.get('/stock_chart')
async def stock_chart(request: BackRequest):
result = await stock_chart_service(stock_code=request.stock_code,
benchmark_code=request.benchmark_code)
return result

@ -1,17 +0,0 @@
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from src.responses import response_list_response
from src.composite.service import *
composite_router = APIRouter()
@composite_router.get("/query-composite")
async def composite_router_query_composite(user_id: int )-> JSONResponse:
"""
查询已有多因子组合
"""
result = await composite_router_query_composite_service(user_id)
return response_list_response(data=result, message="多因子组合查询成功")

@ -1 +0,0 @@

@ -1,10 +0,0 @@
from src.models.user_strategy import *
async def composite_router_query_composite_service(
user_id: int
) -> list:
# 查询所有符合条件的记录并转为字典
user_strategies = await UserStrategy.filter(user_id=user_id).values() # 转换为字典列表
return user_strategies

@ -1,33 +0,0 @@
"""
映射表 为空则返回none表示不筛选此字段
"""
# 时间选择映射表,用于枚举验证
year_mapping = {
2024: 2024,
}
# 股票池映射表
stock_pool_mapping = {
"000300.SH": "000300.SH",
"000905.SH": "000905.SH",
"399303.SZ": "399303.SZ",
"large_mv": "large_mv",
"medium_mv": "medium_mv",
"small_mv": "small_mv",
}
# 财报周期映射
period_mapping = {
1: "一季度报告", # 一季度报告
2: "半年度报告", # 半年度报告
3: "三季度报告", # 三季度报告
4: "正式年报", # 正式年报
9: "一季度业绩预告", # 一季度业绩预告
10: "半年度业绩预告", # 半年度业绩预告
11: "三季度业绩预告", # 三季度业绩预告
12: "年度业绩预告", # 年度业绩预告
5: "一季度业绩快报", # 一季度业绩快报
6: "半年度业绩快报", # 半年度业绩快报
7: "三季度业绩快报" # 三季度业绩快报
}

@ -1,27 +0,0 @@
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from src.financial_reports.schemas import *
from src.financial_reports.service import *
from src.responses import response_list_response
financial_reports_router = APIRouter()
@financial_reports_router.post("/query")
async def financial_repoets_query(request: FinancialReportQuery )-> JSONResponse:
"""
搜索接口
"""
result = await financial_reports_query_service(request)
return result
@financial_reports_router.post("/keyword-search")
async def financial_repoets_query_keyword_search(query: SearchKeywordQuery) -> list:
"""
模糊查询接口
"""
result = await user_keyword_search_service(query.searchKeyword)
# 使用封装类返回列表响应
return response_list_response(data=result, message="联想词查询成功")

@ -1,19 +0,0 @@
from typing import List, Optional, Any, Dict
from pydantic import BaseModel, Field
class FinancialReportQuery(BaseModel):
pageNo: int = Field(..., ge=1, le=1000, description="页码必须大于等于1且小于等于1000")
pageSize: int = Field(..., ge=1, le=50, description="每页条数必须大于等于1且小于等于50")
# 可选参数
year: Optional[int] = Field(None, ge=1, le=2100, description="时间选择必须为正整数或特定年份且小于等于2100或留空")
stockPoolCode: Optional[str] = Field(None, max_length=20, min_length=1, description="股票池限制最大长度为20限制最小长度为1或留空")
period: Optional[int] = Field(None, ge=1, le=50, description="财报周期必须大于等于1且小于等于50或留空")
revenueYoyType: Optional[int] = Field(None, ge=1, le=10, description="营业收入同比必须大于等于1且小于等于10或留空")
nIncomeYoyType: Optional[int] = Field(None, ge=1, le=10, description="净利润同比必须大于等于1且小于等于10或留空")
revenueYoyTop10: Optional[int] = Field(None, ge=1, le=10, description="营业收入行业前10%必须大于等于1且小于等于10或留空")
searchKeyword: Optional[str] = Field(None, min_length=1, max_length=30, description="搜索关键词可选长度在1到30字符之间")
class SearchKeywordQuery(BaseModel):
searchKeyword: Optional[str] = Field(None, min_length=1, max_length=30, description="搜索关键词可选长度在1到30字符之间")

@ -1,200 +0,0 @@
from math import ceil
from fastapi import Request, HTTPException, BackgroundTasks
from tortoise.expressions import Subquery,Q
from tortoise.contrib.pydantic import pydantic_model_creator
from src.models.financial_reports import *
from src.models.stock_hu_shen300 import *
from src.models.stock_zhong_zheng_500 import *
from src.models.stock_guo_zheng_2000 import *
from src.models.stock_hu_shen_jing_a import *
from src.models.stock import *
from src.financial_reports.schemas import *
from src.utils.paginations import *
from src.financial_reports.mappings import *
from src.utils.identify_keyword_type import * #引入判断汉字,拼音,数字的工具类
# 创建一个不包含 "created_at" 字段的 Pydantic 模型用于响应
FinancialReport_Pydantic = pydantic_model_creator(FinancialReport, exclude=("created_at",))
async def financial_reports_query_service(
request: FinancialReportQuery
) -> PaginationPydantic:
"""
根据选择器和关键词查询财报并进行分页
"""
# 构建查询集,这里假设财报数据模型为 FinancialReport
query_set: QuerySet = FinancialReport.all()
try:
# 判空处理,如果 searchKeyword 没有值,跳过关键词筛选逻辑
if not request.searchKeyword:
pass
else:
# 根据 searchKeyword 的类型选择不同的查询字段
search_keyword = request.searchKeyword
keyword_type = identify_keyword_type_simple(search_keyword)
# 根据类型选择相应的字段进行查询
if keyword_type == "pinyin":
# 调用 user_keyword_search_service 获取匹配的股票信息列表
matching_stocks = await user_keyword_search_service(search_keyword)
# 提取股票代码
matching_stock_codes = [stock["stock_code"] for stock in matching_stocks]
if matching_stock_codes:
# 根据返回的股票代码进行筛选
query_set = query_set.filter(stock_code__in=matching_stock_codes)
else:
# 如果没有匹配的股票代码,返回空查询
query_set = query_set.none()
elif keyword_type == "code":
# 模糊查询代码
query_set = query_set.filter(stock_code__icontains=search_keyword)
elif keyword_type == "chinese":
# 模糊查询股票名称字段
query_set = query_set.filter(stock_name__icontains=search_keyword)
else:
raise ValueError("无效的关键词类型")
# 年度映射
if request.year:
mapped_year = year_mapping.get(request.year)
if mapped_year is not None:
query_set = query_set.filter(year=mapped_year)
else:
raise ValueError("无效的年份参数")
# 股票池映射
if request.stockPoolCode is None:
# 未提供股票池代码,返回所有记录
pass
else:
# 获取映射值,如果不存在则返回 None
mapped_stock_pool = stock_pool_mapping.get(request.stockPoolCode)
# 检查 mapped_stock_pool 是否为 None即不在映射表中
if mapped_stock_pool is None:
# 如果 stockPoolCode 不在映射表中,抛出无效股票池参数错误
raise ValueError("无效的股票池参数")
# 如果存在有效的映射值,执行相应的过滤
elif mapped_stock_pool == "000300.SH":
subquery = await StockHuShen300.filter().values_list('code', flat=True)
query_set = query_set.filter(stock_code__in=subquery)
elif mapped_stock_pool == "000905.SH":
subquery = await StockZhongZheng500.filter().values_list('code', flat=True)
query_set = query_set.filter(stock_code__in=subquery)
elif mapped_stock_pool == "399303.SZ":
subquery = await StockGuoZheng2000.filter().values_list('code', flat=True)
query_set = query_set.filter(stock_code__in=subquery)
elif mapped_stock_pool in ["large_mv", "medium_mv", "small_mv"]:
# 先获取所有有总市值的关联的股票代码
subquery = await StockHuShenJingA.filter(total_market_value__isnull=False).values_list('code', 'total_market_value')
# 转换为 DataFrame 或者直接进行排序
stock_list = sorted(subquery, key=lambda x: x[1], reverse=True) # 按 total_market_value 降序排序
total_count = len(stock_list)
if mapped_stock_pool == "large_mv":
# 获取前 30% 的数据
limit_count = ceil(total_count * 0.3)
selected_stocks = [stock[0] for stock in stock_list[:limit_count]]
elif mapped_stock_pool == "medium_mv":
# 获取中间 40% 的数据
start_offset = ceil(total_count * 0.3)
limit_count = ceil(total_count * 0.4)
selected_stocks = [stock[0] for stock in stock_list[start_offset:start_offset + limit_count]]
elif mapped_stock_pool == "small_mv":
# 获取后 30% 的数据
start_offset = ceil(total_count * 0.7)
selected_stocks = [stock[0] for stock in stock_list[start_offset:]]
# 对 FinancialReport 表进行筛选
query_set = query_set.filter(stock_code__in=selected_stocks)
# 财报周期映射
if request.period is not None:
mapped_period = period_mapping.get(request.period)
if mapped_period is not None:
# 如果映射到有效的周期,则进行过滤
query_set = query_set.filter(period=mapped_period)
else:
# 如果找不到有效的映射,抛出错误
raise ValueError("无效的财报周期参数")
# 检查是否所有筛选条件都为空
if not request.revenueYoyType and not request.nIncomeYoyType and not request.revenueYoyTop10:
# 如果全部为空,则按 date 和 created_at 降序排序
query_set = query_set.order_by("-date", "-created_at")
# 筛选 revenueYoyType如果是 1 则筛选 revenue_yoy 大于 10.0 的记录
if request.revenueYoyType == 1:
query_set = query_set.filter(revenue_yoy__gt=10.0).order_by("-revenue_yoy")
# 筛选 nIncomeYoyType如果是 1 则筛选 nincome_yoy 大于 10.0 的记录
elif request.nIncomeYoyType == 1:
query_set = query_set.filter(nincome_yoy__gt=10.0).order_by("-nincome_yoy")
# 如果 revenueYoyTop10 为 1则筛选 revenue_yoy 前 10% 的记录
elif request.revenueYoyTop10 == 1:
# 计算前 10% 的数量
total_count = await FinancialReport.all().count() # 获取总记录数
limit_count = ceil(total_count * 0.1)
# 按 revenue_yoy 降序排列,获取前 10% 记录
query_set = query_set.order_by("-revenue_yoy").limit(limit_count).order_by("-revenue_yoy")
# 调用分页函数进行分页处理
params = Params(page=request.pageNo, size=request.pageSize) #parms作用是传页码和页面大小
result = await pagination(pydantic_model=FinancialReport_Pydantic,
query_set=query_set,
params=params)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail="内部服务器错误")
async def user_keyword_search_service(searchKeyword: Optional[str] = None) -> list:
"""
根据用户输入的关键词实时模糊匹配股票信息支持股票代码拼音简拼和汉字名称
:param searchKeyword: 用户输入的关键词支持股票代码拼音简拼和汉字
:return: 匹配的股票列表
"""
if not searchKeyword:
# 如果搜索关键词为空,返回空列表
return []
# 确定关键词类型(拼音简写、股票代码或汉字)
from src.utils.identify_keyword_type import identify_keyword_type_simple
keyword_type = identify_keyword_type_simple(searchKeyword)
# 构建查询集
query_set: QuerySet = Stock.all()
if keyword_type == "pinyin":
# 如果关键词是拼音简写,则匹配 stock_pinyin
query_set = query_set.filter(stock_pinyin__icontains=searchKeyword)
elif keyword_type == "code":
# 如果关键词是股票代码,则匹配 stock_code
query_set = query_set.filter(stock_code__icontains=searchKeyword)
elif keyword_type == "chinese":
# 如果关键词是汉字,则匹配 stock_name
query_set = query_set.filter(stock_name__icontains=searchKeyword)
else:
# 如果关键词类型无法识别,返回空列表
return []
# 获取匹配的股票列表,限制返回的数量(例如最多返回 20 条数据)
matching_stocks = await query_set.values("stock_name","stock_code")
return matching_stocks

@ -1,105 +1,54 @@
import sys
import os
# 添加项目的根目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import sentry_sdk
import uvicorn
import pandas as pd
import asyncio
from fastapi import FastAPI
from datetime import datetime
from starlette.middleware.cors import CORSMiddleware
from src.exceptions import register_exception_handler
from src.tortoises import register_tortoise_orm
from src.xtdata.router import router as xtdata_router
from src.backtest.router import router as backtest_router
from src.combination.router import router as combine_router
from src.models.financial_reports import FinancialReport
from src.utils.update_financial_reports_spider import combined_search_and_list
from src.financial_reports.router import financial_reports_router
from src.utils.generate_pinyin_abbreviation import generate_pinyin_abbreviation
from src.composite.router import composite_router
from xtquant import xtdata
from src.settings.config import app_configs, settings
import adata
import akshare as ak
app = FastAPI(**app_configs)
register_tortoise_orm(app)
register_exception_handler(app)
app.include_router(xtdata_router, prefix="/getwancedata", tags=["数据接口"])
app.include_router(backtest_router, prefix="/backtest", tags=["回测接口"])
app.include_router(combine_router, prefix="/combine", tags=["组合接口"])
app.include_router(financial_reports_router, prefix="/financial-reports", tags=["财报接口"])
app.include_router(composite_router, prefix="/composite", tags=["vacode组合接口"])
if settings.ENVIRONMENT.is_deployed:
sentry_sdk.init(
dsn=settings.SENTRY_DSN,
environment=settings.ENVIRONMENT,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
allow_credentials=True,
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
allow_headers=settings.CORS_HEADERS,
)
@app.get("/")
async def root():
return {"message": "Hello, FastAPI!"}
# 定时检查和数据抓取函数
async def run_data_fetcher():
print("财报抓取启动")
while True:
# 获取数据库中最新记录的时间
latest_record = await FinancialReport .all().order_by("-created_at").first()
latest_db_date = latest_record.created_at if latest_record else pd.Timestamp("1970-01-01")
# 将最新数据库日期设为无时区,以便比较
latest_db_date = latest_db_date.replace(tzinfo=None)
# 当前时间
current_time = pd.Timestamp(datetime.now()).replace(tzinfo=None)
# 检查当前时间是否超过数据库最新记录时间 12 个小时
if (current_time - latest_db_date).total_seconds() > 43200:
print("启动财报数据抓取...")
await combined_search_and_list()
else:
print("未满足条件,跳过抓取任务。")
# 休眠 12 小时43200 秒),然后再次检查
await asyncio.sleep(43200)
async def test_liang_hua_ku():
print("量化库测试函数启动")
# 启动时立即运行检查任务
@app.on_event("startup")
async def lifespan():
# 启动时将任务放入后台,不阻塞启动流程
asyncio.create_task(test_liang_hua_ku())
asyncio.create_task(run_data_fetcher())
if __name__ == "__main__":
uvicorn.run('src.main:app', host="0.0.0.0", port=8012, reload=True)
import sentry_sdk
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from src.exceptions import register_exception_handler
from src.tortoises import register_tortoise_orm
from src.xtdata.router import router as xtdata_router
from src.backtest.router import router as backtest_router
from src.combination.router import router as combine_router
from xtquant import xtdata
from src.settings.config import app_configs, settings
app = FastAPI(**app_configs)
register_tortoise_orm(app)
register_exception_handler(app)
app.include_router(xtdata_router, prefix="/getwancedata", tags=["数据接口"])
app.include_router(backtest_router, prefix="/backtest", tags=["回测接口"])
app.include_router(combine_router, prefix="/combine", tags=["组合接口"])
if settings.ENVIRONMENT.is_deployed:
sentry_sdk.init(
dsn=settings.SENTRY_DSN,
environment=settings.ENVIRONMENT,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
allow_credentials=True,
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
allow_headers=settings.CORS_HEADERS,
)
@app.get("/")
async def root():
return {"message": "Hello, FastAPI!"}
if __name__ == "__main__":
uvicorn.run('src.main:app', host="0.0.0.0", port=8012, reload=True)

@ -1,32 +0,0 @@
from tortoise import fields, Model
class FinancialReport (Model):
"""
财报数据
"""
id = fields.IntField(pk=True, description="自增主键") # 自增主键
date = fields.DateField(description="日期") # 日期
period = fields.CharField(max_length=50, description="报告期") # 报告期
year = fields.IntField(description="年份") # 年份
stock_code = fields.CharField(max_length=10, description="股票代码") # 股票代码
stock_name = fields.CharField(max_length=50, description="股票名称") # 股票名称
industry_code = fields.CharField(max_length=10, null=True, description="行业代码") # 行业代码
industry_name = fields.CharField(max_length=50, null=True, description="行业名称") # 行业名称
pdf_name = fields.CharField(max_length=100, null=True, description="PDF 报告名称") # PDF 报告名称
pdf_url = fields.CharField(max_length=255, null=True, description="PDF 报告链接") # PDF 报告链接
revenue = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="收入") # 收入
revenue_yoy = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="收入同比增长") # 收入同比增长
pr_of_today = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="今日涨跌幅") # 今日涨跌幅
pr_of_after_1_day = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="次日涨跌幅") # 次日涨跌幅
pr_of_after_1_week = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="一周后涨跌幅") # 一周后涨跌幅
pr_of_after_1_month = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="一月后涨跌幅") # 一月后涨跌幅
pr_of_after_6_month = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="六个月后涨跌幅") # 六个月后涨跌幅
pr_of_that_year = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="年初至今涨跌幅") # 年初至今涨跌幅
nincome_yoy_lower_limit = fields.CharField(max_length=10, null=True, description="净利润同比增长下限") # 净利润同比增长下限
nincome_yoy_upper_limit = fields.CharField(max_length=10, null=True, description="净利润同比增长上限") # 净利润同比增长上限
nincome = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="净利润") # 净利润
nincome_yoy = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="净利润同比增长率") # 净利润同比增长率
created_at = fields.DatetimeField(auto_now_add=True, description="创建时间") # 创建时间,自动生成当前时间
class Meta:
table = "financial_reports"

@ -1,12 +0,0 @@
from tortoise import fields, Model
class StockGuoZheng2000(Model):
"""
国证2000成分股数据
"""
code = fields.CharField(max_length=10, description="股票代码") # 股票代码
name = fields.CharField(max_length=50, description="股票名称") # 股票名称
included_date = fields.CharField(max_length=20, description="纳入日期") # 纳入日期
class Meta:
table = "stock_guo_zheng_2000"

@ -1,12 +0,0 @@
from tortoise import fields, Model
class StockHuShen300(Model):
"""
沪深300成分股数据
"""
symbol = fields.CharField(max_length=10, description="股票代码") # 股票代码
code = fields.CharField(max_length=10, description="股票代码") # 股票代码
name = fields.CharField(max_length=50, description="股票名称") # 股票名称
class Meta:
table = "stock_hu_shen_300"

@ -1,36 +0,0 @@
from tortoise import fields, Model
"""
待补全每日更新表的爬虫
"""
class StockHuShenJingA(Model):
"""
沪深京 A 股实时行情数据
"""
seq = fields.IntField(description="序号") # 序号
code = fields.CharField(max_length=10, description="股票代码") # 股票代码
name = fields.CharField(max_length=50, description="股票名称") # 股票名称
latest_price = fields.FloatField(null=True, description="最新价格") # 最新价格
change_percent = fields.FloatField(null=True, description="涨跌幅") # 涨跌幅
change_amount = fields.FloatField(null=True, description="涨跌额") # 涨跌额
volume = fields.FloatField(null=True, description="成交量") # 成交量
turnover = fields.FloatField(null=True, description="成交额") # 成交额
amplitude = fields.FloatField(null=True, description="振幅") # 振幅
high = fields.FloatField(null=True, description="最高价") # 最高价
low = fields.FloatField(null=True, description="最低价") # 最低价
open = fields.FloatField(null=True, description="今开价") # 今开价
previous_close = fields.FloatField(null=True, description="昨收价") # 昨收价
volume_ratio = fields.FloatField(null=True, description="量比") # 量比
turnover_rate = fields.FloatField(null=True, description="换手率") # 换手率
pe_dynamic = fields.FloatField(null=True, description="市盈率-动态") # 市盈率-动态
pb = fields.FloatField(null=True, description="市净率") # 市净率
total_market_value = fields.FloatField(null=True, description="总市值") # 总市值
circulating_market_value = fields.FloatField(null=True, description="流通市值") # 流通市值
rise_speed = fields.FloatField(null=True, description="涨速") # 涨速
five_minute_change = fields.FloatField(null=True, description="5分钟涨跌幅") # 5分钟涨跌幅
sixty_day_change = fields.FloatField(null=True, description="60日涨跌幅") # 60日涨跌幅
year_to_date_change = fields.FloatField(null=True, description="年初至今涨跌幅") # 年初至今涨跌幅
class Meta:
table = "stock_hu_shen_jing_a"

@ -1,12 +0,0 @@
from tortoise import fields, Model
class StockZhongZheng500(Model):
"""
中证500成分股数据
"""
symbol = fields.CharField(max_length=10, description="股票代码") # 股票代码
code = fields.CharField(max_length=10, description="股票代码") # 股票代码
name = fields.CharField(max_length=50, description="股票名称") # 股票名称
class Meta:
table = "stock_zhong_zheng_500"

@ -1,16 +0,0 @@
from tortoise import fields
from tortoise.models import Model
class UserStrategy(Model):
"""
用户多因子组合表
"""
id = fields.IntField(pk=True)
create_time = fields.DatetimeField(null=True)
user_id = fields.IntField(null=True)
strategy_request = fields.JSONField(null=False)
strategy_name = fields.CharField(max_length=30, null=True)
deleted_stock = fields.BinaryField(null=True)
class Meta:
table = "user_strategy"

Binary file not shown.

@ -1,46 +1,68 @@
import typing
from typing import Generic, Sequence, TypeVar
from fastapi import status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from pydantic import BaseModel
T = TypeVar("T")
class BaseResponse(BaseModel, Generic[T]):
"""通用响应类"""
status_code: int
message: str
data: typing.Optional[T] = None
def json_response(self) -> JSONResponse:
"""返回 JSONResponse 响应"""
return JSONResponse(
status_code=self.status_code,
content=jsonable_encoder(self.dict())
)
def create_response(data: T = None, status_code: int = 200, message: str = "Success") -> JSONResponse:
"""通用响应函数"""
response = BaseResponse(status_code=status_code, message=message, data=data)
return response.json_response()
# 使用示例
def response_success(message: str = 'Success') -> JSONResponse:
return create_response(status_code=status.HTTP_200_OK, message=message)
def response_unauthorized(message: str = "用户认证失败") -> JSONResponse:
return create_response(status_code=status.HTTP_401_UNAUTHORIZED, message=message)
def response_entity_response(data: T, message: str = "Success") -> JSONResponse:
"""单个实体数据响应"""
return create_response(data=data, status_code=200, message=message)
def response_list_response(data: Sequence[T], message: str = "Success") -> JSONResponse:
"""列表数据响应"""
return create_response(data=data, status_code=200, message=message)
import typing
from typing import Generic, Sequence, TypeVar
from fastapi import status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from pydantic import BaseModel
T = TypeVar("T")
class EntityResponse(BaseModel, Generic[T]):
"""实体数据"""
status_code: int = 200
message: str = "Success"
data: T
class ListResponse(BaseModel, Generic[T]):
"""列表数据"""
status_code: int = 200
message: str = "Success"
data: Sequence[T]
# 包装响应的 Pydantic 模型
# class EntityPageResponse(BaseModel, Generic[T]):
# status_code: int
# message: str
# data: PageData
def response_entity_response(data, status_code=200, message="Success") -> EntityResponse:
"""普通实体类"""
return EntityResponse(data=data, status_code=status_code, message=message)
# def response_page_response(data, status_code=200, message="Success") -> EntityPageResponse:
# """普通分页类"""
# return EntityPageResponse(data=data, status_code=status_code, message=message)
def response_list_response(data, status_code=200, message="Success") -> ListResponse:
"""普通列表数据"""
return ListResponse(data=data, status_code=status_code, message=message)
def response_success(message: str = 'Success',
headers: typing.Optional[typing.Dict[str, str]] = None) -> JSONResponse:
"""成功返回"""
return JSONResponse(
status_code=status.HTTP_200_OK,
headers=headers,
content=jsonable_encoder({
"status_code": status.HTTP_200_OK,
"message": message,
}))
def response_unauthorized() -> JSONResponse:
"""未登录"""
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=jsonable_encoder({
"status_code": status.HTTP_401_UNAUTHORIZED,
"message": '用户认证失败',
}))

@ -1,52 +1,52 @@
from typing import Any
from pydantic import RedisDsn, model_validator
from pydantic_settings import BaseSettings
from src.constants import Environment
class Config(BaseSettings):
DATABASE_PREFIX: str = ""
DATABASE_URL: str
DATABASE_CREATE_URL: str
REDIS_URL: RedisDsn
SITE_DOMAIN: str = "myapp.com"
ENVIRONMENT: Environment = Environment.PRODUCTION
SENTRY_DSN: str | None = None
CORS_ORIGINS: list[str]
CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str]
APP_VERSION: str = "1"
APP_ROUTER_PREFIX: str = "/api/v1"
SHOULD_SEND_SMS: bool = False
@model_validator(mode="after")
def validate_sentry_non_local(self) -> "Config":
if self.ENVIRONMENT.is_deployed and not self.SENTRY_DSN:
raise ValueError("Sentry is not set")
return self
settings = Config()
# fastapi/applications.py
app_configs: dict[str, Any] = {
"title": "Wance QMT API",
"root_path": settings.APP_ROUTER_PREFIX,
"docs_url": "/api/docs",
}
# app_configs['debug'] = True
# if settings.ENVIRONMENT.is_deployed:
# app_configs["root_path"] = f"/v{settings.APP_VERSION}"
#
# if not settings.ENVIRONMENT.is_debug:
# app_configs["openapi_url"] = None # hide docs
from typing import Any
from pydantic import RedisDsn, model_validator
from pydantic_settings import BaseSettings
from src.constants import Environment
class Config(BaseSettings):
DATABASE_PREFIX: str = ""
DATABASE_URL: str
DATABASE_CREATE_URL: str
REDIS_URL: RedisDsn
SITE_DOMAIN: str = "myapp.com"
ENVIRONMENT: Environment = Environment.PRODUCTION
SENTRY_DSN: str | None = None
CORS_ORIGINS: list[str]
CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str]
APP_VERSION: str = "1"
APP_ROUTER_PREFIX: str = "/api/v1"
SHOULD_SEND_SMS: bool = False
@model_validator(mode="after")
def validate_sentry_non_local(self) -> "Config":
if self.ENVIRONMENT.is_deployed and not self.SENTRY_DSN:
raise ValueError("Sentry is not set")
return self
settings = Config()
# fastapi/applications.py
app_configs: dict[str, Any] = {
"title": "Wance QMT API",
"root_path": settings.APP_ROUTER_PREFIX,
"docs_url": "/api/docs",
}
# app_configs['debug'] = True
# if settings.ENVIRONMENT.is_deployed:
# app_configs["root_path"] = f"/v{settings.APP_VERSION}"
#
# if not settings.ENVIRONMENT.is_debug:
# app_configs["openapi_url"] = None # hide docs

@ -1,193 +1,186 @@
import logging
from fastapi import FastAPI
from tortoise.contrib.fastapi import register_tortoise
from src.settings.config import settings
DATABASE_URL = str(settings.DATABASE_URL)
DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL)
# 定义不同日志级别的颜色
class ColoredFormatter(logging.Formatter):
COLORS = {
"DEBUG": "\033[94m", # 蓝色
"INFO": "\033[92m", # 绿色
"WARNING": "\033[93m", # 黄色
"ERROR": "\033[91m", # 红色
"CRITICAL": "\033[41m", # 红色背景
}
RESET = "\033[0m" # 重置颜色
def format(self, record):
color = self.COLORS.get(record.levelname, self.RESET)
log_message = super().format(record)
return f"{color}{log_message}{self.RESET}"
# 配置日志
logger_db_client = logging.getLogger("tortoise.db_client")
logger_db_client.setLevel(logging.DEBUG)
# 创建一个控制台处理程序
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# 创建并设置格式化器
formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
# 将处理程序添加到日志记录器中
logger_db_client.addHandler(ch)
# 定义模型
models = [
"src.models.test_table",
"src.models.stock",
"src.models.security_account",
"src.models.order",
"src.models.snowball",
"src.models.backtest",
"src.models.strategy",
"src.models.stock_details",
"src.models.transaction",
"src.models.position",
"src.models.trand_info",
"src.models.tran_observer_data",
"src.models.tran_orders",
"src.models.tran_return",
"src.models.tran_trade_info",
"src.models.back_observed_data",
"src.models.back_observed_data_detail",
"src.models.back_position",
"src.models.back_result_indicator",
"src.models.back_trand_info",
"src.models.tran_position",
"src.models.stock_bt_history",
"src.models.stock_history",
"src.models.stock_data_processing",
"src.models.wance_data_stock",
"src.models.wance_data_storage_backtest",
"src.models.user_combination_history",
"src.models.financial_reports",
"src.models.stock_hu_shen300",
"src.models.stock_zhong_zheng_500",
"src.models.stock_guo_zheng_2000",
"src.models.stock_hu_shen_jing_a",
"src.models.user_strategy"
]
aerich_models = models
aerich_models.append("aerich.models")
# Tortoise ORM 配置
TORTOISE_ORM = {
"connections": {"default": DATABASE_CREATE_URL,
},
"apps": {
"models": {
"models": aerich_models,
"default_connection": "default",
},
},
}
config = {
'connections': {
'default': DATABASE_URL
},
'apps': {
'models': {
'models': models,
'default_connection': 'default',
}
},
'use_tz': False,
'timezone': 'Asia/Shanghai'
}
def register_tortoise_orm(app: FastAPI):
# 注册 Tortoise ORM 配置
register_tortoise(
app,
config=config,
generate_schemas=False,
add_exception_handlers=True,
)
"""
**********
以下内容对原有无改色配置文件进行注释上述内容为修改后日志为蓝色
**********
"""
# import logging
#
# from fastapi import FastAPI
# from tortoise.contrib.fastapi import register_tortoise
#
# from src.settings.config import settings
#
# DATABASE_URL = str(settings.DATABASE_URL)
# DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL)
#
# # will print debug sql
# logging.basicConfig()
# logger_db_client = logging.getLogger("tortoise.db_client")
# logger_db_client.setLevel(logging.DEBUG)
#
# models = [
# "src.models.test_table",
# "src.models.stock",
# "src.models.security_account",
# "src.models.order",
# "src.models.snowball",
# "src.models.backtest",
# "src.models.strategy",
# "src.models.stock_details",
# "src.models.transaction",
# ]
#
# aerich_models = models
# aerich_models.append("aerich.models")
#
# TORTOISE_ORM = {
#
# "connections": {"default": DATABASE_CREATE_URL},
# "apps": {
# "models": {
# "models": aerich_models,
# "default_connection": "default",
# },
# },
# }
#
# config = {
# 'connections': {
# # Using a DB_URL string
# 'default': DATABASE_URL
# },
# 'apps': {
# 'models': {
# 'models': models,
# 'default_connection': 'default',
# }
# },
# 'use_tz': False,
# 'timezone': 'Asia/Shanghai'
# }
#
#
# def register_tortoise_orm(app: FastAPI):
# register_tortoise(
# app,
# config=config,
# # db_url=DATABASE_URL,
# # db_url="mysql://adams:adams@127.0.0.1:3306/adams",
# # modules={"models": models},
# generate_schemas=False,
# add_exception_handlers=True,
# )
import logging
from fastapi import FastAPI
from tortoise.contrib.fastapi import register_tortoise
from src.settings.config import settings
DATABASE_URL = str(settings.DATABASE_URL)
DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL)
# 定义不同日志级别的颜色
class ColoredFormatter(logging.Formatter):
COLORS = {
"DEBUG": "\033[94m", # 蓝色
"INFO": "\033[92m", # 绿色
"WARNING": "\033[93m", # 黄色
"ERROR": "\033[91m", # 红色
"CRITICAL": "\033[41m", # 红色背景
}
RESET = "\033[0m" # 重置颜色
def format(self, record):
color = self.COLORS.get(record.levelname, self.RESET)
log_message = super().format(record)
return f"{color}{log_message}{self.RESET}"
# 配置日志
logger_db_client = logging.getLogger("tortoise.db_client")
logger_db_client.setLevel(logging.DEBUG)
# 创建一个控制台处理程序
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# 创建并设置格式化器
formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
# 将处理程序添加到日志记录器中
logger_db_client.addHandler(ch)
# 定义模型
models = [
"src.models.test_table",
"src.models.stock",
"src.models.security_account",
"src.models.order",
"src.models.snowball",
"src.models.backtest",
"src.models.strategy",
"src.models.stock_details",
"src.models.transaction",
"src.models.position",
"src.models.trand_info",
"src.models.tran_observer_data",
"src.models.tran_orders",
"src.models.tran_return",
"src.models.tran_trade_info",
"src.models.back_observed_data",
"src.models.back_observed_data_detail",
"src.models.back_position",
"src.models.back_result_indicator",
"src.models.back_trand_info",
"src.models.tran_position",
"src.models.stock_bt_history",
"src.models.stock_history",
"src.models.stock_data_processing",
"src.models.wance_data_stock",
"src.models.wance_data_storage_backtest",
"src.models.user_combination_history"
]
aerich_models = models
aerich_models.append("aerich.models")
# Tortoise ORM 配置
TORTOISE_ORM = {
"connections": {"default": DATABASE_CREATE_URL,
},
"apps": {
"models": {
"models": aerich_models,
"default_connection": "default",
},
},
}
config = {
'connections': {
'default': DATABASE_URL
},
'apps': {
'models': {
'models': models,
'default_connection': 'default',
}
},
'use_tz': False,
'timezone': 'Asia/Shanghai'
}
def register_tortoise_orm(app: FastAPI):
# 注册 Tortoise ORM 配置
register_tortoise(
app,
config=config,
generate_schemas=False,
add_exception_handlers=True,
)
"""
**********
以下内容对原有无改色配置文件进行注释上述内容为修改后日志为蓝色
**********
"""
# import logging
#
# from fastapi import FastAPI
# from tortoise.contrib.fastapi import register_tortoise
#
# from src.settings.config import settings
#
# DATABASE_URL = str(settings.DATABASE_URL)
# DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL)
#
# # will print debug sql
# logging.basicConfig()
# logger_db_client = logging.getLogger("tortoise.db_client")
# logger_db_client.setLevel(logging.DEBUG)
#
# models = [
# "src.models.test_table",
# "src.models.stock",
# "src.models.security_account",
# "src.models.order",
# "src.models.snowball",
# "src.models.backtest",
# "src.models.strategy",
# "src.models.stock_details",
# "src.models.transaction",
# ]
#
# aerich_models = models
# aerich_models.append("aerich.models")
#
# TORTOISE_ORM = {
#
# "connections": {"default": DATABASE_CREATE_URL},
# "apps": {
# "models": {
# "models": aerich_models,
# "default_connection": "default",
# },
# },
# }
#
# config = {
# 'connections': {
# # Using a DB_URL string
# 'default': DATABASE_URL
# },
# 'apps': {
# 'models': {
# 'models': models,
# 'default_connection': 'default',
# }
# },
# 'use_tz': False,
# 'timezone': 'Asia/Shanghai'
# }
#
#
# def register_tortoise_orm(app: FastAPI):
# register_tortoise(
# app,
# config=config,
# # db_url=DATABASE_URL,
# # db_url="mysql://adams:adams@127.0.0.1:3306/adams",
# # modules={"models": models},
# generate_schemas=False,
# add_exception_handlers=True,
# )

@ -1,14 +0,0 @@
from pypinyin import lazy_pinyin, Style
from src.models.stock import Stock
async def generate_pinyin_abbreviation():
"""
更新数据库中所有股票记录的拼音缩写如果字段为空则生成并保存,待抽象为通用组件
"""
stocks = await Stock.filter(stock_pinyin__isnull=True) # 查找拼音字段为空的记录
for stock in stocks:
if stock.stock_name: # 确保股票名称存在
# 将每个汉字的第一个字母转为小写,并合并为拼音缩写
stock.stock_pinyin = ''.join(lazy_pinyin(stock.stock_name, style=Style.FIRST_LETTER)).lower()
await stock.save() # 保存拼音缩写到数据库
print("拼音更新完成")

@ -1,19 +0,0 @@
"""
判断关键词类的工具类选中第一个字符做判断
"""
def identify_keyword_type_simple(keyword: str) -> str:
if not keyword:
return "unknown" # 如果字符串为空,返回 "unknown"
first_char = keyword[0]
# 判断第一个字符
if '\u4e00' <= first_char <= '\u9fa5': # 汉字
return "chinese"
elif first_char.isalpha(): # 字母
return "pinyin"
elif first_char.isdigit(): # 数字
return "code"
else:
return "unknown"

@ -1,66 +1,74 @@
from typing import Type, TypeVar, Generic, Sequence, Optional, Callable, Awaitable
import math
from fastapi import Query
from pydantic import BaseModel
from tortoise.queryset import QuerySet
from tortoise.models import Model
T = TypeVar('T', bound=BaseModel)
M = TypeVar('M', bound=Model)
class PaginationPydantic(BaseModel, Generic[T]):
"""分页模型"""
status_code: int = 200
message: str = "Success"
total: int
page: int
size: int
total_pages: int
data: Sequence[T]
class Params(BaseModel):
"""传参"""
page: int = Query(1, ge=1, description="Page number")
size: int = Query(10, gt=0, le=200, description="Page size")
order_by: Optional[str] = Query(None, max_length=32, description="Sort key")
async def pagination(
pydantic_model: Type[T],
query_set: QuerySet[M],
params: Params,
callback: Optional[Callable[[QuerySet[M]], Awaitable[QuerySet[M]]]] = None
) -> PaginationPydantic[T]:
"""分页响应"""
page: int = params.page
size: int = params.size
order_by: Optional[str] = params.order_by
total = await query_set.count()
# 计算总页数
total_pages = math.ceil(total / size) if total > 0 else 1
if page > total_pages and total > 0:
raise ValueError("页数输入有误")
# 排序
if order_by:
order_by_fields = order_by.split(',')
query_set = query_set.order_by(*order_by_fields)
# 分页
query_set = query_set.offset((page - 1) * size).limit(size)
# 回调处理
if callback:
query_set = await callback(query_set)
# 获取数据
data = await pydantic_model.from_queryset(query_set)
return PaginationPydantic[T](
total=total,
page=page,
size=size,
total_pages=total_pages,
data=data
)
from typing import Generic, Optional, Sequence, TypeVar
import math
from fastapi import Query
from pydantic import BaseModel
from tortoise.queryset import QuerySet
T = TypeVar("T")
class PaginationPydantic(BaseModel, Generic[T]):
"""分页模型"""
status_code: int = 200
message: str = "Success"
total: int
page: int
size: int
total_pages: int
data: Sequence[T]
class Params(BaseModel):
"""传参"""
# 设置默认值为1不能够小于1
page: int = Query(1, ge=1, description="Page number")
# 设置默认值为10最大为100
size: int = Query(10, gt=0, le=200, description="Page size")
# 默认值None表示选传
order_by: Optional[str] = Query(None, max_length=32, description="Sort key")
async def pagination(pydantic_model, query_set: QuerySet, params: Params, callback=None):
"""分页响应"""
"""
pydantic_model: Pydantic model
query_set: QuerySet
params: Params
callback: if you want to do something for query_set,it will be useful.
"""
page: int = params.page
size: int = params.size
order_by: str = params.order_by
total = await query_set.count()
# 通过总数和每页数量计算出总页数
total_pages = math.ceil(total / size)
if page > total_pages and total: # 排除查询集为空时报错即total=0时
raise ValueError("页数输入有误")
# 排序后分页
if order_by:
order_by = order_by.split(',')
query_set = query_set.order_by(*order_by)
# 分页
query_set = query_set.offset((page - 1) * size) # 页数 * 页面大小=偏移量
query_set = query_set.limit(size)
if callback:
"""对查询集操作"""
query_set = await callback(query_set)
# query_set = await query_set
data = await pydantic_model.from_queryset(query_set)
return PaginationPydantic(**{
"total": total,
"page": page,
"size": size,
"total_pages": total_pages,
"data": data,
})

@ -1,202 +0,0 @@
"""
TO DO待完善全体数据更新逻辑
"""
import time
import requests
import pandas as pd
from fastapi import APIRouter
from src.models.financial_reports import FinancialReport # 确保路径正确
async def combined_search_and_list():
# 分页参数
pageNo = 1
pageSize = 50
selector_url = "https://www.shidaotec.com/api/research/getFinaResearchPageList"
# 请求第一页获取总记录数
response = requests.get(selector_url, params={
"pageNo": pageNo,
"pageSize": pageSize,
"year": "2024",
"type": ""
})
response.raise_for_status()
selector_data = response.json()
# 获取总记录数和数据库已有记录数
total_records = selector_data.get("data", {}).get("total", 0)
db_record_count = await FinancialReport .all().count()
remaining_insert_count = total_records - db_record_count
print(f"总记录数: {total_records},数据库已有记录数: {db_record_count},需要插入: {remaining_insert_count}")
# 如果已有记录数与总记录数相同,则直接结束函数
if remaining_insert_count <= 0:
print("数据库已包含所有记录,无需插入新数据。")
return
# 获取数据库中最新的记录
latest_record = await FinancialReport .all().order_by("-date").first()
latest_db_date = latest_record.date if latest_record else None
latest_db_date = pd.Timestamp(latest_db_date)
# 缓存数据
all_data = []
# 开始分页请求,直到遇到比数据库最新日期更早的数据
while True:
try:
# 请求当前页数据
response = requests.get(selector_url, params={
"pageNo": pageNo,
"pageSize": pageSize,
"year": "2024",
"type": ""
})
response.raise_for_status()
selector_data = response.json()
fina_data_list = selector_data.get("data", {}).get("list", [])
# 如果页数据为空,停止循环
if not fina_data_list:
print("No more data available.")
break
# 缓存本页数据
for record in fina_data_list:
record.pop("seq", None) # 移除不必要的字段
record = {k: (None if v == "-" else v) for k, v in record.items()}
all_data.append(record)
# 检查本页第一条数据日期,决定是否继续请求
first_record_date = pd.to_datetime(fina_data_list[0]["date"])
if latest_db_date and first_record_date < latest_db_date:
print(f"Data on page {pageNo} is older than the latest entry in the database. Stopping requests.")
break # 停止请求
print(f"{pageNo} 页成功获取并缓存 {len(fina_data_list)} 条记录")
pageNo += 1
time.sleep(2) # 请求间隔
except requests.RequestException as e:
print(f"Error fetching data for page {pageNo}: {e}")
break
except Exception as e:
print(f"Unexpected error on page {pageNo}: {e}")
break
# 批量查询数据库中已存在的记录
existing_records = await FinancialReport .filter(
date__in=[record["date"] for record in all_data],
period__in=[record["period"] for record in all_data],
stock_code__in=[record["stockCode"] for record in all_data]
).all()
# 将已存在记录存入集合以便于后续检查
existing_data_set = {
(record.date, record.period, record.stock_code, record.stock_name,
record.industry_code, record.industry_name, record.pdf_name,
record.pdf_url, record.revenue, record.revenue_yoy,
record.pr_of_today, record.pr_of_after_1_day,
record.pr_of_after_1_week, record.pr_of_after_1_month,
record.pr_of_after_6_month, record.pr_of_that_year,
record.nincome_yoy_lower_limit, record.nincome_yoy_upper_limit,
record.nincome, record.nincome_yoy)
for record in existing_records
}
# 开始处理缓存的数据,插入数据库
success_count = 0
records_to_insert = []
for record in all_data[:]: # 使用切片拷贝以避免在迭代时修改原始列表
record_date = pd.to_datetime(record["date"])
# 检查该记录是否已存在于数据库中
if (record["date"], record["period"], record["stockCode"], record["stockName"],
record["industryCode"], record["industryName"], record["pdfName"], record["pdfUrl"],
record["revenue"], record["revenueYoy"], record["prOfToday"], record["prOfAfter1Day"],
record["prOfAfter1Week"], record["prOfAfter1Month"], record["prOfAfter6Month"],
record["prOfThatYear"], record["nincomeYoyLowerLimit"], record["nincomeYoyUpperLimit"],
record["nincome"], record["nincomeYoy"]) in existing_data_set:
print(f"Skipping existing record: {record}") # 打印已跳过的记录
continue # 跳过已存在的记录
# 打印即将插入的数据
print(f"Inserting record: {record}")
if record_date >= latest_db_date:
# 创建待插入的 FinancialReport 对象并添加到列表
try:
records_to_insert.append(
FinancialReport (
date=record["date"],
period=record["period"],
year=int(record["year"]),
stock_code=record["stockCode"],
stock_name=record["stockName"],
industry_code=record["industryCode"] if record["industryCode"] else None,
industry_name=record["industryName"] if record["industryName"] else None,
pdf_name=record["pdfName"] if record["pdfName"] else None,
pdf_url=record["pdfUrl"] if record["pdfUrl"] else None,
revenue=float(record["revenue"]) if record["revenue"] else None,
revenue_yoy=float(record["revenueYoy"]) if record["revenueYoy"] else None,
pr_of_today=float(record["prOfToday"]) if record["prOfToday"] else None,
pr_of_after_1_day=float(record["prOfAfter1Day"]) if record["prOfAfter1Day"] else None,
pr_of_after_1_week=float(record["prOfAfter1Week"]) if record["prOfAfter1Week"] else None,
pr_of_after_1_month=float(record["prOfAfter1Month"]) if record["prOfAfter1Month"] else None,
pr_of_after_6_month=float(record["prOfAfter6Month"]) if record["prOfAfter6Month"] else None,
pr_of_that_year=float(record["prOfThatYear"]) if record["prOfThatYear"] else None,
nincome_yoy_lower_limit=record["nincomeYoyLowerLimit"] if record["nincomeYoyLowerLimit"] else None,
nincome_yoy_upper_limit=record["nincomeYoyUpperLimit"] if record["nincomeYoyUpperLimit"] else None,
nincome=float(record["nincome"]) if record["nincome"] else None,
nincome_yoy=float(record["nincomeYoy"]) if record["nincomeYoy"] else None
)
)
success_count += 1
except Exception as e:
print(f"Error processing record {record}: {e}") # 输出处理错误
# 成功插入所需条数后停止
if success_count >= remaining_insert_count:
print("this riqi:" + str(latest_db_date))
break
# 批量插入到数据库
if records_to_insert:
try:
await FinancialReport .bulk_create(records_to_insert)
print(f"成功插入 {len(records_to_insert)} 条记录")
except Exception as e:
print(f"Error during bulk insert: {e}")
# 去重步骤
try:
# 查找基于 date、period 和 stock_code 的重复记录
duplicates = await FinancialReport.raw("""
SELECT id FROM financial_reports WHERE id NOT IN (
SELECT id FROM (
SELECT MAX(id) AS id
FROM financial_reports
GROUP BY date, period, stock_code
) AS temp
)
""")
duplicate_ids = [record.id for record in duplicates]
if duplicate_ids:
# 删除较旧的重复记录
await FinancialReport.filter(id__in=duplicate_ids).delete()
print(f"删除了 {len(duplicate_ids)} 条重复记录")
else:
print("未发现重复记录")
except Exception as e:
print(f"去重过程中出错: {e}")
return {"message": "数据获取、插入和去重已完成"}