加入了一个基础的时到财报爬虫待补全更新逻辑,加入了万策财报后端的接口待补全模糊查询逻辑,加入了沪深300,中证500,国证2000,沪深京A,财报表的model对象

This commit is contained in:
mxwj 2024-11-07 17:55:31 +08:00
parent b9cf4c8a64
commit cfa9ba98d7
22 changed files with 929 additions and 362 deletions

44
.env

@ -1,23 +1,23 @@
;使用本地docker数据库就将这段代码解开如何将IP换成本机IP或者localhost # 使用本地docker数据库就将这段代码解开如何将IP换成本机IP或者localhost
; DATABASE_URL="mysql://root:123456@10.1.5.219:3306/ftrade" # DATABASE_URL="mysql://root:123456@10.1.5.219:3306/ftrade"
; DATABASE_CREATE_URL ="mysql://root:123456@10.1.5.219:3306/ftrade" # DATABASE_CREATE_URL ="mysql://root:123456@10.1.5.219:3306/ftrade"
DATABASE_URL="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche" DATABASE_URL="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche"
DATABASE_CREATE_URL ="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche" DATABASE_CREATE_URL ="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche"
REDIS_URL="redis://:redis_tQNjCH@cqxqg.tech:6380" REDIS_URL="redis://:redis_tQNjCH@cqxqg.tech:6380"
SITE_DOMAIN=127.0.0.1 SITE_DOMAIN=127.0.0.1
SECURE_COOKIES=false SECURE_COOKIES=false
ENVIRONMENT=LOCAL ENVIRONMENT=LOCAL
CORS_HEADERS=["*"] CORS_HEADERS=["*"]
CORS_ORIGINS=["http://localhost:3000"] CORS_ORIGINS=["http://localhost:3000"]
# postgres variables, must be the same as in DATABASE_URL # postgres variables, must be the same as in DATABASE_URL
POSTGRES_USER=app POSTGRES_USER=app
POSTGRES_PASSWORD=app POSTGRES_PASSWORD=app
POSTGRES_HOST=app_db POSTGRES_HOST=app_db
POSTGRES_PORT=5432 POSTGRES_PORT=5432
POSTGRES_DB=app POSTGRES_DB=app

@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="Black"> <component name="Black">
<option name="sdkName" value="Python 3.7" /> <option name="sdkName" value="Python 3.7" />
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="ftrade_back" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (wance_data)" project-jdk-type="Python SDK" />
</project> </project>

6
.idea/vcs.xml Normal file

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

@ -1,15 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$">
<orderEntry type="jdk" jdkName="ftrade_back" jdkType="Python SDK" /> <excludeFolder url="file://$MODULE_DIR$/venv" />
<orderEntry type="sourceFolder" forTests="false" /> </content>
</component> <orderEntry type="inheritedJdk" />
<component name="PackageRequirementsSettings"> <orderEntry type="sourceFolder" forTests="false" />
<option name="requirementsPath" value="" /> </component>
</component> <component name="PackageRequirementsSettings">
<component name="PyDocumentationSettings"> <option name="requirementsPath" value="" />
<option name="format" value="EPYTEXT" /> </component>
<option name="myDocStringFormat" value="Epytext" /> <component name="PyDocumentationSettings">
</component> <option name="format" value="EPYTEXT" />
<option name="myDocStringFormat" value="Epytext" />
</component>
</module> </module>

17
.vscode/launch.json vendored Normal file

@ -0,0 +1,17 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python 调试程序: FastAPI",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"args": [
"src.main:app",
"--reload",
"--port", "8012"
],
"jinja": true
}
]
}

Binary file not shown.

@ -0,0 +1,36 @@
"""
映射表 为空则返回none表示不筛选此字段
"""
# 时间选择映射表
year_mapping = {
"""
为空则返回none
"""
"2024": 2024,
}
# 股票池映射表
stock_pool_mapping = {
"000300.SH": "000300.SH",
"000905.SH": "000905.SH",
"399303.SZ": "399303.SZ",
"large_mv": "large_mv",
"medium_mv": "medium_mv",
"small_mv": "small_mv",
}
# 财报周期映射
period_mapping = {
1: "一季度报告", # 一季度报告
2: "半年度报告", # 半年度报告
3: "三季度报告", # 三季度报告
4: "正式年报", # 正式年报
9: "一季度业绩预告", # 一季度业绩预告
10: "半年度业绩预告", # 半年度业绩预告
11: "三季度业绩预告", # 三季度业绩预告
12: "年度业绩预告", # 年度业绩预告
5: "一季度业绩快报", # 一季度业绩快报
6: "半年度业绩快报", # 半年度业绩快报
7: "三季度业绩快报" # 三季度业绩快报
}

@ -0,0 +1,13 @@
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from src.financial_reports.schemas import *
from src.financial_reports.service import *
financial_reports_router = APIRouter()
@financial_reports_router.post("/query")
async def financial_repoets_query(request: FinancialReportQuery )-> JSONResponse:
result = await financial_reports_query_service(request)
return result

