200 lines
9.5 KiB
Python
200 lines
9.5 KiB
Python
from math import ceil
|
||
from fastapi import Request, HTTPException, BackgroundTasks
|
||
from tortoise.expressions import Subquery,Q
|
||
from tortoise.contrib.pydantic import pydantic_model_creator
|
||
|
||
from src.models.financial_reports import *
|
||
from src.models.stock_hu_shen300 import *
|
||
from src.models.stock_zhong_zheng_500 import *
|
||
from src.models.stock_guo_zheng_2000 import *
|
||
from src.models.stock_hu_shen_jing_a import *
|
||
from src.models.stock import *
|
||
from src.financial_reports.schemas import *
|
||
from src.utils.paginations import *
|
||
from src.financial_reports.mappings import *
|
||
from src.utils.identify_keyword_type import * #引入判断汉字,拼音,数字的工具类
|
||
|
||
|
||
# 创建一个不包含 "created_at" 字段的 Pydantic 模型用于响应
|
||
FinancialReport_Pydantic = pydantic_model_creator(FinancialReport, exclude=("created_at",))
|
||
|
||
async def financial_reports_query_service(
|
||
request: FinancialReportQuery
|
||
) -> PaginationPydantic:
|
||
"""
|
||
根据选择器和关键词查询财报,并进行分页
|
||
"""
|
||
# 构建查询集,这里假设财报数据模型为 FinancialReport
|
||
query_set: QuerySet = FinancialReport.all()
|
||
|
||
try:
|
||
# 判空处理,如果 searchKeyword 没有值,跳过关键词筛选逻辑
|
||
if not request.searchKeyword:
|
||
pass
|
||
else:
|
||
# 根据 searchKeyword 的类型选择不同的查询字段
|
||
search_keyword = request.searchKeyword
|
||
keyword_type = identify_keyword_type_simple(search_keyword)
|
||
|
||
# 根据类型选择相应的字段进行查询
|
||
if keyword_type == "pinyin":
|
||
# 调用 user_keyword_search_service 获取匹配的股票信息列表
|
||
matching_stocks = await user_keyword_search_service(search_keyword)
|
||
|
||
# 提取股票代码
|
||
matching_stock_codes = [stock["stock_code"] for stock in matching_stocks]
|
||
if matching_stock_codes:
|
||
# 根据返回的股票代码进行筛选
|
||
query_set = query_set.filter(stock_code__in=matching_stock_codes)
|
||
else:
|
||
# 如果没有匹配的股票代码,返回空查询
|
||
query_set = query_set.none()
|
||
elif keyword_type == "code":
|
||
# 模糊查询代码
|
||
query_set = query_set.filter(stock_code__icontains=search_keyword)
|
||
elif keyword_type == "chinese":
|
||
# 模糊查询股票名称字段
|
||
query_set = query_set.filter(stock_name__icontains=search_keyword)
|
||
else:
|
||
raise ValueError("无效的关键词类型")
|
||
|
||
# 年度映射
|
||
if request.year:
|
||
mapped_year = year_mapping.get(request.year)
|
||
if mapped_year is not None:
|
||
query_set = query_set.filter(year=mapped_year)
|
||
else:
|
||
raise ValueError("无效的年份参数")
|
||
|
||
# 股票池映射
|
||
if request.stockPoolCode is None:
|
||
# 未提供股票池代码,返回所有记录
|
||
pass
|
||
else:
|
||
# 获取映射值,如果不存在则返回 None
|
||
mapped_stock_pool = stock_pool_mapping.get(request.stockPoolCode)
|
||
|
||
# 检查 mapped_stock_pool 是否为 None(即不在映射表中)
|
||
if mapped_stock_pool is None:
|
||
# 如果 stockPoolCode 不在映射表中,抛出无效股票池参数错误
|
||
raise ValueError("无效的股票池参数")
|
||
|
||
# 如果存在有效的映射值,执行相应的过滤
|
||
elif mapped_stock_pool == "000300.SH":
|
||
subquery = await StockHuShen300.filter().values_list('code', flat=True)
|
||
query_set = query_set.filter(stock_code__in=subquery)
|
||
elif mapped_stock_pool == "000905.SH":
|
||
subquery = await StockZhongZheng500.filter().values_list('code', flat=True)
|
||
query_set = query_set.filter(stock_code__in=subquery)
|
||
elif mapped_stock_pool == "399303.SZ":
|
||
subquery = await StockGuoZheng2000.filter().values_list('code', flat=True)
|
||
query_set = query_set.filter(stock_code__in=subquery)
|
||
elif mapped_stock_pool in ["large_mv", "medium_mv", "small_mv"]:
|
||
# 先获取所有有总市值的关联的股票代码
|
||
subquery = await StockHuShenJingA.filter(total_market_value__isnull=False).values_list('code', 'total_market_value')
|
||
|
||
# 转换为 DataFrame 或者直接进行排序
|
||
stock_list = sorted(subquery, key=lambda x: x[1], reverse=True) # 按 total_market_value 降序排序
|
||
|
||
total_count = len(stock_list)
|
||
|
||
if mapped_stock_pool == "large_mv":
|
||
# 获取前 30% 的数据
|
||
limit_count = ceil(total_count * 0.3)
|
||
selected_stocks = [stock[0] for stock in stock_list[:limit_count]]
|
||
elif mapped_stock_pool == "medium_mv":
|
||
# 获取中间 40% 的数据
|
||
start_offset = ceil(total_count * 0.3)
|
||
limit_count = ceil(total_count * 0.4)
|
||
selected_stocks = [stock[0] for stock in stock_list[start_offset:start_offset + limit_count]]
|
||
elif mapped_stock_pool == "small_mv":
|
||
# 获取后 30% 的数据
|
||
start_offset = ceil(total_count * 0.7)
|
||
selected_stocks = [stock[0] for stock in stock_list[start_offset:]]
|
||
|
||
# 对 FinancialReport 表进行筛选
|
||
query_set = query_set.filter(stock_code__in=selected_stocks)
|
||
|
||
# 财报周期映射
|
||
if request.period is not None:
|
||
mapped_period = period_mapping.get(request.period)
|
||
if mapped_period is not None:
|
||
# 如果映射到有效的周期,则进行过滤
|
||
query_set = query_set.filter(period=mapped_period)
|
||
else:
|
||
# 如果找不到有效的映射,抛出错误
|
||
raise ValueError("无效的财报周期参数")
|
||
|
||
# 检查是否所有筛选条件都为空
|
||
if not request.revenueYoyType and not request.nIncomeYoyType and not request.revenueYoyTop10:
|
||
# 如果全部为空,则按 date 和 created_at 降序排序
|
||
query_set = query_set.order_by("-date", "-created_at")
|
||
|
||
# 筛选 revenueYoyType,如果是 1 则筛选 revenue_yoy 大于 10.0 的记录
|
||
if request.revenueYoyType == 1:
|
||
query_set = query_set.filter(revenue_yoy__gt=10.0).order_by("-revenue_yoy")
|
||
|
||
# 筛选 nIncomeYoyType,如果是 1 则筛选 nincome_yoy 大于 10.0 的记录
|
||
elif request.nIncomeYoyType == 1:
|
||
query_set = query_set.filter(nincome_yoy__gt=10.0).order_by("-nincome_yoy")
|
||
|
||
# 如果 revenueYoyTop10 为 1,则筛选 revenue_yoy 前 10% 的记录
|
||
elif request.revenueYoyTop10 == 1:
|
||
# 计算前 10% 的数量
|
||
total_count = await FinancialReport.all().count() # 获取总记录数
|
||
limit_count = ceil(total_count * 0.1)
|
||
|
||
# 按 revenue_yoy 降序排列,获取前 10% 记录
|
||
query_set = query_set.order_by("-revenue_yoy").limit(limit_count).order_by("-revenue_yoy")
|
||
|
||
|
||
|
||
# 调用分页函数进行分页处理
|
||
params = Params(page=request.pageNo, size=request.pageSize) #parms作用是传页码和页面大小
|
||
result = await pagination(pydantic_model=FinancialReport_Pydantic,
|
||
query_set=query_set,
|
||
params=params)
|
||
|
||
return result
|
||
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail="内部服务器错误")
|
||
|
||
|
||
|
||
async def user_keyword_search_service(searchKeyword: Optional[str] = None) -> list:
|
||
"""
|
||
根据用户输入的关键词实时模糊匹配股票信息,支持股票代码、拼音简拼和汉字名称。
|
||
:param searchKeyword: 用户输入的关键词,支持股票代码、拼音简拼和汉字。
|
||
:return: 匹配的股票列表。
|
||
"""
|
||
if not searchKeyword:
|
||
# 如果搜索关键词为空,返回空列表
|
||
return []
|
||
|
||
# 确定关键词类型(拼音简写、股票代码或汉字)
|
||
from src.utils.identify_keyword_type import identify_keyword_type_simple
|
||
keyword_type = identify_keyword_type_simple(searchKeyword)
|
||
|
||
# 构建查询集
|
||
query_set: QuerySet = Stock.all()
|
||
|
||
if keyword_type == "pinyin":
|
||
# 如果关键词是拼音简写,则匹配 stock_pinyin
|
||
query_set = query_set.filter(stock_pinyin__icontains=searchKeyword)
|
||
elif keyword_type == "code":
|
||
# 如果关键词是股票代码,则匹配 stock_code
|
||
query_set = query_set.filter(stock_code__icontains=searchKeyword)
|
||
elif keyword_type == "chinese":
|
||
# 如果关键词是汉字,则匹配 stock_name
|
||
query_set = query_set.filter(stock_name__icontains=searchKeyword)
|
||
else:
|
||
# 如果关键词类型无法识别,返回空列表
|
||
return []
|
||
|
||
# 获取匹配的股票列表,限制返回的数量(例如最多返回 20 条数据)
|
||
matching_stocks = await query_set.values("stock_name","stock_code")
|
||
|
||
return matching_stocks |