diff --git a/.env b/.env index 42bf6bc..a222db3 100644 --- a/.env +++ b/.env @@ -1,23 +1,23 @@ -;使用本地docker数据库就将这段代码解开,如何将IP换成本机IP或者localhost -; DATABASE_URL="mysql://root:123456@10.1.5.219:3306/ftrade" -; DATABASE_CREATE_URL ="mysql://root:123456@10.1.5.219:3306/ftrade" - -DATABASE_URL="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche" -DATABASE_CREATE_URL ="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche" - -REDIS_URL="redis://:redis_tQNjCH@cqxqg.tech:6380" - -SITE_DOMAIN=127.0.0.1 -SECURE_COOKIES=false - -ENVIRONMENT=LOCAL - -CORS_HEADERS=["*"] -CORS_ORIGINS=["http://localhost:3000"] - -# postgres variables, must be the same as in DATABASE_URL -POSTGRES_USER=app -POSTGRES_PASSWORD=app -POSTGRES_HOST=app_db -POSTGRES_PORT=5432 +# 使用本地docker数据库就将这段代码解开,如何将IP换成本机IP或者localhost +# DATABASE_URL="mysql://root:123456@10.1.5.219:3306/ftrade" +# DATABASE_CREATE_URL ="mysql://root:123456@10.1.5.219:3306/ftrade" + +DATABASE_URL="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche" +DATABASE_CREATE_URL ="mysql://wangche:fN7sXX8saiQKXWbG@cqxqg.tech:3308/wangche" + +REDIS_URL="redis://:redis_tQNjCH@cqxqg.tech:6380" + +SITE_DOMAIN=127.0.0.1 +SECURE_COOKIES=false + +ENVIRONMENT=LOCAL + +CORS_HEADERS=["*"] +CORS_ORIGINS=["http://localhost:3000"] + +# postgres variables, must be the same as in DATABASE_URL +POSTGRES_USER=app +POSTGRES_PASSWORD=app +POSTGRES_HOST=app_db +POSTGRES_PORT=5432 POSTGRES_DB=app \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 39831c4..74ef8a9 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,7 +1,7 @@ - - - - - + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..c8397c9 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/wance_data.iml b/.idea/wance_data.iml index d33acca..5ef22bc 100644 --- a/.idea/wance_data.iml +++ b/.idea/wance_data.iml @@ -1,15 +1,17 @@ - - - - - - - - - - - + + + + + + + + + + + + + \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..741c985 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,17 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python 调试程序: FastAPI", + "type": "debugpy", + "request": "launch", + "module": "uvicorn", + "args": [ + "src.main:app", + "--reload", + "--port", "8012" + ], + "jinja": true + } + ] +} diff --git a/src/__pycache__/main.cpython-310.pyc b/src/__pycache__/main.cpython-310.pyc index 724b5c5..2fd05b1 100644 Binary files a/src/__pycache__/main.cpython-310.pyc and b/src/__pycache__/main.cpython-310.pyc differ diff --git a/src/__pycache__/tortoises.cpython-310.pyc b/src/__pycache__/tortoises.cpython-310.pyc index 20f98cb..4e903d0 100644 Binary files a/src/__pycache__/tortoises.cpython-310.pyc and b/src/__pycache__/tortoises.cpython-310.pyc differ diff --git a/src/financial_reports/mappings.py b/src/financial_reports/mappings.py new file mode 100644 index 0000000..89ee695 --- /dev/null +++ b/src/financial_reports/mappings.py @@ -0,0 +1,36 @@ +""" + 映射表, 为空则返回none表示不筛选此字段 +""" + +# 时间选择映射表 +year_mapping = { + """ + 为空则返回none + """ + "2024": 2024, +} + +# 股票池映射表 +stock_pool_mapping = { + "000300.SH": "000300.SH", + "000905.SH": "000905.SH", + "399303.SZ": "399303.SZ", + "large_mv": "large_mv", + "medium_mv": "medium_mv", + "small_mv": "small_mv", +} + +# 财报周期映射 +period_mapping = { + 1: "一季度报告", # 一季度报告 + 2: "半年度报告", # 半年度报告 + 3: "三季度报告", # 三季度报告 + 4: "正式年报", # 正式年报 + 9: "一季度业绩预告", # 一季度业绩预告 + 10: "半年度业绩预告", # 半年度业绩预告 + 11: "三季度业绩预告", # 三季度业绩预告 + 12: "年度业绩预告", # 年度业绩预告 + 5: "一季度业绩快报", # 一季度业绩快报 + 6: "半年度业绩快报", # 半年度业绩快报 + 7: "三季度业绩快报" # 三季度业绩快报 +} \ No newline at end of file diff --git a/src/financial_reports/router.py b/src/financial_reports/router.py new file mode 100644 index 0000000..c11c515 --- /dev/null +++ b/src/financial_reports/router.py @@ -0,0 +1,13 @@ +from fastapi import APIRouter +from fastapi.responses import JSONResponse + +from src.financial_reports.schemas import * +from src.financial_reports.service import * + +financial_reports_router = APIRouter() + +@financial_reports_router.post("/query") +async def financial_repoets_query(request: FinancialReportQuery )-> JSONResponse: + result = await financial_reports_query_service(request) + + return result \ No newline at end of file diff --git a/src/financial_reports/schemas.py b/src/financial_reports/schemas.py new file mode 100644 index 0000000..59dfe11 --- /dev/null +++ b/src/financial_reports/schemas.py @@ -0,0 +1,16 @@ +from typing import List, Optional, Any, Dict + +from pydantic import BaseModel, Field + +class FinancialReportQuery(BaseModel): + pageNo: int = Field(..., ge=1, le=1000, description="页码,必须大于等于1且小于等于1000") + pageSize: int = Field(..., ge=1, le=50, description="每页条数,必须大于等于1且小于等于50") + + # 可选参数 + year: Optional[int] = Field(None, ge=1, le=2100, description="时间选择,必须为正整数或特定年份且小于等于2100,或留空") + stockPoolCode: Optional[str] = Field(None, max_length=20, min_length=1, description="股票池,限制最大长度为20,限制最小长度为1,或留空") + period: Optional[int] = Field(None, ge=1, le=50, description="财报周期,必须大于等于1且小于等于50,或留空") + revenueYoyType: Optional[int] = Field(None, ge=1, le=10, description="营业收入同比,必须大于等于1且小于等于10,或留空") + nIncomeYoyType: Optional[int] = Field(None, ge=1, le=10, description="净利润同比,必须大于等于1且小于等于10,或留空") + revenueYoyTop10: Optional[int] = Field(None, ge=1, le=10, description="营业收入行业前10%,必须大于等于1且小于等于10,或留空") + diff --git a/src/financial_reports/service.py b/src/financial_reports/service.py new file mode 100644 index 0000000..46fbb53 --- /dev/null +++ b/src/financial_reports/service.py @@ -0,0 +1,134 @@ +import json +import adata + +from math import ceil +from typing import Any, Dict +from fastapi import Request, HTTPException, BackgroundTasks +from tortoise.expressions import Subquery,Q +from tortoise.contrib.pydantic import pydantic_model_creator + +from src.models.financial_reports import * +from src.models.stock_hu_shen300 import * +from src.models.stock_zhong_zheng_500 import * +from src.models.stock_guo_zheng_2000 import * +from src.models.stock_hu_shen_jing_a import * +from src.financial_reports.schemas import * +from src.utils.paginations import * +from src.financial_reports.mappings import * + + +# 创建一个不包含 "created_at" 字段的 Pydantic 模型用于响应 +FinancialReport_Pydantic = pydantic_model_creator(FinancialReport, exclude=("created_at",)) + +async def financial_reports_query_service( + request: FinancialReportQuery +) -> PaginationPydantic: + """ + 根据选择器和关键词查询财报,并进行分页 + """ + # 构建查询集,这里假设财报数据模型为 FinancialReport + query_set: QuerySet = FinancialReport.all() + + try: + # 年度映射 + if request.year: + mapped_year = year_mapping.get(request.year) + if mapped_year is not None: + query_set = query_set.filter(year=mapped_year) + else: + raise ValueError("无效的年份参数") + + # 股票池映射 + if request.stockPoolCode is None: + # 未提供股票池代码,返回所有记录 + pass + else: + # 获取映射值,如果不存在则返回 None + mapped_stock_pool = stock_pool_mapping.get(request.stockPoolCode) + + # 检查 mapped_stock_pool 是否为 None(即不在映射表中) + if mapped_stock_pool is None: + # 如果 stockPoolCode 不在映射表中,抛出无效股票池参数错误 + raise ValueError("无效的股票池参数") + + # 如果存在有效的映射值,执行相应的过滤 + elif mapped_stock_pool == "000300.SH": + subquery = await StockHuShen300.filter().values_list('code', flat=True) + query_set = query_set.filter(stock_code__in=subquery) + elif mapped_stock_pool == "000905.SH": + subquery = await StockZhongZheng500.filter().values_list('code', flat=True) + query_set = query_set.filter(stock_code__in=subquery) + elif mapped_stock_pool == "399303.SZ": + subquery = await StockGuoZheng2000.filter().values_list('code', flat=True) + query_set = query_set.filter(stock_code__in=subquery) + elif mapped_stock_pool in ["large_mv", "medium_mv", "small_mv"]: + # 先获取所有有总市值的关联的股票代码 + subquery = await StockHuShenJingA.filter(total_market_value__isnull=False).values_list('code', 'total_market_value') + + # 转换为 DataFrame 或者直接进行排序 + stock_list = sorted(subquery, key=lambda x: x[1], reverse=True) # 按 total_market_value 降序排序 + + total_count = len(stock_list) + + if mapped_stock_pool == "large_mv": + # 获取前 30% 的数据 + limit_count = ceil(total_count * 0.3) + selected_stocks = [stock[0] for stock in stock_list[:limit_count]] + elif mapped_stock_pool == "medium_mv": + # 获取中间 40% 的数据 + start_offset = ceil(total_count * 0.3) + limit_count = ceil(total_count * 0.4) + selected_stocks = [stock[0] for stock in stock_list[start_offset:start_offset + limit_count]] + elif mapped_stock_pool == "small_mv": + # 获取后 30% 的数据 + start_offset = ceil(total_count * 0.7) + selected_stocks = [stock[0] for stock in stock_list[start_offset:]] + + # 对 FinancialReport 表进行筛选 + query_set = query_set.filter(stock_code__in=selected_stocks) + + # 财报周期映射 + if request.period is not None: + mapped_period = period_mapping.get(request.period) + if mapped_period is not None: + # 如果映射到有效的周期,则进行过滤 + query_set = query_set.filter(period=mapped_period) + else: + # 如果找不到有效的映射,抛出错误 + raise ValueError("无效的财报周期参数") + + # 检查是否所有筛选条件都为空 + if not request.revenueYoyType and not request.nIncomeYoyType and not request.revenueYoyTop10: + # 如果全部为空,则按 date 和 created_at 降序排序 + query_set = query_set.order_by("-date", "-created_at") + + # 筛选 revenueYoyType,如果是 1 则筛选 revenue_yoy 大于 10.0 的记录 + if request.revenueYoyType == 1: + query_set = query_set.filter(revenue_yoy__gt=10.0).order_by("-revenue_yoy") + + # 筛选 nIncomeYoyType,如果是 1 则筛选 nincome_yoy 大于 10.0 的记录 + elif request.nIncomeYoyType == 1: + query_set = query_set.filter(nincome_yoy__gt=10.0).order_by("-nincome_yoy") + + # 如果 revenueYoyTop10 为 1,则筛选 revenue_yoy 前 10% 的记录 + elif request.revenueYoyTop10 == 1: + # 计算前 10% 的数量 + total_count = await FinancialReport.all().count() # 获取总记录数 + limit_count = ceil(total_count * 0.1) + + # 按 revenue_yoy 降序排列,获取前 10% 记录 + query_set = query_set.order_by("-revenue_yoy").limit(limit_count).order_by("-revenue_yoy") + + + # 调用分页函数进行分页处理 + params = Params(page=request.pageNo, size=request.pageSize) #parms作用是传页码和页面大小 + result = await pagination(pydantic_model=FinancialReport_Pydantic, + query_set=query_set, + params=params) + + return result + + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail="内部服务器错误") \ No newline at end of file diff --git a/src/main.py b/src/main.py index 5ce721d..8a7edc1 100644 --- a/src/main.py +++ b/src/main.py @@ -1,60 +1,103 @@ -import sys -import os - -# 添加项目的根目录到 sys.path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -import sentry_sdk -import uvicorn -from fastapi import FastAPI -from starlette.middleware.cors import CORSMiddleware - -from src.exceptions import register_exception_handler -from src.tortoises import register_tortoise_orm -from src.xtdata.router import router as xtdata_router -from src.backtest.router import router as backtest_router -from src.combination.router import router as combine_router - -from xtquant import xtdata -from src.settings.config import app_configs, settings - - - - -app = FastAPI(**app_configs) - -register_tortoise_orm(app) - -register_exception_handler(app) - -app.include_router(xtdata_router, prefix="/getwancedata", tags=["数据接口"]) -app.include_router(backtest_router, prefix="/backtest", tags=["回测接口"]) -app.include_router(combine_router, prefix="/combine", tags=["组合接口"]) - - -if settings.ENVIRONMENT.is_deployed: - sentry_sdk.init( - dsn=settings.SENTRY_DSN, - environment=settings.ENVIRONMENT, - ) - -app.add_middleware( - CORSMiddleware, - allow_origins=settings.CORS_ORIGINS, - allow_origin_regex=settings.CORS_ORIGINS_REGEX, - allow_credentials=True, - allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"), - allow_headers=settings.CORS_HEADERS, -) - - - - -@app.get("/") -async def root(): - return {"message": "Hello, FastAPI!"} - - - -if __name__ == "__main__": - uvicorn.run('src.main:app', host="0.0.0.0", port=8012, reload=True) +import sys +import os + +# 添加项目的根目录到 sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +import sentry_sdk +import uvicorn +import pandas as pd +import asyncio +from fastapi import FastAPI +from datetime import datetime + + +from starlette.middleware.cors import CORSMiddleware + +from src.exceptions import register_exception_handler +from src.tortoises import register_tortoise_orm +from src.xtdata.router import router as xtdata_router +from src.backtest.router import router as backtest_router +from src.combination.router import router as combine_router +from src.models.financial_reports import FinancialReport +from src.utils.update_financial_reports_spider import combined_search_and_list +from src.financial_reports.router import financial_reports_router + +from xtquant import xtdata +from src.settings.config import app_configs, settings + +import adata +import akshare as ak + + +app = FastAPI(**app_configs) + +register_tortoise_orm(app) + +register_exception_handler(app) + +app.include_router(xtdata_router, prefix="/getwancedata", tags=["数据接口"]) +app.include_router(backtest_router, prefix="/backtest", tags=["回测接口"]) +app.include_router(combine_router, prefix="/combine", tags=["组合接口"]) +app.include_router(financial_reports_router, prefix="/financial-reports", tags=["财报接口"]) + +if settings.ENVIRONMENT.is_deployed: + sentry_sdk.init( + dsn=settings.SENTRY_DSN, + environment=settings.ENVIRONMENT, + ) + +app.add_middleware( + CORSMiddleware, + allow_origins=settings.CORS_ORIGINS, + allow_origin_regex=settings.CORS_ORIGINS_REGEX, + allow_credentials=True, + allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"), + allow_headers=settings.CORS_HEADERS, +) + + +@app.get("/") +async def root(): + return {"message": "Hello, FastAPI!"} + +# 定时检查和数据抓取函数 +async def run_data_fetcher(): + print("财报抓取启动") + while True: + # 获取数据库中最新记录的时间 + latest_record = await FinancialReport .all().order_by("-created_at").first() + latest_db_date = latest_record.created_at if latest_record else pd.Timestamp("1970-01-01") + + # 将最新数据库日期设为无时区,以便比较 + latest_db_date = latest_db_date.replace(tzinfo=None) + + # 当前时间 + current_time = pd.Timestamp(datetime.now()).replace(tzinfo=None) + + # 检查当前时间是否超过数据库最新记录时间 12 个小时 + if (current_time - latest_db_date).total_seconds() > 43200: + print("启动财报数据抓取...") + await combined_search_and_list() + else: + print("未满足条件,跳过抓取任务。") + + # 休眠 12 小时(43200 秒),然后再次检查 + await asyncio.sleep(43200) + +async def test_liang_hua_ku(): + print("量化库测试函数启动") + + + +# 启动时立即运行检查任务 +@app.on_event("startup") +async def lifespan(): + # 启动时将任务放入后台,不阻塞启动流程 + asyncio.create_task(test_liang_hua_ku()) + asyncio.create_task(run_data_fetcher()) + + + +if __name__ == "__main__": + uvicorn.run('src.main:app', host="0.0.0.0", port=8012, reload=True) diff --git a/src/models/financial_reports.py b/src/models/financial_reports.py new file mode 100644 index 0000000..c2734b8 --- /dev/null +++ b/src/models/financial_reports.py @@ -0,0 +1,32 @@ +from tortoise import fields, Model + +class FinancialReport (Model): + """ + 财报数据 + """ + id = fields.IntField(pk=True, description="自增主键") # 自增主键 + date = fields.DateField(description="日期") # 日期 + period = fields.CharField(max_length=50, description="报告期") # 报告期 + year = fields.IntField(description="年份") # 年份 + stock_code = fields.CharField(max_length=10, description="股票代码") # 股票代码 + stock_name = fields.CharField(max_length=50, description="股票名称") # 股票名称 + industry_code = fields.CharField(max_length=10, null=True, description="行业代码") # 行业代码 + industry_name = fields.CharField(max_length=50, null=True, description="行业名称") # 行业名称 + pdf_name = fields.CharField(max_length=100, null=True, description="PDF 报告名称") # PDF 报告名称 + pdf_url = fields.CharField(max_length=255, null=True, description="PDF 报告链接") # PDF 报告链接 + revenue = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="收入") # 收入 + revenue_yoy = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="收入同比增长") # 收入同比增长 + pr_of_today = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="今日涨跌幅") # 今日涨跌幅 + pr_of_after_1_day = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="次日涨跌幅") # 次日涨跌幅 + pr_of_after_1_week = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="一周后涨跌幅") # 一周后涨跌幅 + pr_of_after_1_month = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="一月后涨跌幅") # 一月后涨跌幅 + pr_of_after_6_month = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="六个月后涨跌幅") # 六个月后涨跌幅 + pr_of_that_year = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="年初至今涨跌幅") # 年初至今涨跌幅 + nincome_yoy_lower_limit = fields.CharField(max_length=10, null=True, description="净利润同比增长下限") # 净利润同比增长下限 + nincome_yoy_upper_limit = fields.CharField(max_length=10, null=True, description="净利润同比增长上限") # 净利润同比增长上限 + nincome = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="净利润") # 净利润 + nincome_yoy = fields.DecimalField(max_digits=15, decimal_places=2, null=True, description="净利润同比增长率") # 净利润同比增长率 + created_at = fields.DatetimeField(auto_now_add=True, description="创建时间") # 创建时间,自动生成当前时间 + + class Meta: + table = "financial_reports" diff --git a/src/models/stock_guo_zheng_2000.py b/src/models/stock_guo_zheng_2000.py new file mode 100644 index 0000000..e60cee7 --- /dev/null +++ b/src/models/stock_guo_zheng_2000.py @@ -0,0 +1,12 @@ +from tortoise import fields, Model + +class StockGuoZheng2000(Model): + """ + 国证2000成分股数据 + """ + code = fields.CharField(max_length=10, description="股票代码") # 股票代码 + name = fields.CharField(max_length=50, description="股票名称") # 股票名称 + included_date = fields.CharField(max_length=20, description="纳入日期") # 纳入日期 + + class Meta: + table = "stock_guo_zheng_2000" \ No newline at end of file diff --git a/src/models/stock_hu_shen300.py b/src/models/stock_hu_shen300.py new file mode 100644 index 0000000..4bf3730 --- /dev/null +++ b/src/models/stock_hu_shen300.py @@ -0,0 +1,12 @@ +from tortoise import fields, Model + +class StockHuShen300(Model): + """ + 沪深300成分股数据 + """ + symbol = fields.CharField(max_length=10, description="股票代码") # 股票代码 + code = fields.CharField(max_length=10, description="股票代码") # 股票代码 + name = fields.CharField(max_length=50, description="股票名称") # 股票名称 + + class Meta: + table = "stock_hu_shen_300" \ No newline at end of file diff --git a/src/models/stock_hu_shen_jing_a.py b/src/models/stock_hu_shen_jing_a.py new file mode 100644 index 0000000..9b88662 --- /dev/null +++ b/src/models/stock_hu_shen_jing_a.py @@ -0,0 +1,32 @@ +from tortoise import fields, Model + +class StockHuShenJingA(Model): + """ + 沪深京 A 股实时行情数据 + """ + seq = fields.IntField(description="序号") # 序号 + code = fields.CharField(max_length=10, description="股票代码") # 股票代码 + name = fields.CharField(max_length=50, description="股票名称") # 股票名称 + latest_price = fields.FloatField(null=True, description="最新价格") # 最新价格 + change_percent = fields.FloatField(null=True, description="涨跌幅") # 涨跌幅 + change_amount = fields.FloatField(null=True, description="涨跌额") # 涨跌额 + volume = fields.FloatField(null=True, description="成交量") # 成交量 + turnover = fields.FloatField(null=True, description="成交额") # 成交额 + amplitude = fields.FloatField(null=True, description="振幅") # 振幅 + high = fields.FloatField(null=True, description="最高价") # 最高价 + low = fields.FloatField(null=True, description="最低价") # 最低价 + open = fields.FloatField(null=True, description="今开价") # 今开价 + previous_close = fields.FloatField(null=True, description="昨收价") # 昨收价 + volume_ratio = fields.FloatField(null=True, description="量比") # 量比 + turnover_rate = fields.FloatField(null=True, description="换手率") # 换手率 + pe_dynamic = fields.FloatField(null=True, description="市盈率-动态") # 市盈率-动态 + pb = fields.FloatField(null=True, description="市净率") # 市净率 + total_market_value = fields.FloatField(null=True, description="总市值") # 总市值 + circulating_market_value = fields.FloatField(null=True, description="流通市值") # 流通市值 + rise_speed = fields.FloatField(null=True, description="涨速") # 涨速 + five_minute_change = fields.FloatField(null=True, description="5分钟涨跌幅") # 5分钟涨跌幅 + sixty_day_change = fields.FloatField(null=True, description="60日涨跌幅") # 60日涨跌幅 + year_to_date_change = fields.FloatField(null=True, description="年初至今涨跌幅") # 年初至今涨跌幅 + + class Meta: + table = "stock_hu_shen_jing_a" \ No newline at end of file diff --git a/src/models/stock_zhong_zheng_500.py b/src/models/stock_zhong_zheng_500.py new file mode 100644 index 0000000..8940f59 --- /dev/null +++ b/src/models/stock_zhong_zheng_500.py @@ -0,0 +1,12 @@ +from tortoise import fields, Model + +class StockZhongZheng500(Model): + """ + 中证500成分股数据 + """ + symbol = fields.CharField(max_length=10, description="股票代码") # 股票代码 + code = fields.CharField(max_length=10, description="股票代码") # 股票代码 + name = fields.CharField(max_length=50, description="股票名称") # 股票名称 + + class Meta: + table = "stock_zhong_zheng_500" diff --git a/src/pydantic/requirements.txt b/src/pydantic/requirements.txt new file mode 100644 index 0000000..c82145d Binary files /dev/null and b/src/pydantic/requirements.txt differ diff --git a/src/tortoises.py b/src/tortoises.py index 3f95d63..1ec0d14 100644 --- a/src/tortoises.py +++ b/src/tortoises.py @@ -1,186 +1,192 @@ -import logging -from fastapi import FastAPI -from tortoise.contrib.fastapi import register_tortoise -from src.settings.config import settings - -DATABASE_URL = str(settings.DATABASE_URL) -DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL) - - -# 定义不同日志级别的颜色 -class ColoredFormatter(logging.Formatter): - COLORS = { - "DEBUG": "\033[94m", # 蓝色 - "INFO": "\033[92m", # 绿色 - "WARNING": "\033[93m", # 黄色 - "ERROR": "\033[91m", # 红色 - "CRITICAL": "\033[41m", # 红色背景 - } - RESET = "\033[0m" # 重置颜色 - - def format(self, record): - color = self.COLORS.get(record.levelname, self.RESET) - log_message = super().format(record) - return f"{color}{log_message}{self.RESET}" - - -# 配置日志 -logger_db_client = logging.getLogger("tortoise.db_client") -logger_db_client.setLevel(logging.DEBUG) - -# 创建一个控制台处理程序 -ch = logging.StreamHandler() -ch.setLevel(logging.DEBUG) - -# 创建并设置格式化器 -formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -ch.setFormatter(formatter) - -# 将处理程序添加到日志记录器中 -logger_db_client.addHandler(ch) - -# 定义模型 -models = [ - "src.models.test_table", - "src.models.stock", - "src.models.security_account", - "src.models.order", - "src.models.snowball", - "src.models.backtest", - "src.models.strategy", - "src.models.stock_details", - "src.models.transaction", - "src.models.position", - "src.models.trand_info", - "src.models.tran_observer_data", - "src.models.tran_orders", - "src.models.tran_return", - "src.models.tran_trade_info", - "src.models.back_observed_data", - "src.models.back_observed_data_detail", - "src.models.back_position", - "src.models.back_result_indicator", - "src.models.back_trand_info", - "src.models.tran_position", - "src.models.stock_bt_history", - "src.models.stock_history", - "src.models.stock_data_processing", - "src.models.wance_data_stock", - "src.models.wance_data_storage_backtest", - "src.models.user_combination_history" - -] - -aerich_models = models -aerich_models.append("aerich.models") - -# Tortoise ORM 配置 -TORTOISE_ORM = { - "connections": {"default": DATABASE_CREATE_URL, - }, - "apps": { - "models": { - "models": aerich_models, - "default_connection": "default", - }, - }, -} - -config = { - 'connections': { - 'default': DATABASE_URL - }, - 'apps': { - 'models': { - 'models': models, - 'default_connection': 'default', - } - }, - 'use_tz': False, - 'timezone': 'Asia/Shanghai' -} - - -def register_tortoise_orm(app: FastAPI): - # 注册 Tortoise ORM 配置 - register_tortoise( - app, - config=config, - generate_schemas=False, - add_exception_handlers=True, - ) - - -""" -********** -以下内容对原有无改色配置文件进行注释,上述内容为修改后日志为蓝色 -********** -""" - -# import logging -# -# from fastapi import FastAPI -# from tortoise.contrib.fastapi import register_tortoise -# -# from src.settings.config import settings -# -# DATABASE_URL = str(settings.DATABASE_URL) -# DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL) -# -# # will print debug sql -# logging.basicConfig() -# logger_db_client = logging.getLogger("tortoise.db_client") -# logger_db_client.setLevel(logging.DEBUG) -# -# models = [ -# "src.models.test_table", -# "src.models.stock", -# "src.models.security_account", -# "src.models.order", -# "src.models.snowball", -# "src.models.backtest", -# "src.models.strategy", -# "src.models.stock_details", -# "src.models.transaction", -# ] -# -# aerich_models = models -# aerich_models.append("aerich.models") -# -# TORTOISE_ORM = { -# -# "connections": {"default": DATABASE_CREATE_URL}, -# "apps": { -# "models": { -# "models": aerich_models, -# "default_connection": "default", -# }, -# }, -# } -# -# config = { -# 'connections': { -# # Using a DB_URL string -# 'default': DATABASE_URL -# }, -# 'apps': { -# 'models': { -# 'models': models, -# 'default_connection': 'default', -# } -# }, -# 'use_tz': False, -# 'timezone': 'Asia/Shanghai' -# } -# -# -# def register_tortoise_orm(app: FastAPI): -# register_tortoise( -# app, -# config=config, -# # db_url=DATABASE_URL, -# # db_url="mysql://adams:adams@127.0.0.1:3306/adams", -# # modules={"models": models}, -# generate_schemas=False, -# add_exception_handlers=True, -# ) +import logging +from fastapi import FastAPI +from tortoise.contrib.fastapi import register_tortoise +from src.settings.config import settings + +DATABASE_URL = str(settings.DATABASE_URL) +DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL) + + +# 定义不同日志级别的颜色 +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[94m", # 蓝色 + "INFO": "\033[92m", # 绿色 + "WARNING": "\033[93m", # 黄色 + "ERROR": "\033[91m", # 红色 + "CRITICAL": "\033[41m", # 红色背景 + } + RESET = "\033[0m" # 重置颜色 + + def format(self, record): + color = self.COLORS.get(record.levelname, self.RESET) + log_message = super().format(record) + return f"{color}{log_message}{self.RESET}" + + +# 配置日志 +logger_db_client = logging.getLogger("tortoise.db_client") +logger_db_client.setLevel(logging.DEBUG) + +# 创建一个控制台处理程序 +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) + +# 创建并设置格式化器 +formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +ch.setFormatter(formatter) + +# 将处理程序添加到日志记录器中 +logger_db_client.addHandler(ch) + +# 定义模型 +models = [ + "src.models.test_table", + "src.models.stock", + "src.models.security_account", + "src.models.order", + "src.models.snowball", + "src.models.backtest", + "src.models.strategy", + "src.models.stock_details", + "src.models.transaction", + "src.models.position", + "src.models.trand_info", + "src.models.tran_observer_data", + "src.models.tran_orders", + "src.models.tran_return", + "src.models.tran_trade_info", + "src.models.back_observed_data", + "src.models.back_observed_data_detail", + "src.models.back_position", + "src.models.back_result_indicator", + "src.models.back_trand_info", + "src.models.tran_position", + "src.models.stock_bt_history", + "src.models.stock_history", + "src.models.stock_data_processing", + "src.models.wance_data_stock", + "src.models.wance_data_storage_backtest", + "src.models.user_combination_history", + "src.models.financial_reports", + "src.models.stock_hu_shen300", + "src.models.stock_zhong_zheng_500", + "src.models.stock_guo_zheng_2000", + "src.models.stock_hu_shen_jing_a" + + +] + +aerich_models = models +aerich_models.append("aerich.models") + +# Tortoise ORM 配置 +TORTOISE_ORM = { + "connections": {"default": DATABASE_CREATE_URL, + }, + "apps": { + "models": { + "models": aerich_models, + "default_connection": "default", + }, + }, +} + +config = { + 'connections': { + 'default': DATABASE_URL + }, + 'apps': { + 'models': { + 'models': models, + 'default_connection': 'default', + } + }, + 'use_tz': False, + 'timezone': 'Asia/Shanghai' +} + + +def register_tortoise_orm(app: FastAPI): + # 注册 Tortoise ORM 配置 + register_tortoise( + app, + config=config, + generate_schemas=False, + add_exception_handlers=True, + ) + + +""" +********** +以下内容对原有无改色配置文件进行注释,上述内容为修改后日志为蓝色 +********** +""" + +# import logging +# +# from fastapi import FastAPI +# from tortoise.contrib.fastapi import register_tortoise +# +# from src.settings.config import settings +# +# DATABASE_URL = str(settings.DATABASE_URL) +# DATABASE_CREATE_URL = str(settings.DATABASE_CREATE_URL) +# +# # will print debug sql +# logging.basicConfig() +# logger_db_client = logging.getLogger("tortoise.db_client") +# logger_db_client.setLevel(logging.DEBUG) +# +# models = [ +# "src.models.test_table", +# "src.models.stock", +# "src.models.security_account", +# "src.models.order", +# "src.models.snowball", +# "src.models.backtest", +# "src.models.strategy", +# "src.models.stock_details", +# "src.models.transaction", +# ] +# +# aerich_models = models +# aerich_models.append("aerich.models") +# +# TORTOISE_ORM = { +# +# "connections": {"default": DATABASE_CREATE_URL}, +# "apps": { +# "models": { +# "models": aerich_models, +# "default_connection": "default", +# }, +# }, +# } +# +# config = { +# 'connections': { +# # Using a DB_URL string +# 'default': DATABASE_URL +# }, +# 'apps': { +# 'models': { +# 'models': models, +# 'default_connection': 'default', +# } +# }, +# 'use_tz': False, +# 'timezone': 'Asia/Shanghai' +# } +# +# +# def register_tortoise_orm(app: FastAPI): +# register_tortoise( +# app, +# config=config, +# # db_url=DATABASE_URL, +# # db_url="mysql://adams:adams@127.0.0.1:3306/adams", +# # modules={"models": models}, +# generate_schemas=False, +# add_exception_handlers=True, +# ) diff --git a/src/utils/__pycache__/paginations.cpython-310.pyc b/src/utils/__pycache__/paginations.cpython-310.pyc index 82f7cd1..58023cf 100644 Binary files a/src/utils/__pycache__/paginations.cpython-310.pyc and b/src/utils/__pycache__/paginations.cpython-310.pyc differ diff --git a/src/utils/paginations.py b/src/utils/paginations.py index 21cf264..116db3e 100644 --- a/src/utils/paginations.py +++ b/src/utils/paginations.py @@ -1,74 +1,66 @@ -from typing import Generic, Optional, Sequence, TypeVar - -import math -from fastapi import Query -from pydantic import BaseModel -from tortoise.queryset import QuerySet - -T = TypeVar("T") - - -class PaginationPydantic(BaseModel, Generic[T]): - """分页模型""" - status_code: int = 200 - message: str = "Success" - total: int - page: int - size: int - total_pages: int - data: Sequence[T] - - -class Params(BaseModel): - """传参""" - # 设置默认值为1,不能够小于1 - page: int = Query(1, ge=1, description="Page number") - - # 设置默认值为10,最大为100 - size: int = Query(10, gt=0, le=200, description="Page size") - - # 默认值None表示选传 - order_by: Optional[str] = Query(None, max_length=32, description="Sort key") - - -async def pagination(pydantic_model, query_set: QuerySet, params: Params, callback=None): - """分页响应""" - """ - pydantic_model: Pydantic model - query_set: QuerySet - params: Params - callback: if you want to do something for query_set,it will be useful. - """ - page: int = params.page - size: int = params.size - order_by: str = params.order_by - total = await query_set.count() - - # 通过总数和每页数量计算出总页数 - total_pages = math.ceil(total / size) - - if page > total_pages and total: # 排除查询集为空时报错,即total=0时 - raise ValueError("页数输入有误") - - # 排序后分页 - if order_by: - order_by = order_by.split(',') - query_set = query_set.order_by(*order_by) - - # 分页 - query_set = query_set.offset((page - 1) * size) # 页数 * 页面大小=偏移量 - query_set = query_set.limit(size) - - if callback: - """对查询集操作""" - query_set = await callback(query_set) - - # query_set = await query_set - data = await pydantic_model.from_queryset(query_set) - return PaginationPydantic(**{ - "total": total, - "page": page, - "size": size, - "total_pages": total_pages, - "data": data, - }) +from typing import Type, TypeVar, Generic, Sequence, Optional, Callable, Awaitable +import math +from fastapi import Query +from pydantic import BaseModel +from tortoise.queryset import QuerySet +from tortoise.models import Model + +T = TypeVar('T', bound=BaseModel) +M = TypeVar('M', bound=Model) + +class PaginationPydantic(BaseModel, Generic[T]): + """分页模型""" + status_code: int = 200 + message: str = "Success" + total: int + page: int + size: int + total_pages: int + data: Sequence[T] + +class Params(BaseModel): + """传参""" + page: int = Query(1, ge=1, description="Page number") + size: int = Query(10, gt=0, le=200, description="Page size") + order_by: Optional[str] = Query(None, max_length=32, description="Sort key") + +async def pagination( + pydantic_model: Type[T], + query_set: QuerySet[M], + params: Params, + callback: Optional[Callable[[QuerySet[M]], Awaitable[QuerySet[M]]]] = None +) -> PaginationPydantic[T]: + """分页响应""" + page: int = params.page + size: int = params.size + order_by: Optional[str] = params.order_by + total = await query_set.count() + + # 计算总页数 + total_pages = math.ceil(total / size) if total > 0 else 1 + + if page > total_pages and total > 0: + raise ValueError("页数输入有误") + + # 排序 + if order_by: + order_by_fields = order_by.split(',') + query_set = query_set.order_by(*order_by_fields) + + # 分页 + query_set = query_set.offset((page - 1) * size).limit(size) + + # 回调处理 + if callback: + query_set = await callback(query_set) + + # 获取数据 + data = await pydantic_model.from_queryset(query_set) + + return PaginationPydantic[T]( + total=total, + page=page, + size=size, + total_pages=total_pages, + data=data + ) diff --git a/src/utils/update_financial_reports_spider.py b/src/utils/update_financial_reports_spider.py new file mode 100644 index 0000000..005dab0 --- /dev/null +++ b/src/utils/update_financial_reports_spider.py @@ -0,0 +1,202 @@ +""" +TO DO待完善全体数据更新逻辑 +""" + + +import time +import requests +import pandas as pd +from fastapi import APIRouter +from src.models.financial_reports import FinancialReport # 确保路径正确 + +async def combined_search_and_list(): + # 分页参数 + pageNo = 1 + pageSize = 50 + selector_url = "https://www.shidaotec.com/api/research/getFinaResearchPageList" + + # 请求第一页获取总记录数 + response = requests.get(selector_url, params={ + "pageNo": pageNo, + "pageSize": pageSize, + "year": "2024", + "type": "" + }) + response.raise_for_status() + selector_data = response.json() + + # 获取总记录数和数据库已有记录数 + total_records = selector_data.get("data", {}).get("total", 0) + db_record_count = await FinancialReport .all().count() + remaining_insert_count = total_records - db_record_count + + print(f"总记录数: {total_records},数据库已有记录数: {db_record_count},需要插入: {remaining_insert_count}") + + # 如果已有记录数与总记录数相同,则直接结束函数 + if remaining_insert_count <= 0: + print("数据库已包含所有记录,无需插入新数据。") + return + + # 获取数据库中最新的记录 + latest_record = await FinancialReport .all().order_by("-date").first() + latest_db_date = latest_record.date if latest_record else None + latest_db_date = pd.Timestamp(latest_db_date) + + # 缓存数据 + all_data = [] + + # 开始分页请求,直到遇到比数据库最新日期更早的数据 + while True: + try: + # 请求当前页数据 + response = requests.get(selector_url, params={ + "pageNo": pageNo, + "pageSize": pageSize, + "year": "2024", + "type": "" + }) + response.raise_for_status() + selector_data = response.json() + fina_data_list = selector_data.get("data", {}).get("list", []) + + # 如果页数据为空,停止循环 + if not fina_data_list: + print("No more data available.") + break + + # 缓存本页数据 + for record in fina_data_list: + record.pop("seq", None) # 移除不必要的字段 + record = {k: (None if v == "-" else v) for k, v in record.items()} + all_data.append(record) + + # 检查本页第一条数据日期,决定是否继续请求 + first_record_date = pd.to_datetime(fina_data_list[0]["date"]) + if latest_db_date and first_record_date < latest_db_date: + print(f"Data on page {pageNo} is older than the latest entry in the database. Stopping requests.") + break # 停止请求 + + print(f"第 {pageNo} 页成功获取并缓存 {len(fina_data_list)} 条记录") + pageNo += 1 + time.sleep(2) # 请求间隔 + + except requests.RequestException as e: + print(f"Error fetching data for page {pageNo}: {e}") + break + except Exception as e: + print(f"Unexpected error on page {pageNo}: {e}") + break + + # 批量查询数据库中已存在的记录 + existing_records = await FinancialReport .filter( + date__in=[record["date"] for record in all_data], + period__in=[record["period"] for record in all_data], + stock_code__in=[record["stockCode"] for record in all_data] + ).all() + + # 将已存在记录存入集合以便于后续检查 + existing_data_set = { + (record.date, record.period, record.stock_code, record.stock_name, + record.industry_code, record.industry_name, record.pdf_name, + record.pdf_url, record.revenue, record.revenue_yoy, + record.pr_of_today, record.pr_of_after_1_day, + record.pr_of_after_1_week, record.pr_of_after_1_month, + record.pr_of_after_6_month, record.pr_of_that_year, + record.nincome_yoy_lower_limit, record.nincome_yoy_upper_limit, + record.nincome, record.nincome_yoy) + for record in existing_records + } + + # 开始处理缓存的数据,插入数据库 + success_count = 0 + records_to_insert = [] + + for record in all_data[:]: # 使用切片拷贝以避免在迭代时修改原始列表 + record_date = pd.to_datetime(record["date"]) + + # 检查该记录是否已存在于数据库中 + if (record["date"], record["period"], record["stockCode"], record["stockName"], + record["industryCode"], record["industryName"], record["pdfName"], record["pdfUrl"], + record["revenue"], record["revenueYoy"], record["prOfToday"], record["prOfAfter1Day"], + record["prOfAfter1Week"], record["prOfAfter1Month"], record["prOfAfter6Month"], + record["prOfThatYear"], record["nincomeYoyLowerLimit"], record["nincomeYoyUpperLimit"], + record["nincome"], record["nincomeYoy"]) in existing_data_set: + print(f"Skipping existing record: {record}") # 打印已跳过的记录 + continue # 跳过已存在的记录 + + # 打印即将插入的数据 + print(f"Inserting record: {record}") + + if record_date >= latest_db_date: + # 创建待插入的 FinancialReport 对象并添加到列表 + try: + records_to_insert.append( + FinancialReport ( + date=record["date"], + period=record["period"], + year=int(record["year"]), + stock_code=record["stockCode"], + stock_name=record["stockName"], + industry_code=record["industryCode"] if record["industryCode"] else None, + industry_name=record["industryName"] if record["industryName"] else None, + pdf_name=record["pdfName"] if record["pdfName"] else None, + pdf_url=record["pdfUrl"] if record["pdfUrl"] else None, + revenue=float(record["revenue"]) if record["revenue"] else None, + revenue_yoy=float(record["revenueYoy"]) if record["revenueYoy"] else None, + pr_of_today=float(record["prOfToday"]) if record["prOfToday"] else None, + pr_of_after_1_day=float(record["prOfAfter1Day"]) if record["prOfAfter1Day"] else None, + pr_of_after_1_week=float(record["prOfAfter1Week"]) if record["prOfAfter1Week"] else None, + pr_of_after_1_month=float(record["prOfAfter1Month"]) if record["prOfAfter1Month"] else None, + pr_of_after_6_month=float(record["prOfAfter6Month"]) if record["prOfAfter6Month"] else None, + pr_of_that_year=float(record["prOfThatYear"]) if record["prOfThatYear"] else None, + nincome_yoy_lower_limit=record["nincomeYoyLowerLimit"] if record["nincomeYoyLowerLimit"] else None, + nincome_yoy_upper_limit=record["nincomeYoyUpperLimit"] if record["nincomeYoyUpperLimit"] else None, + nincome=float(record["nincome"]) if record["nincome"] else None, + nincome_yoy=float(record["nincomeYoy"]) if record["nincomeYoy"] else None + ) + ) + success_count += 1 + + except Exception as e: + print(f"Error processing record {record}: {e}") # 输出处理错误 + + # 成功插入所需条数后停止 + if success_count >= remaining_insert_count: + print("this riqi:" + str(latest_db_date)) + break + + # 批量插入到数据库 + if records_to_insert: + try: + await FinancialReport .bulk_create(records_to_insert) + print(f"成功插入 {len(records_to_insert)} 条记录") + except Exception as e: + print(f"Error during bulk insert: {e}") + + + # 去重步骤 + try: + # 查找基于 date、period 和 stock_code 的重复记录 + duplicates = await FinancialReport.raw(""" + SELECT id FROM financial_reports WHERE id NOT IN ( + SELECT id FROM ( + SELECT MAX(id) AS id + FROM financial_reports + GROUP BY date, period, stock_code + ) AS temp + ) + """) + + duplicate_ids = [record.id for record in duplicates] + + if duplicate_ids: + # 删除较旧的重复记录 + await FinancialReport.filter(id__in=duplicate_ids).delete() + print(f"删除了 {len(duplicate_ids)} 条重复记录") + else: + print("未发现重复记录") + + except Exception as e: + print(f"去重过程中出错: {e}") + + return {"message": "数据获取、插入和去重已完成"} \ No newline at end of file