新增识别关键词类工具类,新增根据股票代码生成简拼列工具类(待抽象),补全财报接口搜索和模糊搜索逻辑
This commit is contained in:
parent
cfa9ba98d7
commit
0a749ff31a
Binary file not shown.
@ -8,6 +8,18 @@ financial_reports_router = APIRouter()
|
|||||||
|
|
||||||
@financial_reports_router.post("/query")
|
@financial_reports_router.post("/query")
|
||||||
async def financial_repoets_query(request: FinancialReportQuery )-> JSONResponse:
|
async def financial_repoets_query(request: FinancialReportQuery )-> JSONResponse:
|
||||||
|
"""
|
||||||
|
搜索接口
|
||||||
|
"""
|
||||||
result = await financial_reports_query_service(request)
|
result = await financial_reports_query_service(request)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@financial_reports_router.post("/keyword_search")
|
||||||
|
async def financial_repoets_query_keyword_search(query: SearchKeywordQuery) -> list:
|
||||||
|
"""
|
||||||
|
模糊查询接口
|
||||||
|
"""
|
||||||
|
result = await user_keyword_search_service(query.searchKeyword)
|
||||||
|
|
||||||
return result
|
return result
|
@ -13,4 +13,7 @@ class FinancialReportQuery(BaseModel):
|
|||||||
revenueYoyType: Optional[int] = Field(None, ge=1, le=10, description="营业收入同比,必须大于等于1且小于等于10,或留空")
|
revenueYoyType: Optional[int] = Field(None, ge=1, le=10, description="营业收入同比,必须大于等于1且小于等于10,或留空")
|
||||||
nIncomeYoyType: Optional[int] = Field(None, ge=1, le=10, description="净利润同比,必须大于等于1且小于等于10,或留空")
|
nIncomeYoyType: Optional[int] = Field(None, ge=1, le=10, description="净利润同比,必须大于等于1且小于等于10,或留空")
|
||||||
revenueYoyTop10: Optional[int] = Field(None, ge=1, le=10, description="营业收入行业前10%,必须大于等于1且小于等于10,或留空")
|
revenueYoyTop10: Optional[int] = Field(None, ge=1, le=10, description="营业收入行业前10%,必须大于等于1且小于等于10,或留空")
|
||||||
|
searchKeyword: Optional[str] = Field(None, min_length=1, max_length=30, description="搜索关键词,可选,长度在1到30字符之间")
|
||||||
|
|
||||||
|
class SearchKeywordQuery(BaseModel):
|
||||||
|
searchKeyword: Optional[str] = Field(None, min_length=1, max_length=30, description="搜索关键词,可选,长度在1到30字符之间")
|
@ -1,8 +1,4 @@
|
|||||||
import json
|
|
||||||
import adata
|
|
||||||
|
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Any, Dict
|
|
||||||
from fastapi import Request, HTTPException, BackgroundTasks
|
from fastapi import Request, HTTPException, BackgroundTasks
|
||||||
from tortoise.expressions import Subquery,Q
|
from tortoise.expressions import Subquery,Q
|
||||||
from tortoise.contrib.pydantic import pydantic_model_creator
|
from tortoise.contrib.pydantic import pydantic_model_creator
|
||||||
@ -12,9 +8,11 @@ from src.models.stock_hu_shen300 import *
|
|||||||
from src.models.stock_zhong_zheng_500 import *
|
from src.models.stock_zhong_zheng_500 import *
|
||||||
from src.models.stock_guo_zheng_2000 import *
|
from src.models.stock_guo_zheng_2000 import *
|
||||||
from src.models.stock_hu_shen_jing_a import *
|
from src.models.stock_hu_shen_jing_a import *
|
||||||
|
from src.models.stock import *
|
||||||
from src.financial_reports.schemas import *
|
from src.financial_reports.schemas import *
|
||||||
from src.utils.paginations import *
|
from src.utils.paginations import *
|
||||||
from src.financial_reports.mappings import *
|
from src.financial_reports.mappings import *
|
||||||
|
from src.utils.identify_keyword_type import * #引入判断汉字,拼音,数字的工具类
|
||||||
|
|
||||||
|
|
||||||
# 创建一个不包含 "created_at" 字段的 Pydantic 模型用于响应
|
# 创建一个不包含 "created_at" 字段的 Pydantic 模型用于响应
|
||||||
@ -27,9 +25,40 @@ async def financial_reports_query_service(
|
|||||||
根据选择器和关键词查询财报,并进行分页
|
根据选择器和关键词查询财报,并进行分页
|
||||||
"""
|
"""
|
||||||
# 构建查询集,这里假设财报数据模型为 FinancialReport
|
# 构建查询集,这里假设财报数据模型为 FinancialReport
|
||||||
query_set: QuerySet = FinancialReport.all()
|
query_set: QuerySet = FinancialReport.all()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 判空处理,如果 searchKeyword 没有值,跳过关键词筛选逻辑
|
||||||
|
if not request.searchKeyword:
|
||||||
|
# 如果 searchKeyword 为空,跳过关键词筛选逻辑
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# 根据 searchKeyword 的类型选择不同的查询字段
|
||||||
|
search_keyword = request.searchKeyword
|
||||||
|
keyword_type = identify_keyword_type_simple(search_keyword)
|
||||||
|
|
||||||
|
# 根据类型选择相应的字段进行查询
|
||||||
|
if keyword_type == "pinyin":
|
||||||
|
# 调用 user_keyword_search_service 获取匹配的股票信息列表
|
||||||
|
matching_stocks = await user_keyword_search_service(search_keyword)
|
||||||
|
|
||||||
|
# 提取股票代码
|
||||||
|
matching_stock_codes = [stock["stock_code"] for stock in matching_stocks]
|
||||||
|
if matching_stock_codes:
|
||||||
|
# 根据返回的股票代码进行筛选
|
||||||
|
query_set = query_set.filter(stock_code__in=matching_stock_codes)
|
||||||
|
else:
|
||||||
|
# 如果没有匹配的股票代码,返回空查询
|
||||||
|
query_set = query_set.none()
|
||||||
|
elif keyword_type == "code":
|
||||||
|
# 精确查询股票代码字段
|
||||||
|
query_set = query_set.filter(stock_code=search_keyword)
|
||||||
|
elif keyword_type == "chinese":
|
||||||
|
# 模糊查询股票名称字段
|
||||||
|
query_set = query_set.filter(stock_name__icontains=search_keyword)
|
||||||
|
else:
|
||||||
|
raise ValueError("无效的关键词类型")
|
||||||
|
|
||||||
# 年度映射
|
# 年度映射
|
||||||
if request.year:
|
if request.year:
|
||||||
mapped_year = year_mapping.get(request.year)
|
mapped_year = year_mapping.get(request.year)
|
||||||
@ -120,6 +149,7 @@ async def financial_reports_query_service(
|
|||||||
query_set = query_set.order_by("-revenue_yoy").limit(limit_count).order_by("-revenue_yoy")
|
query_set = query_set.order_by("-revenue_yoy").limit(limit_count).order_by("-revenue_yoy")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 调用分页函数进行分页处理
|
# 调用分页函数进行分页处理
|
||||||
params = Params(page=request.pageNo, size=request.pageSize) #parms作用是传页码和页面大小
|
params = Params(page=request.pageNo, size=request.pageSize) #parms作用是传页码和页面大小
|
||||||
result = await pagination(pydantic_model=FinancialReport_Pydantic,
|
result = await pagination(pydantic_model=FinancialReport_Pydantic,
|
||||||
@ -131,4 +161,41 @@ async def financial_reports_query_service(
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="内部服务器错误")
|
raise HTTPException(status_code=500, detail="内部服务器错误")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def user_keyword_search_service(searchKeyword: Optional[str] = None) -> list:
|
||||||
|
"""
|
||||||
|
根据用户输入的关键词实时模糊匹配股票信息,支持股票代码、拼音简拼和汉字名称。
|
||||||
|
:param searchKeyword: 用户输入的关键词,支持股票代码、拼音简拼和汉字。
|
||||||
|
:return: 匹配的股票列表。
|
||||||
|
"""
|
||||||
|
if not searchKeyword:
|
||||||
|
# 如果搜索关键词为空,返回空列表
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 确定关键词类型(拼音简写、股票代码或汉字)
|
||||||
|
from src.utils.identify_keyword_type import identify_keyword_type_simple
|
||||||
|
keyword_type = identify_keyword_type_simple(searchKeyword)
|
||||||
|
|
||||||
|
# 构建查询集
|
||||||
|
query_set: QuerySet = Stock.all()
|
||||||
|
|
||||||
|
if keyword_type == "pinyin":
|
||||||
|
# 如果关键词是拼音简写,则匹配 stock_pinyin
|
||||||
|
query_set = query_set.filter(stock_pinyin__icontains=searchKeyword)
|
||||||
|
elif keyword_type == "code":
|
||||||
|
# 如果关键词是股票代码,则匹配 stock_code
|
||||||
|
query_set = query_set.filter(stock_code__icontains=searchKeyword)
|
||||||
|
elif keyword_type == "chinese":
|
||||||
|
# 如果关键词是汉字,则匹配 stock_name
|
||||||
|
query_set = query_set.filter(stock_name__icontains=searchKeyword)
|
||||||
|
else:
|
||||||
|
# 如果关键词类型无法识别,返回空列表
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取匹配的股票列表,限制返回的数量(例如最多返回 20 条数据)
|
||||||
|
matching_stocks = await query_set.values("stock_name","stock_code")
|
||||||
|
|
||||||
|
return matching_stocks
|
@ -22,6 +22,7 @@ from src.combination.router import router as combine_router
|
|||||||
from src.models.financial_reports import FinancialReport
|
from src.models.financial_reports import FinancialReport
|
||||||
from src.utils.update_financial_reports_spider import combined_search_and_list
|
from src.utils.update_financial_reports_spider import combined_search_and_list
|
||||||
from src.financial_reports.router import financial_reports_router
|
from src.financial_reports.router import financial_reports_router
|
||||||
|
from src.utils.generate_pinyin_abbreviation import generate_pinyin_abbreviation
|
||||||
|
|
||||||
from xtquant import xtdata
|
from xtquant import xtdata
|
||||||
from src.settings.config import app_configs, settings
|
from src.settings.config import app_configs, settings
|
||||||
@ -89,7 +90,6 @@ async def test_liang_hua_ku():
|
|||||||
print("量化库测试函数启动")
|
print("量化库测试函数启动")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 启动时立即运行检查任务
|
# 启动时立即运行检查任务
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def lifespan():
|
async def lifespan():
|
||||||
|
14
src/utils/generate_pinyin_abbreviation.py
Normal file
14
src/utils/generate_pinyin_abbreviation.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
from src.models.stock import Stock
|
||||||
|
|
||||||
|
async def generate_pinyin_abbreviation():
|
||||||
|
"""
|
||||||
|
更新数据库中所有股票记录的拼音缩写,如果字段为空则生成并保存。
|
||||||
|
"""
|
||||||
|
stocks = await Stock.filter(stock_pinyin__isnull=True) # 查找拼音字段为空的记录
|
||||||
|
for stock in stocks:
|
||||||
|
if stock.stock_name: # 确保股票名称存在
|
||||||
|
# 将每个汉字的第一个字母转为小写,并合并为拼音缩写
|
||||||
|
stock.stock_pinyin = ''.join(lazy_pinyin(stock.stock_name, style=Style.FIRST_LETTER)).lower()
|
||||||
|
await stock.save() # 保存拼音缩写到数据库
|
||||||
|
print("拼音更新完成")
|
19
src/utils/identify_keyword_type.py
Normal file
19
src/utils/identify_keyword_type.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
判断关键词类的工具类,选中第一个字符做判断
|
||||||
|
"""
|
||||||
|
def identify_keyword_type_simple(keyword: str) -> str:
|
||||||
|
if not keyword:
|
||||||
|
return "unknown" # 如果字符串为空,返回 "unknown"
|
||||||
|
|
||||||
|
first_char = keyword[0]
|
||||||
|
|
||||||
|
# 判断第一个字符
|
||||||
|
if '\u4e00' <= first_char <= '\u9fa5': # 汉字
|
||||||
|
return "chinese"
|
||||||
|
elif first_char.isalpha(): # 字母
|
||||||
|
return "pinyin"
|
||||||
|
elif first_char.isdigit(): # 数字
|
||||||
|
return "code"
|
||||||
|
else:
|
||||||
|
return "unknown"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user