@ -0,0 +1,16 @@
from typing import List, Optional, Any, Dict
from pydantic import BaseModel, Field
class FinancialReportQuery(BaseModel):
pageNo: int = Field(..., ge=1, le=1000, description="页码必须大于等于1且小于等于1000")
pageSize: int = Field(..., ge=1, le=50, description="每页条数必须大于等于1且小于等于50")
# 可选参数
year: Optional[int] = Field(None, ge=1, le=2100, description="时间选择必须为正整数或特定年份且小于等于2100或留空")
stockPoolCode: Optional[str] = Field(None, max_length=20, min_length=1, description="股票池限制最大长度为20限制最小长度为1或留空")
period: Optional[int] = Field(None, ge=1, le=50, description="财报周期必须大于等于1且小于等于50或留空")
revenueYoyType: Optional[int] = Field(None, ge=1, le=10, description="营业收入同比必须大于等于1且小于等于10或留空")
nIncomeYoyType: Optional[int] = Field(None, ge=1, le=10, description="净利润同比必须大于等于1且小于等于10或留空")
revenueYoyTop10: Optional[int] = Field(None, ge=1, le=10, description="营业收入行业前10%必须大于等于1且小于等于10或留空")

@ -0,0 +1,134 @@
import json
import adata
from math import ceil
from typing import Any, Dict
from fastapi import Request, HTTPException, BackgroundTasks
from tortoise.expressions import Subquery,Q
from tortoise.contrib.pydantic import pydantic_model_creator
from src.models.financial_reports import *
from src.models.stock_hu_shen300 import *
from src.models.stock_zhong_zheng_500 import *
from src.models.stock_guo_zheng_2000 import *
from src.models.stock_hu_shen_jing_a import *
from src.financial_reports.schemas import *
from src.utils.paginations import *
from src.financial_reports.mappings import *
# 创建一个不包含 "created_at" 字段的 Pydantic 模型用于响应
FinancialReport_Pydantic = pydantic_model_creator(FinancialReport, exclude=("created_at",))
async def financial_reports_query_service(
request: FinancialReportQuery
) -> PaginationPydantic:
"""
根据选择器和关键词查询财报并进行分页
"""
# 构建查询集,这里假设财报数据模型为 FinancialReport
query_set: QuerySet = FinancialReport.all()
try:
# 年度映射
if request.year:
mapped_year = year_mapping.get(request.year)
if mapped_year is not None:
query_set = query_set.filter(year=mapped_year)
else:
raise ValueError("无效的年份参数")
# 股票池映射
if request.stockPoolCode is None:
# 未提供股票池代码,返回所有记录
pass
else:
# 获取映射值,如果不存在则返回 None
mapped_stock_pool = stock_pool_mapping.get(request.stockPoolCode)
# 检查 mapped_stock_pool 是否为 None即不在映射表中
if mapped_stock_pool is None:
# 如果 stockPoolCode 不在映射表中,抛出无效股票池参数错误
raise ValueError("无效的股票池参数")
# 如果存在有效的映射值,执行相应的过滤
elif mapped_stock_pool == "000300.SH":
subquery = await StockHuShen300.filter().values_list('code', flat=True)
query_set = query_set.filter(stock_code__in=subquery)
elif mapped_stock_pool == "000905.SH":
subquery = await StockZhongZheng500.filter().values_list('code', flat=True)
query_set = query_set.filter(stock_code__in=subquery)
elif mapped_stock_pool == "399303.SZ":
subquery = await StockGuoZheng2000.filter().values_list('code', flat=True)
query_set = query_set.filter(stock_code__in=subquery)
elif mapped_stock_pool in ["large_mv", "medium_mv", "small_mv"]:
# 先获取所有有总市值的关联的股票代码
subquery = await StockHuShenJingA.filter(total_market_value__isnull=False).values_list('code', 'total_market_value')
# 转换为 DataFrame 或者直接进行排序
stock_list = sorted(subquery, key=lambda x: x[1], reverse=True) # 按 total_market_value 降序排序
total_count = len(stock_list)
if mapped_stock_pool == "large_mv":
# 获取前 30% 的数据
limit_count = ceil(total_count * 0.3)
selected_stocks = [stock[0] for stock in stock_list[:limit_count]]
elif mapped_stock_pool == "medium_mv":
# 获取中间 40% 的数据
start_offset = ceil(total_count * 0.3)
limit_count = ceil(total_count * 0.4)
selected_stocks = [stock[0] for stock in stock_list[start_offset:start_offset + limit_count]]
elif mapped_stock_pool == "small_mv":
# 获取后 30% 的数据
start_offset = ceil(total_count * 0.7)
selected_stocks = [stock[0] for stock in stock_list[start_offset:]]
# 对 FinancialReport 表进行筛选
query_set = query_set.filter(stock_code__in=selected_stocks)
# 财报周期映射
if request.period is not None:
mapped_period = period_mapping.get(request.period)
if mapped_period is not None:
# 如果映射到有效的周期,则进行过滤
query_set = query_set.filter(period=mapped_period)
else:
# 如果找不到有效的映射,抛出错误
raise ValueError("无效的财报周期参数")
# 检查是否所有筛选条件都为空
if not request.revenueYoyType and not request.nIncomeYoyType and not request.revenueYoyTop10:
# 如果全部为空,则按 date 和 created_at 降序排序
query_set = query_set.order_by("-date", "-created_at")
# 筛选 revenueYoyType如果是 1 则筛选 revenue_yoy 大于 10.0 的记录
if request.revenueYoyType == 1:
query_set = query_set.filter(revenue_yoy__gt=10.0).order_by("-revenue_yoy")
# 筛选 nIncomeYoyType如果是 1 则筛选 nincome_yoy 大于 10.0 的记录
elif request.nIncomeYoyType == 1:
query_set = query_set.filter(nincome_yoy__gt=10.0).order_by("-nincome_yoy")
# 如果 revenueYoyTop10 为 1则筛选 revenue_yoy 前 10% 的记录
elif request.revenueYoyTop10 == 1:
# 计算前 10% 的数量
total_count = await FinancialReport.all().count() # 获取总记录数
limit_count = ceil(total_count * 0.1)
# 按 revenue_yoy 降序排列,获取前 10% 记录
query_set = query_set.order_by("-revenue_yoy").limit(limit_count).order_by("-revenue_yoy")
# 调用分页函数进行分页处理
params = Params(page=request.pageNo, size=request.pageSize) #parms作用是传页码和页面大小
result = await pagination(pydantic_model=FinancialReport_Pydantic,
query_set=query_set,
params=params)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail="内部服务器错误")

@ -1,60 +1,103 @@
import sys import sys
import os import os
# 添加项目的根目录到 sys.path # 添加项目的根目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import sentry_sdk import sentry_sdk
import uvicorn import uvicorn
from fastapi import FastAPI import pandas as pd
from starlette.middleware.cors import CORSMiddleware import asyncio
from fastapi import FastAPI
from src.exceptions import register_exception_handler from datetime import datetime
from src.tortoises import register_tortoise_orm
from src.xtdata.router import router as xtdata_router
from src.backtest.router import router as backtest_router from starlette.middleware.cors import CORSMiddleware
from src.combination.router import router as combine_router
from src.exceptions import register_exception_handler
from xtquant import xtdata from src.tortoises import register_tortoise_orm
from src.settings.config import app_configs, settings from src.xtdata.router import router as xtdata_router
from src.backtest.router import router as backtest_router
from src.combination.router import router as combine_router
from src.models.financial_reports import FinancialReport
from src.utils.update_financial_reports_spider import combined_search_and_list
app = FastAPI(**app_configs) from src.financial_reports.router import financial_reports_router
register_tortoise_orm(app) from xtquant import xtdata
from src.settings.config import app_configs, settings
register_exception_handler(app)
import adata
app.include_router(xtdata_router, prefix="/getwancedata", tags=["数据接口"]) import akshare as ak
app.include_router(backtest_router, prefix="/backtest", tags=["回测接口"])
app.include_router(combine_router, prefix="/combine", tags=["组合接口"])
app = FastAPI(**app_configs)
if settings.ENVIRONMENT.is_deployed: register_tortoise_orm(app)
sentry_sdk.init(
dsn=settings.SENTRY_DSN, register_exception_handler(app)
environment=settings.ENVIRONMENT,
) app.include_router(xtdata_router, prefix="/getwancedata", tags=["数据接口"])
app.include_router(backtest_router, prefix="/backtest", tags=["回测接口"])
app.add_middleware( app.include_router(combine_router, prefix="/combine", tags=["组合接口"])
CORSMiddleware, app.include_router(financial_reports_router, prefix="/financial-reports", tags=["财报接口"])
allow_origins=settings.CORS_ORIGINS,
allow_origin_regex=settings.CORS_ORIGINS_REGEX, if settings.ENVIRONMENT.is_deployed:
allow_credentials=True, sentry_sdk.init(
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"), dsn=settings.SENTRY_DSN,
allow_headers=settings.CORS_HEADERS, environment=settings.ENVIRONMENT,
) )
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
@app.get("/") allow_origin_regex=settings.CORS_ORIGINS_REGEX,
async def root(): allow_credentials=True,
return {"message": "Hello, FastAPI!"} allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
allow_headers=settings.CORS_HEADERS,
)
if __name__ == "__main__":
uvicorn.run('src.main:app', host="0.0.0.0", port=8012, reload=True) @app.get("/")
async def root():
return {"message": "Hello, FastAPI!"}
# 定时检查和数据抓取函数
async def run_data_fetcher():
print("财报抓取启动")
while True:
# 获取数据库中最新记录的时间
latest_record = await FinancialReport .all().order_by("-created_at").first()
latest_db_date = latest_record.created_at if latest_record else pd.Timestamp("1970-01-01")
# 将最新数据库日期设为无时区,以便比较
latest_db_date = latest_db_date.replace(tzinfo=None)
# 当前时间
current_time = pd.Timestamp(datetime.now()).replace(tzinfo=None)
# 检查当前时间是否超过数据库最新记录时间 12 个小时
if (current_time - latest_db_date).total_seconds() > 43200:
print("启动财报数据抓取...")
await combined_search_and_list()
else:
print("未满足条件,跳过抓取任务。")
# 休眠 12 小时43200 秒),然后再次检查
await asyncio.sleep(43200)
async def test_liang_hua_ku():
print("量化库测试函数启动")
# 启动时立即运行检查任务
@app.on_event("startup")
async def lifespan():
# 启动时将任务放入后台,不阻塞启动流程
asyncio.create_task(test_liang_hua_ku())
asyncio.create_task(run_data_fetcher())
if __name__ == "__main__":
uvicorn.run('src.main:app', host="0.0.0.0", port=8012, reload=True)

@ -0,0 +1,32 @@
from tortoise import fields, Model
class FinancialReport (Model):
"""
财报数据
"""
id = fields.IntField(pk=True, description="自增主键") # 自增主键
date = fields.DateField(description="日期") # 日期
period = fields.CharField(max_length=50, description="报告期") # 报告期
year = fields.IntField(description="年份") # 年份
stock_code = fields.CharField(max_length=10, description="股票代码") # 股票代码
stock_name = fields.CharField(max_length=50, description="股票名称") # 股票名称
industry_code = fields.CharField(max_length=10, null=True, description="行业代码") # 行业代码
industry_name = fields.CharField(max_length=50, null=True, description="行业名称") # 行业名称
pdf_name = fields.CharField(max_length=100, null=True, description="PDF 报告名称") # PDF 报告名称
pdf_url = fields.CharField(max_length=255, null=True, description="PDF 报告链接") # PDF 报告链接
revenue = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="收入") # 收入
revenue_yoy = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="收入同比增长") # 收入同比增长
pr_of_today = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="今日涨跌幅") # 今日涨跌幅
pr_of_after_1_day = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="次日涨跌幅") # 次日涨跌幅
pr_of_after_1_week = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="一周后涨跌幅") # 一周后涨跌幅
pr_of_after_1_month = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="一月后涨跌幅") # 一月后涨跌幅
pr_of_after_6_month = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="六个月后涨跌幅") # 六个月后涨跌幅
pr_of_that_year = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="年初至今涨跌幅") # 年初至今涨跌幅
nincome_yoy_lower_limit = fields.CharField(max_length=10, null=True, description="净利润同比增长下限") # 净利润同比增长下限
nincome_yoy_upper_limit = fields.CharField(max_length=10, null=True, description="净利润同比增长上限") # 净利润同比增长上限
nincome = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="净利润") # 净利润
nincome_yoy = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="净利润同比增长率") # 净利润同比增长率
created_at = fields.DatetimeField(auto_now_add=True, description="创建时间") # 创建时间,自动生成当前时间
class Meta:
table = "financial_reports"

@ -0,0 +1,12 @@
from tortoise import fields, Model
class StockGuoZheng2000(Model):
"""
国证2000成分股数据
"""
code = fields.CharField(max_length=10, description="股票代码") # 股票代码
name = fields.CharField(max_length=50, description="股票名称") # 股票名称
included_date = fields.CharField(max_length=20, description="纳入日期") # 纳入日期
class Meta:
table = "stock_guo_zheng_2000"

@ -0,0 +1,12 @@
from tortoise import fields, Model
class StockHuShen300(Model):
"""
沪深300成分股数据
"""
symbol = fields.CharField(max_length=10, description="股票代码") # 股票代码
code = fields.CharField(max_length=10, description="股票代码") # 股票代码
name = fields.CharField(max_length=50, description="股票名称") # 股票名称
class Meta:
table = "stock_hu_shen_300"

@ -0,0 +1,32 @@
from tortoise import fields, Model
class StockHuShenJingA(Model):
"""
沪深京 A 股实时行情数据
"""
seq = fields.IntField(description="序号") # 序号
code = fields.CharField(max_length=10, description="股票代码") # 股票代码
name = fields.CharField(max_length=50, description="股票名称") # 股票名称
latest_price = fields.FloatField(null=True, description="最新价格") # 最新价格
change_percent = fields.FloatField(null=True, description="涨跌幅") # 涨跌幅
change_amount = fields.FloatField(null=True, description="涨跌额") # 涨跌额
volume = fields.FloatField(null=True, description="成交量") # 成交量
turnover = fields.FloatField(null=True, description="成交额") # 成交额
amplitude = fields.FloatField(null=True, description="振幅") # 振幅
high = fields.FloatField(null=True, description="最高价") # 最高价
low = fields.FloatField(null=True, description="最低价") # 最低价
open = fields.FloatField(null=True, description="今开价") # 今开价
previous_close = fields.FloatField(null=True, description="昨收价") # 昨收价
volume_ratio = fields.FloatField(null=True, description="量比") # 量比
turnover_rate = fields.FloatField(null=True, description="换手率") # 换手率
pe_dynamic = fields.FloatField(null=True, description="市盈率-动态") # 市盈率-动态
pb = fields.FloatField(null=True, description="市净率") # 市净率
total_market_value = fields.FloatField(null=True, description="总市值") # 总市值
circulating_market_value = fields.FloatField(null=True, description="流通市值") # 流通市值
rise_speed = fields.FloatField(null=True, description="涨速") # 涨速
five_minute_change = fields.FloatField(null=True, description="5分钟涨跌幅") # 5分钟涨跌幅
sixty_day_change = fields.FloatField(null=True, description="60日涨跌幅") # 60日涨跌幅
year_to_date_change = fields.FloatField(null=True, description="年初至今涨跌幅") # 年初至今涨跌幅
class Meta:
table = "stock_hu_shen_jing_a"

@ -0,0 +1,12 @@
from tortoise import fields, Model
class StockZhongZheng500(Model):
"""
中证500成分股数据
"""
symbol = fields.CharField(max_length=10, description="股票代码") # 股票代码
code = fields.CharField(max_length=10, description="股票代码") # 股票代码
name = fields.CharField(max_length=50, description="股票名称") # 股票名称
class Meta:
table = "stock_zhong_zheng_500"

Binary file not shown.

@ -1,186 +1,192 @@
import logging import logging
from fastapi import FastAPI from fastapi import FastAPI
from tortoise.contrib.fastapi import register_tortoise from tortoise.contrib.fastapi import register_tortoise
from src.settings.config import settings from src.settings.config import settings
DATABASE_URL = str(settings.DATABASE_URL) DATABASE_URL = str(settings.DATABASE_URL)
DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL) DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL)
# 定义不同日志级别的颜色 # 定义不同日志级别的颜色
class ColoredFormatter(logging.Formatter): class ColoredFormatter(logging.Formatter):
COLORS = { COLORS = {
"DEBUG": "\033[94m", # 蓝色 "DEBUG": "\033[94m", # 蓝色
"INFO": "\033[92m", # 绿色 "INFO": "\033[92m", # 绿色
"WARNING": "\033[93m", # 黄色 "WARNING": "\033[93m", # 黄色
"ERROR": "\033[91m", # 红色 "ERROR": "\033[91m", # 红色
"CRITICAL": "\033[41m", # 红色背景 "CRITICAL": "\033[41m", # 红色背景
} }
RESET = "\033[0m" # 重置颜色 RESET = "\033[0m" # 重置颜色
def format(self, record): def format(self, record):
color = self.COLORS.get(record.levelname, self.RESET) color = self.COLORS.get(record.levelname, self.RESET)
log_message = super().format(record) log_message = super().format(record)
return f"{color}{log_message}{self.RESET}" return f"{color}{log_message}{self.RESET}"
# 配置日志 # 配置日志
logger_db_client = logging.getLogger("tortoise.db_client") logger_db_client = logging.getLogger("tortoise.db_client")
logger_db_client.setLevel(logging.DEBUG) logger_db_client.setLevel(logging.DEBUG)
# 创建一个控制台处理程序 # 创建一个控制台处理程序
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG) ch.setLevel(logging.DEBUG)
# 创建并设置格式化器 # 创建并设置格式化器
formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter) ch.setFormatter(formatter)
# 将处理程序添加到日志记录器中 # 将处理程序添加到日志记录器中
logger_db_client.addHandler(ch) logger_db_client.addHandler(ch)
# 定义模型 # 定义模型
models = [ models = [
"src.models.test_table", "src.models.test_table",
"src.models.stock", "src.models.stock",
"src.models.security_account", "src.models.security_account",
"src.models.order", "src.models.order",
"src.models.snowball", "src.models.snowball",
"src.models.backtest", "src.models.backtest",
"src.models.strategy", "src.models.strategy",
"src.models.stock_details", "src.models.stock_details",
"src.models.transaction", "src.models.transaction",
"src.models.position", "src.models.position",
"src.models.trand_info", "src.models.trand_info",
"src.models.tran_observer_data", "src.models.tran_observer_data",
"src.models.tran_orders", "src.models.tran_orders",
"src.models.tran_return", "src.models.tran_return",
"src.models.tran_trade_info", "src.models.tran_trade_info",
"src.models.back_observed_data", "src.models.back_observed_data",
"src.models.back_observed_data_detail", "src.models.back_observed_data_detail",
"src.models.back_position", "src.models.back_position",
"src.models.back_result_indicator", "src.models.back_result_indicator",
"src.models.back_trand_info", "src.models.back_trand_info",
"src.models.tran_position", "src.models.tran_position",
"src.models.stock_bt_history", "src.models.stock_bt_history",
"src.models.stock_history", "src.models.stock_history",
"src.models.stock_data_processing", "src.models.stock_data_processing",
"src.models.wance_data_stock", "src.models.wance_data_stock",
"src.models.wance_data_storage_backtest", "src.models.wance_data_storage_backtest",
"src.models.user_combination_history" "src.models.user_combination_history",
"src.models.financial_reports",
] "src.models.stock_hu_shen300",
"src.models.stock_zhong_zheng_500",
aerich_models = models "src.models.stock_guo_zheng_2000",
aerich_models.append("aerich.models") "src.models.stock_hu_shen_jing_a"
# Tortoise ORM 配置
TORTOISE_ORM = { ]
"connections": {"default": DATABASE_CREATE_URL,
}, aerich_models = models
"apps": { aerich_models.append("aerich.models")
"models": {
"models": aerich_models, # Tortoise ORM 配置
"default_connection": "default", TORTOISE_ORM = {
}, "connections": {"default": DATABASE_CREATE_URL,
}, },
} "apps": {
"models": {
config = { "models": aerich_models,
'connections': { "default_connection": "default",
'default': DATABASE_URL },
}, },
'apps': { }
'models': {
'models': models, config = {
'default_connection': 'default', 'connections': {
} 'default': DATABASE_URL
}, },
'use_tz': False, 'apps': {
'timezone': 'Asia/Shanghai' 'models': {
} 'models': models,
'default_connection': 'default',
}
def register_tortoise_orm(app: FastAPI): },
# 注册 Tortoise ORM 配置 'use_tz': False,
register_tortoise( 'timezone': 'Asia/Shanghai'
app, }
config=config,
generate_schemas=False,
add_exception_handlers=True, def register_tortoise_orm(app: FastAPI):
) # 注册 Tortoise ORM 配置
register_tortoise(
app,
""" config=config,
********** generate_schemas=False,
以下内容对原有无改色配置文件进行注释上述内容为修改后日志为蓝色 add_exception_handlers=True,
********** )
"""
# import logging """
# **********
# from fastapi import FastAPI 以下内容对原有无改色配置文件进行注释上述内容为修改后日志为蓝色
# from tortoise.contrib.fastapi import register_tortoise **********
# """
# from src.settings.config import settings
# # import logging
# DATABASE_URL = str(settings.DATABASE_URL) #
# DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL) # from fastapi import FastAPI
# # from tortoise.contrib.fastapi import register_tortoise
# # will print debug sql #
# logging.basicConfig() # from src.settings.config import settings
# logger_db_client = logging.getLogger("tortoise.db_client") #
# logger_db_client.setLevel(logging.DEBUG) # DATABASE_URL = str(settings.DATABASE_URL)
# # DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL)
# models = [ #
# "src.models.test_table", # # will print debug sql
# "src.models.stock", # logging.basicConfig()
# "src.models.security_account", # logger_db_client = logging.getLogger("tortoise.db_client")
# "src.models.order", # logger_db_client.setLevel(logging.DEBUG)
# "src.models.snowball", #
# "src.models.backtest", # models = [
# "src.models.strategy", # "src.models.test_table",
# "src.models.stock_details", # "src.models.stock",
# "src.models.transaction", # "src.models.security_account",
# ] # "src.models.order",
# # "src.models.snowball",
# aerich_models = models # "src.models.backtest",
# aerich_models.append("aerich.models") # "src.models.strategy",
# # "src.models.stock_details",
# TORTOISE_ORM = { # "src.models.transaction",
# # ]
# "connections": {"default": DATABASE_CREATE_URL}, #
# "apps": { # aerich_models = models
# "models": { # aerich_models.append("aerich.models")
# "models": aerich_models, #
# "default_connection": "default", # TORTOISE_ORM = {
# }, #
# }, # "connections": {"default": DATABASE_CREATE_URL},
# } # "apps": {
# # "models": {
# config = { # "models": aerich_models,
# 'connections': { # "default_connection": "default",
# # Using a DB_URL string # },
# 'default': DATABASE_URL # },
# }, # }
# 'apps': { #
# 'models': { # config = {
# 'models': models, # 'connections': {
# 'default_connection': 'default', # # Using a DB_URL string
# } # 'default': DATABASE_URL
# }, # },
# 'use_tz': False, # 'apps': {
# 'timezone': 'Asia/Shanghai' # 'models': {
# } # 'models': models,
# # 'default_connection': 'default',
# # }
# def register_tortoise_orm(app: FastAPI): # },
# register_tortoise( # 'use_tz': False,
# app, # 'timezone': 'Asia/Shanghai'
# config=config, # }
# # db_url=DATABASE_URL, #
# # db_url="mysql://adams:adams@127.0.0.1:3306/adams", #
# # modules={"models": models}, # def register_tortoise_orm(app: FastAPI):
# generate_schemas=False, # register_tortoise(
# add_exception_handlers=True, # app,
# ) # config=config,
# # db_url=DATABASE_URL,
# # db_url="mysql://adams:adams@127.0.0.1:3306/adams",
# # modules={"models": models},
# generate_schemas=False,
# add_exception_handlers=True,
# )

@ -1,74 +1,66 @@
from typing import Generic, Optional, Sequence, TypeVar from typing import Type, TypeVar, Generic, Sequence, Optional, Callable, Awaitable
import math
import math from fastapi import Query
from fastapi import Query from pydantic import BaseModel
from pydantic import BaseModel from tortoise.queryset import QuerySet
from tortoise.queryset import QuerySet from tortoise.models import Model
T = TypeVar("T") T = TypeVar('T', bound=BaseModel)
M = TypeVar('M', bound=Model)
class PaginationPydantic(BaseModel, Generic[T]): class PaginationPydantic(BaseModel, Generic[T]):
"""分页模型""" """分页模型"""
status_code: int = 200 status_code: int = 200
message: str = "Success" message: str = "Success"
total: int total: int
page: int page: int
size: int size: int
total_pages: int total_pages: int
data: Sequence[T] data: Sequence[T]
class Params(BaseModel):
class Params(BaseModel): """传参"""
"""传参""" page: int = Query(1, ge=1, description="Page number")
# 设置默认值为1不能够小于1 size: int = Query(10, gt=0, le=200, description="Page size")
page: int = Query(1, ge=1, description="Page number") order_by: Optional[str] = Query(None, max_length=32, description="Sort key")
# 设置默认值为10最大为100 async def pagination(
size: int = Query(10, gt=0, le=200, description="Page size") pydantic_model: Type[T],
query_set: QuerySet[M],
# 默认值None表示选传 params: Params,
order_by: Optional[str] = Query(None, max_length=32, description="Sort key") callback: Optional[Callable[[QuerySet[M]], Awaitable[QuerySet[M]]]] = None
) -> PaginationPydantic[T]:
"""分页响应"""
async def pagination(pydantic_model, query_set: QuerySet, params: Params, callback=None): page: int = params.page
"""分页响应""" size: int = params.size
""" order_by: Optional[str] = params.order_by
pydantic_model: Pydantic model total = await query_set.count()
query_set: QuerySet
params: Params # 计算总页数
callback: if you want to do something for query_set,it will be useful. total_pages = math.ceil(total / size) if total > 0 else 1
"""
page: int = params.page if page > total_pages and total > 0:
size: int = params.size raise ValueError("页数输入有误")
order_by: str = params.order_by
total = await query_set.count() # 排序
if order_by:
# 通过总数和每页数量计算出总页数 order_by_fields = order_by.split(',')
total_pages = math.ceil(total / size) query_set = query_set.order_by(*order_by_fields)
if page > total_pages and total: # 排除查询集为空时报错即total=0时 # 分页
raise ValueError("页数输入有误") query_set = query_set.offset((page - 1) * size).limit(size)
# 排序后分页 # 回调处理
if order_by: if callback:
order_by = order_by.split(',') query_set = await callback(query_set)
query_set = query_set.order_by(*order_by)
# 获取数据
# 分页 data = await pydantic_model.from_queryset(query_set)
query_set = query_set.offset((page - 1) * size) # 页数 * 页面大小=偏移量
query_set = query_set.limit(size) return PaginationPydantic[T](
total=total,
if callback: page=page,
"""对查询集操作""" size=size,
query_set = await callback(query_set) total_pages=total_pages,
data=data
# query_set = await query_set )
data = await pydantic_model.from_queryset(query_set)
return PaginationPydantic(**{
"total": total,
"page": page,
"size": size,
"total_pages": total_pages,
"data": data,
})

@ -0,0 +1,202 @@
"""
TO DO待完善全体数据更新逻辑
"""
import time
import requests
import pandas as pd
from fastapi import APIRouter
from src.models.financial_reports import FinancialReport # 确保路径正确
async def combined_search_and_list():
# 分页参数
pageNo = 1
pageSize = 50
selector_url = "https://www.shidaotec.com/api/research/getFinaResearchPageList"
# 请求第一页获取总记录数
response = requests.get(selector_url, params={
"pageNo": pageNo,
"pageSize": pageSize,
"year": "2024",
"type": ""
})
response.raise_for_status()
selector_data = response.json()
# 获取总记录数和数据库已有记录数
total_records = selector_data.get("data", {}).get("total", 0)
db_record_count = await FinancialReport .all().count()
remaining_insert_count = total_records - db_record_count
print(f"总记录数: {total_records},数据库已有记录数: {db_record_count},需要插入: {remaining_insert_count}")
# 如果已有记录数与总记录数相同,则直接结束函数
if remaining_insert_count <= 0:
print("数据库已包含所有记录,无需插入新数据。")
return
# 获取数据库中最新的记录
latest_record = await FinancialReport .all().order_by("-date").first()
latest_db_date = latest_record.date if latest_record else None
latest_db_date = pd.Timestamp(latest_db_date)
# 缓存数据
all_data = []
# 开始分页请求,直到遇到比数据库最新日期更早的数据
while True:
try:
# 请求当前页数据
response = requests.get(selector_url, params={
"pageNo": pageNo,
"pageSize": pageSize,
"year": "2024",
"type": ""
})
response.raise_for_status()
selector_data = response.json()
fina_data_list = selector_data.get("data", {}).get("list", [])
# 如果页数据为空,停止循环
if not fina_data_list:
print("No more data available.")
break
# 缓存本页数据
for record in fina_data_list:
record.pop("seq", None) # 移除不必要的字段
record = {k: (None if v == "-" else v) for k, v in record.items()}
all_data.append(record)
# 检查本页第一条数据日期,决定是否继续请求
first_record_date = pd.to_datetime(fina_data_list[0]["date"])
if latest_db_date and first_record_date < latest_db_date:
print(f"Data on page {pageNo} is older than the latest entry in the database. Stopping requests.")
break # 停止请求
print(f"{pageNo} 页成功获取并缓存 {len(fina_data_list)} 条记录")
pageNo += 1
time.sleep(2) # 请求间隔
except requests.RequestException as e:
print(f"Error fetching data for page {pageNo}: {e}")
break
except Exception as e:
print(f"Unexpected error on page {pageNo}: {e}")
break
# 批量查询数据库中已存在的记录
existing_records = await FinancialReport .filter(
date__in=[record["date"] for record in all_data],
period__in=[record["period"] for record in all_data],
stock_code__in=[record["stockCode"] for record in all_data]
).all()
# 将已存在记录存入集合以便于后续检查
existing_data_set = {
(record.date, record.period, record.stock_code, record.stock_name,
record.industry_code, record.industry_name, record.pdf_name,
record.pdf_url, record.revenue, record.revenue_yoy,
record.pr_of_today, record.pr_of_after_1_day,
record.pr_of_after_1_week, record.pr_of_after_1_month,
record.pr_of_after_6_month, record.pr_of_that_year,
record.nincome_yoy_lower_limit, record.nincome_yoy_upper_limit,
record.nincome, record.nincome_yoy)
for record in existing_records
}
# 开始处理缓存的数据,插入数据库
success_count = 0
records_to_insert = []
for record in all_data[:]: # 使用切片拷贝以避免在迭代时修改原始列表
record_date = pd.to_datetime(record["date"])
# 检查该记录是否已存在于数据库中
if (record["date"], record["period"], record["stockCode"], record["stockName"],
record["industryCode"], record["industryName"], record["pdfName"], record["pdfUrl"],
record["revenue"], record["revenueYoy"], record["prOfToday"], record["prOfAfter1Day"],
record["prOfAfter1Week"], record["prOfAfter1Month"], record["prOfAfter6Month"],
record["prOfThatYear"], record["nincomeYoyLowerLimit"], record["nincomeYoyUpperLimit"],
record["nincome"], record["nincomeYoy"]) in existing_data_set:
print(f"Skipping existing record: {record}") # 打印已跳过的记录
continue # 跳过已存在的记录
# 打印即将插入的数据
print(f"Inserting record: {record}")
if record_date >= latest_db_date:
# 创建待插入的 FinancialReport 对象并添加到列表
try:
records_to_insert.append(
FinancialReport (
date=record["date"],
period=record["period"],
year=int(record["year"]),
stock_code=record["stockCode"],
stock_name=record["stockName"],
industry_code=record["industryCode"] if record["industryCode"] else None,
industry_name=record["industryName"] if record["industryName"] else None,
pdf_name=record["pdfName"] if record["pdfName"] else None,
pdf_url=record["pdfUrl"] if record["pdfUrl"] else None,
revenue=float(record["revenue"]) if record["revenue"] else None,
revenue_yoy=float(record["revenueYoy"]) if record["revenueYoy"] else None,
pr_of_today=float(record["prOfToday"]) if record["prOfToday"] else None,
pr_of_after_1_day=float(record["prOfAfter1Day"]) if record["prOfAfter1Day"] else None,
pr_of_after_1_week=float(record["prOfAfter1Week"]) if record["prOfAfter1Week"] else None,
pr_of_after_1_month=float(record["prOfAfter1Month"]) if record["prOfAfter1Month"] else None,
pr_of_after_6_month=float(record["prOfAfter6Month"]) if record["prOfAfter6Month"] else None,
pr_of_that_year=float(record["prOfThatYear"]) if record["prOfThatYear"] else None,
nincome_yoy_lower_limit=record["nincomeYoyLowerLimit"] if record["nincomeYoyLowerLimit"] else None,
nincome_yoy_upper_limit=record["nincomeYoyUpperLimit"] if record["nincomeYoyUpperLimit"] else None,
nincome=float(record["nincome"]) if record["nincome"] else None,
nincome_yoy=float(record["nincomeYoy"]) if record["nincomeYoy"] else None
)
)
success_count += 1
except Exception as e:
print(f"Error processing record {record}: {e}") # 输出处理错误
# 成功插入所需条数后停止
if success_count >= remaining_insert_count:
print("this riqi:" + str(latest_db_date))
break
# 批量插入到数据库
if records_to_insert:
try:
await FinancialReport .bulk_create(records_to_insert)
print(f"成功插入 {len(records_to_insert)} 条记录")
except Exception as e:
print(f"Error during bulk insert: {e}")
# 去重步骤
try:
# 查找基于 date、period 和 stock_code 的重复记录
duplicates = await FinancialReport.raw("""
SELECT id FROM financial_reports WHERE id NOT IN (
SELECT id FROM (
SELECT MAX(id) AS id
FROM financial_reports
GROUP BY date, period, stock_code
) AS temp
)
""")
duplicate_ids = [record.id for record in duplicates]
if duplicate_ids:
# 删除较旧的重复记录
await FinancialReport.filter(id__in=duplicate_ids).delete()
print(f"删除了 {len(duplicate_ids)} 条重复记录")
else:
print("未发现重复记录")
except Exception as e:
print(f"去重过程中出错: {e}")
return {"message": "数据获取、插入和去重已完成"}