后端代码

This commit is contained in:
xlmessage 2025-02-12 19:12:07 +08:00
parent 023cc64d20
commit 62e0ab2da7
67 changed files with 1414 additions and 485 deletions

@ -1,12 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="wangche@lawyer5.cn" uuid="7e849454-3537-491d-a390-aa5024bc7f42">
<driver-ref>mysql.8</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>com.mysql.cj.jdbc.Driver</jdbc-driver>
<jdbc-url>jdbc:mysql://lawyer5.cn:3308/wangche</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="wangche@lawyer5.cn" uuid="7e849454-3537-491d-a390-aa5024bc7f42">
<driver-ref>mysql.8</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>com.mysql.cj.jdbc.Driver</jdbc-driver>
<jdbc-url>jdbc:mysql://lawyer5.cn:3308/wangche</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
<data-source source="LOCAL" name="@localhost" uuid="7b52e754-e1c2-4a7c-ae09-3c939af31a22">
<driver-ref>mysql.8</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>com.mysql.cj.jdbc.Driver</jdbc-driver>
<jdbc-url>jdbc:mysql://localhost:3306</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>

@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="Python 3.7" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (wance_data)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (testenv) (2)" project-jdk-type="Python SDK" />
</project>

@ -4,7 +4,7 @@
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/venv" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="jdk" jdkName="Python 3.11 (testenv) (2)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PackageRequirementsSettings">

Binary file not shown.

915
src/akshare_data/router.py Normal file

@ -0,0 +1,915 @@
from fastapi import APIRouter, Query, FastAPI, Body
from starlette.middleware.cors import CORSMiddleware
import akshare as ak
import pymysql
import json
from pydantic import BaseModel
from typing import List, Dict
import math
router = APIRouter() # 创建一个 FastAPI 路由器实例
# 数据库测试
@router.get("/userstrategy")
async def userstrategy():
# 创建数据库连接
conn = pymysql.connect(
host='cqxqg.tech', # MySQL服务器地址
user='wangche', # MySQL用户名
password='fN7sXX8saiQKXWbG', # MySQL密码
database='wangche', # 要连接的数据库名
port=3308,
charset='utf8mb4', # 字符集,确保支持中文等
cursorclass=pymysql.cursors.DictCursor # 使用字典形式返回结果
)
# 使用 cursor() 方法创建一个游标对象 cursor
cursor = conn.cursor()
# 使用 execute()方法执行 SQL 查询
# 通配符 *,意思是查询表里所有内容
cursor.execute("select * from user_strategy where user_id = 100096")
# 使用 fetchone() 方法获取一行数据.
# data = cursor.fetchone()
data = cursor.fetchall()
# print(data)
# 关闭数据库连接
cursor.close()
# 将strategy_request字段中的JSON字符串转换为Python字典
for i in range(0, len(data)):
strategy_request = data[i]['strategy_request']
# 将JSON字符串转换为Python字典
data_dict = json.loads(strategy_request)
data[i]['strategy_request'] = data_dict
return data
# 定义Pydantic模型
class InfoItem(BaseModel):
code: str
name: str
market: str
newprice: float
amplitudetype: bool
amplitude: float
type: str
class MessageItem(BaseModel):
id: str
label: str
value: str
state: bool
info: List[InfoItem]
inputClass: str
fixchange: bool
class StrategyRequest(BaseModel):
id: int
strategy_name: str
message: List[MessageItem] or None
info: List[InfoItem] or None
class MyData(BaseModel):
mes: List[StrategyRequest]
# 新加的new_user_strategy数据库表
@router.post("/newuserstrategy")
async def newuserstrategy(strategy: StrategyRequest = Body(...)):
# 创建数据库连接
conn = pymysql.connect(
host='cqxqg.tech', # MySQL服务器地址
user='wangche', # MySQL用户名
password='fN7sXX8saiQKXWbG', # MySQL密码
database='wangche', # 要连接的数据库名
port=3308,
charset='utf8mb4', # 字符集,确保支持中文等
cursorclass=pymysql.cursors.DictCursor # 使用字典形式返回结果
)
# 使用 cursor() 方法创建一个游标对象 cursor
cursor = conn.cursor()
# cursor1 = conn.cursor()
# cursor2 = conn.cursor()
# --------------------new_user_strategy数据------------------
# SQL 查询语句
sql1 = "SELECT strategy_name FROM new_user_strategy WHERE id = %s"
# 执行查询
cursor.execute(sql1, ('100096',))
# 获取所有数据
result1 = cursor.fetchall()
# 提取 strategy_name 列并转换为列表
strategy_names1 = [row['strategy_name'] for row in result1]
# print(strategy_names1)
# --------------------user_strategy数据------------------
# SQL 查询语句
sql2 = "SELECT strategy_name FROM user_strategy WHERE user_id = %s"
# 执行查询
cursor.execute(sql2, ('100096',))
# 获取所有数据
result2 = cursor.fetchall()
# 提取 strategy_name 列并转换为列表
strategy_names2 = [row['strategy_name'] for row in result2]
# print(strategy_names2)
# --------------------获取整个请求数据--------------------
request_data = strategy.dict()
# print(request_data)
# 准备SQL插入语句注意没有包含id列因为它可能是自动递增的
sql = "INSERT INTO new_user_strategy (id, strategy_name, message, info) VALUES (%s, %s, %s, %s)"
# 要插入的数据(确保数据类型与数据库表列匹配)
# 将message和info转换为JSON字符串
import json
message_json = json.dumps(request_data['message'])
info_json = json.dumps(request_data['info'])
values = (request_data['id'], request_data['strategy_name'], message_json, info_json)
# 是否进行写入数据库表中
set_strategy_names1 = set(strategy_names1)
set_strategy_names2 = set(strategy_names2)
if request_data['strategy_name'] in strategy_names1:
# print("信息已存在")
conn.close()
else:
# 执行SQL语句
cursor.execute(sql, values)
# 提交事务到数据库执行
conn.commit()
print("数据插入成功")
# 关闭数据库连接
conn.close()
return {"message": "数据插入成功"}
# for i in range(0, len(result2)):
# if result2[i] not in strategy_names1:
# # 执行SQL语句
# cursor.execute(sql, values)
# # 提交事务到数据库执行
# conn.commit()
# print("数据插入成功")
# # 关闭数据库连接
# conn.close()
# return {"message": "数据插入成功"}
# conn.close()
@router.post("/newadd")
async def newadd(strategy: StrategyRequest = Body(...)):
conn = pymysql.connect(
host='cqxqg.tech', # MySQL服务器地址
user='wangche', # MySQL用户名
password='fN7sXX8saiQKXWbG', # MySQL密码
database='wangche', # 要连接的数据库名
port=3308,
charset='utf8mb4', # 字符集,确保支持中文等
cursorclass=pymysql.cursors.DictCursor # 使用字典形式返回结果
)
# 使用 cursor() 方法创建一个游标对象 cursor
cursor = conn.cursor()
# SQL 查询语句
sql1 = "SELECT strategy_name FROM new_user_strategy WHERE id = %s"
# 执行查询
cursor.execute(sql1, ('100096',))
# 获取所有数据
result1 = cursor.fetchall()
# 获取整个请求数据
request_data = strategy.dict()
print(request_data)
import json
message_json = json.dumps(request_data['message'])
info_json = json.dumps(request_data['info'])
print(result1)
for item in result1:
if request_data['strategy_name'] == item['strategy_name']:
return {
"code": 204,
"message": "该分组已经存在"
}
sql = "INSERT INTO new_user_strategy (id, strategy_name, message, info) VALUES (%s, %s, %s, %s)"
# # 执行 SQL
cursor.execute(sql, (request_data["id"], request_data['strategy_name'], message_json, info_json))
#
# # 提交事务到数据库执行
conn.commit()
# print("更新数据成功")
# 关闭数据库连接
conn.close()
return {
"code": 200,
"message": "新建分组成功!"
}
# 获取数据
@router.get("/newget")
async def newget():
# 创建数据库连接
conn = pymysql.connect(
host='cqxqg.tech', # MySQL服务器地址
user='wangche', # MySQL用户名
password='fN7sXX8saiQKXWbG', # MySQL密码
database='wangche', # 要连接的数据库名
port=3308,
charset='utf8mb4', # 字符集,确保支持中文等
cursorclass=pymysql.cursors.DictCursor # 使用字典形式返回结果
)
# 使用 cursor() 方法创建一个游标对象 cursor
cursor = conn.cursor()
# 使用 execute()方法执行 SQL 查询
# 通配符 *,意思是查询表里所有内容
cursor.execute("select * from new_user_strategy where id = 100096")
# 使用 fetchone() 方法获取一行数据.
# data = cursor.fetchone()
data = cursor.fetchall()
# print(data)
# 关闭数据库连接
cursor.close()
# 将strategy_request字段中的JSON字符串转换为Python字典
for i in range(0, len(data)):
strategy_request1 = data[i]['message']
strategy_request2 = data[i]['info']
# 将JSON字符串转换为Python字典
data_dict1 = json.loads(strategy_request1)
data_dict2 = json.loads(strategy_request2)
data[i]['message'] = data_dict1
data[i]['info'] = data_dict2
return {
"code": 200,
"data": data
}
# 新增分组
@router.post("/newupdata")
async def newupdata(strategy: StrategyRequest = Body(...)):
# 创建数据库连接
conn = pymysql.connect(
host='cqxqg.tech', # MySQL服务器地址
user='wangche', # MySQL用户名
password='fN7sXX8saiQKXWbG', # MySQL密码
database='wangche', # 要连接的数据库名
port=3308,
charset='utf8mb4', # 字符集,确保支持中文等
cursorclass=pymysql.cursors.DictCursor # 使用字典形式返回结果
)
# 使用 cursor() 方法创建一个游标对象 cursor
cursor = conn.cursor()
# 获取整个请求数据
request_data = strategy.dict()
print(request_data)
import json
message_json = json.dumps(request_data['message'])
info_json = json.dumps(request_data['info'])
# SQL 语句
sql = """
UPDATE new_user_strategy
SET message = %s, info = %s
WHERE strategy_name = %s;
"""
# 执行 SQL
cursor.execute(sql, (message_json, info_json, request_data['strategy_name']))
# 提交事务到数据库执行
conn.commit()
print("更新数据成功")
# 关闭数据库连接
conn.close()
return "更新成功"
class delItem(BaseModel):
strategy_name: str
# 删除分组
@router.post("/newdel")
async def newdel(delitem: delItem):
delitem = delitem.strategy_name
# 创建数据库连接
conn = pymysql.connect(
host='cqxqg.tech', # MySQL服务器地址
user='wangche', # MySQL用户名
password='fN7sXX8saiQKXWbG', # MySQL密码
database='wangche', # 要连接的数据库名
port=3308,
charset='utf8mb4', # 字符集,确保支持中文等
cursorclass=pymysql.cursors.DictCursor # 使用字典形式返回结果
)
# 使用 cursor() 方法创建一个游标对象 cursor
cursor = conn.cursor()
sql1 = "DELETE FROM new_user_strategy WHERE strategy_name = %s"
sql2 = "DELETE FROM user_strategy WHERE strategy_name = %s"
cursor.execute(sql1, (delitem,))
cursor.execute(sql2, (delitem,))
# 提交事务到数据库执行
conn.commit()
print("数据删除成功")
# 关闭数据库连接
conn.close()
return "删除成功"
@router.get("/newmodify")
async def newmodify(
strategy_name: str = Query(..., description="原始值"),
new_strategy_name: str = Query(..., description="更改值")
):
print(strategy_name)
print(new_strategy_name)
# return "success"
# pass
# 创建数据库连接
conn = pymysql.connect(
host='cqxqg.tech', # MySQL服务器地址
user='wangche', # MySQL用户名
password='fN7sXX8saiQKXWbG', # MySQL密码
database='wangche', # 要连接的数据库名
port=3308,
charset='utf8mb4', # 字符集,确保支持中文等
cursorclass=pymysql.cursors.DictCursor # 使用字典形式返回结果
)
# 使用 cursor() 方法创建一个游标对象 cursor
cursor = conn.cursor()
# 更新 strategy_name
update_sql = "UPDATE new_user_strategy SET strategy_name = %s WHERE strategy_name = %s"
cursor.execute(update_sql, (new_strategy_name, strategy_name))
# 提交事务到数据库执行
conn.commit()
print("重命名成功")
# 关闭数据库连接
conn.close()
return "重命名成功"
# 侧边栏webview数据
@router.post("/asidestrinfo/")
async def asidestrinfo():
pass
# 股票数据
stock_data = None
@router.get("/stock")
async def stock(
symbol: str = Query(..., description="股票代码"),
start_date: str = Query(..., description="起始日期"),
end_date: str = Query(..., description="结束日期"),
):
# 获取股票日线行情数据
# print(symbol, start_date, end_date)
# print(symbol)
global stock_data
try:
stock_zh_a_daily_df = ak.stock_zh_a_daily(symbol=symbol, start_date=start_date, end_date=end_date, adjust="qfq")
# 获取所有的code
all_dates = stock_zh_a_daily_df['date']
# 如果你想要一个列表而不是Pandas Series
dates_list = all_dates.tolist()
all_opens = stock_zh_a_daily_df['open']
opens_list = all_opens.tolist()
cleaned_opens_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in opens_list
]
all_closes = stock_zh_a_daily_df['close']
close_list = all_closes.tolist()
cleaned_close_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in close_list
]
all_highs = stock_zh_a_daily_df['high']
high_list = all_highs.tolist()
cleaned_high_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in high_list
]
all_lows = stock_zh_a_daily_df['low']
low_list = all_lows.tolist()
cleaned_low_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in low_list
]
all_volumes = stock_zh_a_daily_df['volume']
volume_list = all_volumes.tolist()
cleaned_volume_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in volume_list
]
all_amounts = stock_zh_a_daily_df['amount']
amount_lists = all_amounts.tolist()
cleaned_amount_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in amount_lists
]
global stock_data
stock_data = {
"amount": cleaned_amount_list,
"close": cleaned_close_list,
"date": dates_list,
"high": cleaned_high_list,
"low": cleaned_low_list,
"open": cleaned_opens_list,
"outstanding_share": [],
"turnover": [],
"volume": cleaned_volume_list
}
except Exception as e:
print(e)
print("无法使用该方式请求股票数据当前股票可能不是A股")
stock_data = {
"amount": [],
"close": [],
"date": [],
"high": [],
"low": [],
"open": [],
"outstanding_share": [],
"turnover": [],
"volume": []
}
finally:
return {"message": stock_data}
# 前端获取数据接口
@router.get("/kdata")
async def kdata():
global stock_data
if stock_data is None:
stock_data = {
"amount": [],
"close": [],
"date": [],
"high": [],
"low": [],
"open": [],
"outstanding_share": [],
"turnover": [],
"volume": []
}
return {"message": stock_data}
else:
return {"message": stock_data}
# ---------------------------------------------------------------------
# 港股代码数据
@router.get("/ganggudata")
async def ganggudata():
stock_hk_spot_em_df = ak.stock_hk_spot_em()
# print(stock_hk_spot_em_df)
# 获取所有的code
all_codes = stock_hk_spot_em_df['代码']
# 如果你想要一个列表而不是Pandas Series
codes_list = all_codes.tolist()
all_names = stock_hk_spot_em_df['名称']
names_list = all_names.tolist()
all_prices = stock_hk_spot_em_df['最新价']
price_list = all_prices.tolist()
# 清理非法浮点数值NaN, Infinity, -Infinity
cleaned_price_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in price_list
]
all_amplitudes = stock_hk_spot_em_df['涨跌幅']
amplitudes_list = all_amplitudes.tolist()
# 清理非法浮点数值NaN, Infinity, -Infinity
cleaned_amplitudes_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in amplitudes_list
]
# 返回的数据
ggstocking = []
for i in range(9):
if cleaned_price_list[i] >= 0:
flag = True
else:
flag = False
ggstocking.append({
'code': codes_list[i],
'name': names_list[i],
'market': '港股',
'newprice': cleaned_price_list[i],
'amplitudetype': flag,
'amplitude': cleaned_amplitudes_list[i],
'type': 'ganggu'
})
# 返回清理后的列表
return ggstocking
# 港股K线图历史数据
@router.get("/ganggudataK")
async def ganggudataK(
symbol: str = Query(..., description="股票代码"),
start_date: str = Query(..., description="起始日期"),
end_date: str = Query(..., description="结束日期"),
):
try:
stock_hk_hist_df = ak.stock_hk_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date,
adjust="")
# 获取所有的code
all_dates = stock_hk_hist_df['日期']
# 如果你想要一个列表而不是Pandas Series
dates_list = all_dates.tolist()
all_opens = stock_hk_hist_df['开盘']
opens_list = all_opens.tolist()
cleaned_opens_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in opens_list
]
all_closes = stock_hk_hist_df['收盘']
close_list = all_closes.tolist()
cleaned_close_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in close_list
]
all_highs = stock_hk_hist_df['最高']
high_list = all_highs.tolist()
cleaned_high_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in high_list
]
all_lows = stock_hk_hist_df['最低']
low_list = all_lows.tolist()
cleaned_low_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in low_list
]
all_volumes = stock_hk_hist_df['成交量']
volume_list = all_volumes.tolist()
cleaned_volume_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in volume_list
]
all_amounts = stock_hk_hist_df['成交额']
amount_list = all_amounts.tolist()
cleaned_amount_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in amount_list
]
global stock_data
stock_data = {
"amount": cleaned_amount_list,
"close": cleaned_close_list,
"date": dates_list,
"high": cleaned_high_list,
"low": cleaned_low_list,
"open": cleaned_opens_list,
"outstanding_share": [],
"turnover": [],
"volume": cleaned_volume_list
}
except Exception as e:
print(e)
stock_data = {
"amount": [],
"close": [],
"date": [],
"high": [],
"low": [],
"open": [],
"outstanding_share": [],
"turnover": [],
"volume": []
}
finally:
return {"message": stock_data}
# ---------------------------------------------------------------------
# 美股代码数据
@router.get("/meigudata")
async def meigudata():
stock_us_spot_em_df = ak.stock_us_spot_em()
# print(stock_us_spot_em_df)
all_codes = stock_us_spot_em_df['代码']
# 如果你想要一个列表而不是Pandas Series
codes_list = all_codes.tolist()
all_names = stock_us_spot_em_df['名称']
names_list = all_names.tolist()
all_prices = stock_us_spot_em_df['最新价']
price_list = all_prices.tolist()
# 清理非法浮点数值NaN, Infinity, -Infinity
cleaned_price_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in price_list
]
all_amplitudes = stock_us_spot_em_df['涨跌幅']
amplitudes_list = all_amplitudes.tolist()
# 清理非法浮点数值NaN, Infinity, -Infinity
cleaned_amplitudes_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in amplitudes_list
]
# 返回的数据
mgstocking = []
for i in range(9):
if cleaned_price_list[i] >= 0:
flag = True
else:
flag = False
mgstocking.append({
'code': codes_list[i],
'name': names_list[i],
'market': '港股',
'newprice': cleaned_price_list[i],
'amplitudetype': flag,
'amplitude': cleaned_amplitudes_list[i],
'type': 'meigu'
})
# 返回清理后的列表
return mgstocking
# 美股K线图历史数据
@router.get("/meigudataK")
async def meigudataK(
symbol: str = Query(..., description="股票代码"),
start_date: str = Query(..., description="起始日期"),
end_date: str = Query(..., description="结束日期"),
):
try:
stock_us_hist_df = ak.stock_us_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date,
adjust="qfq")
# 获取所有的code
all_dates = stock_us_hist_df['日期']
# 如果你想要一个列表而不是Pandas Series
dates_list = all_dates.tolist()
all_opens = stock_us_hist_df['开盘']
opens_list = all_opens.tolist()
cleaned_opens_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in opens_list
]
all_closes = stock_us_hist_df['收盘']
close_list = all_closes.tolist()
cleaned_close_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in close_list
]
all_highs = stock_us_hist_df['最高']
high_list = all_highs.tolist()
cleaned_high_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in high_list
]
all_lows = stock_us_hist_df['最低']
low_list = all_lows.tolist()
cleaned_low_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in low_list
]
all_volumes = stock_us_hist_df['成交量']
volume_list = all_volumes.tolist()
cleaned_volume_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in volume_list
]
all_amounts = stock_us_hist_df['成交额']
amount_list = all_amounts.tolist()
cleaned_amount_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in amount_list
]
global stock_data
stock_data = {
"amount": cleaned_amount_list,
"close": cleaned_close_list,
"date": dates_list,
"high": cleaned_high_list,
"low": cleaned_low_list,
"open": cleaned_opens_list,
"outstanding_share": [],
"turnover": [],
"volume": cleaned_volume_list
}
except Exception as e:
print(e)
stock_data = {
"amount": [],
"close": [],
"date": [],
"high": [],
"low": [],
"open": [],
"outstanding_share": [],
"turnover": [],
"volume": []
}
finally:
return {"message": stock_data}
# ---------------------------------------------------------------------
# 沪深代码数据
@router.get("/hushendata")
async def hushendata():
try:
stock_zh_a_spot_df = ak.stock_kc_a_spot_em()
except Exception as e:
print(e)
# print(stock_zh_a_spot_df)
all_codes = stock_zh_a_spot_df['代码']
# 如果你想要一个列表而不是Pandas Series
codes_list = all_codes.tolist()
all_names = stock_zh_a_spot_df['名称']
names_list = all_names.tolist()
all_prices = stock_zh_a_spot_df['最新价']
price_list = all_prices.tolist()
# 清理非法浮点数值NaN, Infinity, -Infinity
cleaned_price_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in price_list
]
all_amplitudes = stock_zh_a_spot_df['涨跌幅']
amplitudes_list = all_amplitudes.tolist()
# 清理非法浮点数值NaN, Infinity, -Infinity
cleaned_amplitudes_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in amplitudes_list
]
# 返回的数据
hsstocking = []
# for i in range(len(codes_list)):
# if cleaned_price_list[i] >= 0:
# flag = True
# else:
# flag = False
# hsstocking.append({
# 'code': codes_list[i],
# 'name': names_list[i],
# 'market': '港股',
# 'newprice': cleaned_price_list[i],
# 'amplitudetype': flag,
# 'amplitude': cleaned_amplitudes_list[i],
# })
for i in range(9):
if cleaned_price_list[i] >= 0:
flag = True
else:
flag = False
hsstocking.append({
'code': codes_list[i],
'name': names_list[i],
'market': '港股',
'newprice': cleaned_price_list[i],
'amplitudetype': flag,
'amplitude': cleaned_amplitudes_list[i],
'type': 'hushen'
})
# 返回清理后的列表
return hsstocking
@router.get("/hushendataK")
async def hushendataK(
symbol: str = Query(..., description="股票代码"),
start_date: str = Query(..., description="起始日期"),
end_date: str = Query(..., description="结束日期"),
):
try:
stock_zh_a_daily_qfq_df = ak.stock_zh_a_daily(symbol='sh' + symbol, start_date=start_date, end_date=end_date,
adjust="qfq")
# 获取所有的code
all_dates = stock_zh_a_daily_qfq_df['date']
# 如果你想要一个列表而不是Pandas Series
dates_list = all_dates.tolist()
all_opens = stock_zh_a_daily_qfq_df['open']
opens_list = all_opens.tolist()
cleaned_opens_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in opens_list
]
all_closes = stock_zh_a_daily_qfq_df['close']
close_list = all_closes.tolist()
cleaned_close_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in close_list
]
all_highs = stock_zh_a_daily_qfq_df['high']
high_list = all_highs.tolist()
cleaned_high_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in high_list
]
all_lows = stock_zh_a_daily_qfq_df['low']
low_list = all_lows.tolist()
cleaned_low_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in low_list
]
all_volumes = stock_zh_a_daily_qfq_df['volume']
volume_list = all_volumes.tolist()
cleaned_volume_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in volume_list
]
all_amounts = stock_zh_a_daily_qfq_df['amount']
amount_lists = all_amounts.tolist()
cleaned_amount_list = [
value if not (math.isnan(value) or math.isinf(value)) else 0.00
for value in amount_lists
]
global stock_data
stock_data = {
"amount": cleaned_amount_list,
"close": cleaned_close_list,
"date": dates_list,
"high": cleaned_high_list,
"low": cleaned_low_list,
"open": cleaned_opens_list,
"outstanding_share": [],
"turnover": [],
"volume": cleaned_volume_list
}
except Exception as e:
print(e)
stock_data = {
"amount": [],
"close": [],
"date": [],
"high": [],
"low": [],
"open": [],
"outstanding_share": [],
"turnover": [],
"volume": []
}
finally:
return {"message": stock_data}

@ -0,0 +1,6 @@
from fastapi import FastAPI,Query
from starlette.middleware.cors import CORSMiddleware
import akshare as ak
async def get_day_k_data():
pass

@ -1,92 +1,93 @@
from typing import Any
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from src.utils.helpers import first
def register_exception_handler(app: FastAPI):
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""请求参数验证错误处理"""
error = first(exc.errors())
field = error.get('loc')[1] or ''
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({
"status_code": status.HTTP_422_UNPROCESSABLE_ENTITY,
"message": "{} param {}".format(field, error.get('msg')),
"detail": exc.errors()}),
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""Http 异常处理"""
return JSONResponse(
status_code=exc.status_code,
content=jsonable_encoder({
"status_code": exc.status_code,
"message": exc.detail}),
)
@app.exception_handler(Exception)
async def exception_callback(request: Request, exc: Exception):
"""其他异常处理,遇到其他问题再自定义异常"""
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=jsonable_encoder({
"status_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
"message": "Internal Server Error",
# "detail": ''.join(exc.args)
}),
)
class CommonHttpException(HTTPException):
def __init__(self, detail, status_code, **kwargs: dict[str, Any]) -> None:
super().__init__(status_code=status_code, detail=detail, **kwargs)
class DetailedHTTPException(HTTPException):
STATUS_CODE = status.HTTP_500_INTERNAL_SERVER_ERROR
DETAIL = "Server error"
def __init__(self, **kwargs: dict[str, Any]) -> None:
super().__init__(status_code=self.STATUS_CODE, detail=self.DETAIL, **kwargs)
class PermissionDenied(DetailedHTTPException):
STATUS_CODE = status.HTTP_403_FORBIDDEN
DETAIL = "Permission denied"
class NotFound(DetailedHTTPException):
STATUS_CODE = status.HTTP_404_NOT_FOUND
class BadRequest(DetailedHTTPException):
STATUS_CODE = status.HTTP_400_BAD_REQUEST
DETAIL = "Bad Request"
class UnprocessableEntity(DetailedHTTPException):
STATUS_CODE = status.HTTP_422_UNPROCESSABLE_ENTITY
DETAIL = "Unprocessable entity"
class NotAuthenticated(DetailedHTTPException):
STATUS_CODE = status.HTTP_401_UNAUTHORIZED
DETAIL = "User not authenticated"
class WxResponseError(DetailedHTTPException):
STATUS_CODE = status.HTTP_400_BAD_REQUEST
DETAIL = "请求微信异常"
def __init__(self) -> None:
super().__init__(headers={"WWW-Authenticate": "Bearer"})
3
from typing import Any
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from src.utils.helpers import first
def register_exception_handler(app: FastAPI):
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""请求参数验证错误处理"""
error = first(exc.errors())
field = error.get('loc')[1] or ''
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({
"status_code": status.HTTP_422_UNPROCESSABLE_ENTITY,
"message": "{} param {}".format(field, error.get('msg')),
"detail": exc.errors()}),
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""Http 异常处理"""
return JSONResponse(
status_code=exc.status_code,
content=jsonable_encoder({
"status_code": exc.status_code,
"message": exc.detail}),
)
@app.exception_handler(Exception)
async def exception_callback(request: Request, exc: Exception):
"""其他异常处理,遇到其他问题再自定义异常"""
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=jsonable_encoder({
"status_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
"message": "Internal Server Error",
# "detail": ''.join(exc.args)
}),
)
class CommonHttpException(HTTPException):
def __init__(self, detail, status_code, **kwargs: dict[str, Any]) -> None:
super().__init__(status_code=status_code, detail=detail, **kwargs)
class DetailedHTTPException(HTTPException):
STATUS_CODE = status.HTTP_500_INTERNAL_SERVER_ERROR
DETAIL = "Server error"
def __init__(self, **kwargs: dict[str, Any]) -> None:
super().__init__(status_code=self.STATUS_CODE, detail=self.DETAIL, **kwargs)
class PermissionDenied(DetailedHTTPException):
STATUS_CODE = status.HTTP_403_FORBIDDEN
DETAIL = "Permission denied"
class NotFound(DetailedHTTPException):
STATUS_CODE = status.HTTP_404_NOT_FOUND
class BadRequest(DetailedHTTPException):
STATUS_CODE = status.HTTP_400_BAD_REQUEST
DETAIL = "Bad Request"
class UnprocessableEntity(DetailedHTTPException):
STATUS_CODE = status.HTTP_422_UNPROCESSABLE_ENTITY
DETAIL = "Unprocessable entity"
class NotAuthenticated(DetailedHTTPException):
STATUS_CODE = status.HTTP_401_UNAUTHORIZED
DETAIL = "User not authenticated"
class WxResponseError(DetailedHTTPException):
STATUS_CODE = status.HTTP_400_BAD_REQUEST
DETAIL = "请求微信异常"
def __init__(self) -> None:
super().__init__(headers={"WWW-Authenticate": "Bearer"})

@ -8,7 +8,7 @@ from src.responses import response_list_response
financial_reports_router = APIRouter()
@financial_reports_router.post("/query")
async def financial_repoets_query(request: FinancialReportQuery )-> JSONResponse:
async def financial_repoets_query(request: FinancialReportQuery) -> JSONResponse:
"""
搜索接口
"""

@ -11,7 +11,6 @@ import asyncio
from fastapi import FastAPI
from datetime import datetime
from starlette.middleware.cors import CORSMiddleware
from src.exceptions import register_exception_handler
@ -19,17 +18,19 @@ from src.tortoises import register_tortoise_orm
from src.xtdata.router import router as xtdata_router
from src.backtest.router import router as backtest_router
from src.combination.router import router as combine_router
from src.models.financial_reports import FinancialReport
from src.models.financial_reports import FinancialReport
from src.utils.update_financial_reports_spider import combined_search_and_list
from src.financial_reports.router import financial_reports_router
from src.utils.generate_pinyin_abbreviation import generate_pinyin_abbreviation
from src.composite.router import composite_router
from src.akshare_data.router import router as akshare_data_router
from xtquant import xtdata
from src.settings.config import app_configs, settings
import adata
import akshare as ak
# import adata
# import akshare as ak
app = FastAPI(**app_configs)
@ -43,6 +44,7 @@ app.include_router(backtest_router, prefix="/backtest", tags=["回测接口"])
app.include_router(combine_router, prefix="/combine", tags=["组合接口"])
app.include_router(financial_reports_router, prefix="/financial-reports", tags=["财报接口"])
app.include_router(composite_router, prefix="/composite", tags=["vacode组合接口"])
app.include_router(akshare_data_router, prefix="/akshare", tags=["数据接口"])
if settings.ENVIRONMENT.is_deployed:
sentry_sdk.init(
@ -52,11 +54,14 @@ if settings.ENVIRONMENT.is_deployed:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
# allow_origins=settings.CORS_ORIGINS,
# allow_origin_regex=settings.CORS_ORIGINS_REGEX,
allow_origins=["*"],
allow_origin_regex=None,
allow_credentials=True,
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
allow_headers=settings.CORS_HEADERS,
# allow_headers=settings.CORS_HEADERS,
allow_headers=["*"], # 允许所有请求头
)
@ -64,12 +69,13 @@ app.add_middleware(
async def root():
return {"message": "Hello, FastAPI!"}
# 定时检查和数据抓取函数
async def run_data_fetcher():
print("财报抓取启动")
while True:
# 获取数据库中最新记录的时间
latest_record = await FinancialReport .all().order_by("-created_at").first()
latest_record = await FinancialReport.all().order_by("-created_at").first()
latest_db_date = latest_record.created_at if latest_record else pd.Timestamp("1970-01-01")
# 将最新数据库日期设为无时区,以便比较
@ -78,8 +84,8 @@ async def run_data_fetcher():
# 当前时间
current_time = pd.Timestamp(datetime.now()).replace(tzinfo=None)
# 检查当前时间是否超过数据库最新记录时间 12 个小时
if (current_time - latest_db_date).total_seconds() > 43200:
# 检查当前时间是否超过数据库最新记录时间 12 个小时
if (current_time - latest_db_date).total_seconds() > 43200:
print("启动财报数据抓取...")
await combined_search_and_list()
else:
@ -88,18 +94,18 @@ async def run_data_fetcher():
# 休眠 12 小时43200 秒),然后再次检查
await asyncio.sleep(43200)
async def test_liang_hua_ku():
print("量化库测试函数启动")
# 启动时立即运行检查任务
@app.on_event("startup")
async def lifespan():
# 启动时将任务放入后台,不阻塞启动流程
# 启动时将任务放入后台,不阻塞启动流程
asyncio.create_task(test_liang_hua_ku())
asyncio.create_task(run_data_fetcher())
if __name__ == "__main__":
uvicorn.run('src.main:app', host="0.0.0.0", port=8012, reload=True)
uvicorn.run('src.main:app', host="127.0.0.1", port=8012, reload=True)

@ -44,9 +44,3 @@ app_configs: dict[str, Any] = {
"docs_url": "/api/docs",
}
# app_configs['debug'] = True
# if settings.ENVIRONMENT.is_deployed:
# app_configs["root_path"] = f"/v{settings.APP_VERSION}"
#
# if not settings.ENVIRONMENT.is_debug:
# app_configs["openapi_url"] = None # hide docs

@ -1,358 +1,358 @@
import numpy as np # 导入 numpy 库
from xtquant import xtdata # 导入 xtquant 库的 xtdata 模块
from src.utils.backtest_until import convert_pandas_to_json_serializable
from src.models.wance_data_stock import WanceDataStock
from src.pydantic.factor_request import StockQuery
from src.utils.history_data_processing_utils import translation_dict
# 获取股票池中的所有股票的股票代码
async def get_full_tick_keys_service(code_list: list):
"""
获取所有股票的逐笔成交键若未提供股票列表则默认使用 ['SH', 'SZ']
"""
if len(code_list) == 0: # 如果股票代码列表为空
code_list = ['SH', 'SZ'] # 默认使用上证和深证股票池
result = xtdata.get_full_tick(code_list=['SH', 'SZ']) # 调用 xtdata 库的函数获取逐笔成交数据
return list(result.keys()) # 返回所有股票代码的列表
# 获取股票池中的所有股票的逐笔成交数据
async def get_full_tick_service(code_list: list):
"""
获取所有股票的逐笔成交数据若未提供股票列表则默认使用 ['SH', 'SZ']
"""
if len(code_list) == 0: # 如果股票代码列表为空
code_list = ['SH', 'SZ'] # 默认使用上证和深证股票池
result = xtdata.get_full_tick(code_list=code_list) # 获取逐笔成交数据
return result # 返回逐笔成交数据
# 获取板块信息
async def get_sector_list_service():
"""
获取所有板块的列表
"""
result = xtdata.get_sector_list() # 调用 xtdata 库的函数获取板块信息
return result # 返回板块列表
# 通过板块名称获取股票列表
async def get_stock_list_in_sector_service(sector_name):
"""
获取指定板块中的所有股票名称
"""
result = xtdata.get_stock_list_in_sector(sector_name=sector_name) # 根据板块名称获取股票列表
return result # 返回股票列表
# 下载板块信息
async def download_sector_data_service():
"""
下载板块的详细数据
"""
result = xtdata.download_sector_data() # 调用 xtdata 库的函数下载板块数据
return result # 返回下载结果
# 获取股票的详细信息
async def get_instrument_detail_service(stock_code: str, iscomplete: bool):
"""
获取股票的详细信息
"""
result = xtdata.get_instrument_detail(stock_code=stock_code, iscomplete=iscomplete) # 获取股票的详细信息
return result # 返回股票的详细信息
# 获取市场行情数据
async def get_market_data_service(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
count: int, dividend_type: str, fill_data: bool):
"""
获取指定条件下的市场行情数据
"""
result = xtdata.get_market_data(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
end_time=end_time, count=count, dividend_type=dividend_type,
fill_data=fill_data) # 获取市场数据
return result # 返回市场数据
# 获取最新的 K 线数据
async def get_full_kline_service(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
count: int,
dividend_type: str, fill_data: bool):
"""
获取指定条件下的完整 K 线数据
"""
result = xtdata.get_full_kline(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
end_time=end_time, count=count, dividend_type=dividend_type,
fill_data=fill_data) # 获取完整的 K 线数据
return result # 返回 K 线数据
# 下载历史数据
async def download_history_data_service(stock_code: str, period: str, start_time: str, end_time: str,
incrementally: bool):
"""
下载指定股票的历史数据
"""
xtdata.download_history_data(stock_code=stock_code, period=period, start_time=start_time, end_time=end_time,
incrementally=incrementally) # 下载股票历史数据
# 获取本地数据
async def get_local_data_service(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
count: int, dividend_type: str, fill_data: bool, data_dir: str):
"""
@param field_list:
@type field_list:
@param stock_list:
@type stock_list:
@param period:
@type period:
@param start_time:
@type start_time:
@param end_time:
@type end_time:
@param count:
@type count:
@param dividend_type:
@type dividend_type:
@param fill_data:
@type fill_data:
@param data_dir:
@type data_dir:
@return:
@rtype:
"""
"""
获取本地存储的股票数据
"""
return_list = []
result = xtdata.get_local_data(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
end_time=end_time, count=count, dividend_type=dividend_type, fill_data=fill_data,
data_dir=data_dir) # 获取本地数据
for i in stock_list:
return_list.append({
i : convert_pandas_to_json_serializable(result.get(i))
})
return return_list
# 订阅全市场行情推送
async def subscribe_whole_quote_service(code_list, callback=None):
"""
订阅全市场行情推送
"""
if callback:
result = xtdata.subscribe_whole_quote(code_list, callback=on_data) # 有回调函数时调用
else:
result = xtdata.subscribe_whole_quote(code_list, callback=None) # 无回调函数时调用
return result # 返回订阅结果
# 订阅单只股票行情推送
async def subscribe_quote_service(stock_code: str, period: str, start_time: str, end_time: str, count: int,
callback: bool):
"""
订阅单只股票的行情推送
"""
if callback:
result = xtdata.subscribe_quote(stock_code=stock_code, period=period, start_time=start_time, end_time=end_time,
count=count, callback=on_data) # 有回调函数时调用
else:
result = xtdata.subscribe_quote(stock_code=stock_code, period=period, start_time=start_time, end_time=end_time,
count=count, callback=None) # 无回调函数时调用
return result # 返回订阅结果
def on_data(datas):
"""
行情数据推送的回调函数处理并打印接收到的行情数据
"""
for stock_code in datas:
print(stock_code, datas[stock_code]) # 打印接收到的行情数据
# 批量增量下载或批量下载
async def download_history_data2_service(stock_list: list, period: str, start_time: str, end_time: str, callback: bool):
"""
批量下载股票历史数据支持进度回调
"""
if callback:
xtdata.download_history_data2(stock_list=stock_list, period=period, start_time=start_time, end_time=end_time,
callback=on_progress) # 有回调函数时调用
else:
xtdata.download_history_data2(stock_list=stock_list, period=period, start_time=start_time, end_time=end_time,
callback=None) # 无回调函数时调用
def on_progress(data):
"""
数据下载进度的回调函数实时显示进度
"""
print(data) # 打印下载进度
# 单只股票下载示例
def download_data():
"""
测试函数下载单只股票的历史数据并获取相关信息
"""
xtdata.download_history_data("000300.SH", "1d", start_time='', end_time='', incrementally=None) # 下载历史数据
a = xtdata.get_full_tick(code_list=['SH', 'SZ']) # 获取逐笔成交数据
print(list(a.keys())) # 打印获取的股票代码
async def update_stock_data_service():
"""
# 用于本地xtdata数据更新
@return:
@rtype:
"""
stock_list = await get_full_tick_keys_service(['SH', 'SZ'])
for stock in stock_list:
await download_history_data_service(stock_code=stock,
period='1d',
start_time='',
end_time='',
incrementally=True)
async def get_stock_factor_service(query_params: StockQuery):
"""
动态查询方法根据传入的字段和值进行条件构造
:param kwargs: 查询条件键为字段名值为查询值
:return: 符合条件的记录列表
"""
"""
使用示例
{
"financial_asset_value": 10.7169,
"financial_cash_flow": -0.42982,
"gt": {
"profit_gross_rate": 4.0,
"valuation_PB_TTM": 0.8
},
"lt": {
"market_indicator": -0.6
},
"gte": {
"valuation_market_percentile":5.38793
},
"lte": {
"valuation_PEGTTM_ratio": 15.2773
},
"between": {
"financial_dividend": {
"min": 1.0,
"max": 2.0
},
"growth_Income_rate": {
"min": -0.6,
"max": 1.0
}
},
"stock_code": "600051.SH",
"stock_name": "宁波联合",
"stock_sector": ["上证A股", "沪深A股"],
"stock_type": ["stock"],
"time_start": "0",
"time_end": "20240911",
"valuation_PEG_percentile": 15.2222,
"valuation_PTS_TTM": 2.3077
}
financial_asset_value financial_cash_flow 字段为普通等于条件
gt 字段用于指定字段大于某个值的条件例如 profit_gross_rate 大于 15.0
lt 字段用于指定字段小于某个值的条件例如 market_indicator 小于 30.0
gte 字段用于指定字段大于等于某个值的条件例如 valuation_market_percentile 大于等于 0.5
lte 字段用于指定字段小于等于某个值的条件例如 valuation_PEGTTM_ratio 小于等于 2.0
between 字段用于指定字段在两个值之间的范围例如 financial_dividend 1.0 5.0 之间growth_Income_rate 10.0 25.0 之间
stock_code stock_name 字段为普通等于条件
stock_sector stock_type 字段为列表类型的条件可能会进行包含匹配
time_start time_end 字段为普通等于条件
valuation_PEG_percentile valuation_PTS_TTM 字段为普通等于条件
比较大小字段说明
'gt' 对应 __gt大于
'lt' 对应 __lt小于
'gte' 对应 __gte大于等于
'lte' 对应 __lte小于等于
'between' 对应在某个值之间的条件使用 __gt __lt 组合实现
时间字段说明
time_start={'recent_years': 2} 近2年
time_start={'recent_months': 2} 近2月
"""
filters = {}
# 处理大于、小于等操作
if query_params.gt:
for field, value in query_params.gt.items():
filters[f"{field}__gt"] = value
if query_params.lt:
for field, value in query_params.lt.items():
filters[f"{field}__lt"] = value
if query_params.gte:
for field, value in query_params.gte.items():
filters[f"{field}__gte"] = value
if query_params.lte:
for field, value in query_params.lte.items():
filters[f"{field}__lte"] = value
# 处理范围查询
if query_params.between:
for field, range_values in query_params.between.items():
if 'min' in range_values and 'max' in range_values:
filters[f"{field}__gte"] = range_values['min']
filters[f"{field}__lte"] = range_values['max']
# 处理普通字段
for field, value in query_params.model_dump(exclude_unset=True).items():
if field not in ['gt', 'lt', 'gte', 'lte', 'between']:
if isinstance(value, float):
filters[f"{field}__gte"] = value
elif isinstance(value, str):
filters[f"{field}__icontains"] = value
elif isinstance(value, list):
if field == "stock_sector":
# 转换 `stock_sector` 中的中文板块名到英文
translated_sectors = [translation_dict.get(sector, sector) for sector in value]
filters[f"{field}__contains"] = translated_sectors
else:
filters[f"{field}__contains"] = value
try:
stocks = await WanceDataStock.filter(**filters).all()
return stocks
except Exception as e:
print(f"Error occurred when querying stocks: {e}")
return []
# 将 numpy 类型转换为原生 Python 类型
async def convert_numpy_to_native(obj):
"""
递归地将 numpy 类型转换为原生 Python 类型
"""
if isinstance(obj, np.generic): # 如果是 numpy 类型
return obj.item() # 转换为 Python 原生类型
elif isinstance(obj, dict): # 如果是字典类型
return {k: convert_numpy_to_native(v) for k, v in obj.items()} # 递归转换字典中的值
elif isinstance(obj, list): # 如果是列表类型
return [convert_numpy_to_native(i) for i in obj] # 递归转换列表中的值
elif isinstance(obj, tuple): # 如果是元组类型
return tuple(convert_numpy_to_native(i) for i in obj) # 递归转换元组中的值
else:
return obj # 如果是其他类型,直接返回
if __name__ == '__main__':
# result = asyncio.run(get_instrument_detail_service("000300.SH",False))
# print(result)
# resulta = asyncio.run(get_sector_list_service())
# print(resulta)
pass
import numpy as np # 导入 numpy 库
from xtquant import xtdata # 导入 xtquant 库的 xtdata 模块
from src.utils.backtest_until import convert_pandas_to_json_serializable
from src.models.wance_data_stock import WanceDataStock
from src.pydantic.factor_request import StockQuery
from src.utils.history_data_processing_utils import translation_dict
# 获取股票池中的所有股票的股票代码
async def get_full_tick_keys_service(code_list: list):
"""
获取所有股票的逐笔成交键若未提供股票列表则默认使用 ['SH', 'SZ']
"""
if len(code_list) == 0: # 如果股票代码列表为空
code_list = ['SH', 'SZ'] # 默认使用上证和深证股票池
result = xtdata.get_full_tick(code_list=['SH', 'SZ']) # 调用 xtdata 库的函数获取逐笔成交数据
return list(result.keys()) # 返回所有股票代码的列表
# 获取股票池中的所有股票的逐笔成交数据
async def get_full_tick_service(code_list: list):
"""
获取所有股票的逐笔成交数据若未提供股票列表则默认使用 ['SH', 'SZ']
"""
if len(code_list) == 0: # 如果股票代码列表为空
code_list = ['SH', 'SZ'] # 默认使用上证和深证股票池
result = xtdata.get_full_tick(code_list=code_list) # 获取逐笔成交数据
return result # 返回逐笔成交数据
# 获取板块信息
async def get_sector_list_service():
"""
获取所有板块的列表
"""
result = xtdata.get_sector_list() # 调用 xtdata 库的函数获取板块信息
return result # 返回板块列表
# 通过板块名称获取股票列表
async def get_stock_list_in_sector_service(sector_name):
"""
获取指定板块中的所有股票名称
"""
result = xtdata.get_stock_list_in_sector(sector_name=sector_name) # 根据板块名称获取股票列表
return result # 返回股票列表
# 下载板块信息
async def download_sector_data_service():
"""
下载板块的详细数据
"""
result = xtdata.download_sector_data() # 调用 xtdata 库的函数下载板块数据
return result # 返回下载结果
# 获取股票的详细信息
async def get_instrument_detail_service(stock_code: str, iscomplete: bool):
"""
获取股票的详细信息
"""
result = xtdata.get_instrument_detail(stock_code=stock_code, iscomplete=iscomplete) # 获取股票的详细信息
return result # 返回股票的详细信息
# 获取市场行情数据
async def get_market_data_service(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
count: int, dividend_type: str, fill_data: bool):
"""
获取指定条件下的市场行情数据
"""
result = xtdata.get_market_data(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
end_time=end_time, count=count, dividend_type=dividend_type,
fill_data=fill_data) # 获取市场数据
return result # 返回市场数据
# 获取最新的 K 线数据
async def get_full_kline_service(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
count: int,
dividend_type: str, fill_data: bool):
"""
获取指定条件下的完整 K 线数据
"""
result = xtdata.get_full_kline(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
end_time=end_time, count=count, dividend_type=dividend_type,
fill_data=fill_data) # 获取完整的 K 线数据
return result # 返回 K 线数据
# 下载历史数据
async def download_history_data_service(stock_code: str, period: str, start_time: str, end_time: str,
incrementally: bool):
"""
下载指定股票的历史数据
"""
xtdata.download_history_data(stock_code=stock_code, period=period, start_time=start_time, end_time=end_time,
incrementally=incrementally) # 下载股票历史数据
# 获取本地数据
async def get_local_data_service(field_list: list, stock_list: list, period: str, start_time: str, end_time: str,
count: int, dividend_type: str, fill_data: bool, data_dir: str):
"""
@param field_list:
@type field_list:
@param stock_list:
@type stock_list:
@param period:
@type period:
@param start_time:
@type start_time:
@param end_time:
@type end_time:
@param count:
@type count:
@param dividend_type:
@type dividend_type:
@param fill_data:
@type fill_data:
@param data_dir:
@type data_dir:
@return:
@rtype:
"""
"""
获取本地存储的股票数据
"""
return_list = []
result = xtdata.get_local_data(field_list=field_list, stock_list=stock_list, period=period, start_time=start_time,
end_time=end_time, count=count, dividend_type=dividend_type, fill_data=fill_data,
data_dir=data_dir) # 获取本地数据
for i in stock_list:
return_list.append({
i : convert_pandas_to_json_serializable(result.get(i))
})
return return_list
# 订阅全市场行情推送
async def subscribe_whole_quote_service(code_list, callback=None):
"""
订阅全市场行情推送
"""
if callback:
result = xtdata.subscribe_whole_quote(code_list, callback=on_data) # 有回调函数时调用
else:
result = xtdata.subscribe_whole_quote(code_list, callback=None) # 无回调函数时调用
return result # 返回订阅结果
# 订阅单只股票行情推送
async def subscribe_quote_service(stock_code: str, period: str, start_time: str, end_time: str, count: int,
callback: bool):
"""
订阅单只股票的行情推送
"""
if callback:
result = xtdata.subscribe_quote(stock_code=stock_code, period=period, start_time=start_time, end_time=end_time,
count=count, callback=on_data) # 有回调函数时调用
else:
result = xtdata.subscribe_quote(stock_code=stock_code, period=period, start_time=start_time, end_time=end_time,
count=count, callback=None) # 无回调函数时调用
return result # 返回订阅结果
def on_data(datas):
"""
行情数据推送的回调函数处理并打印接收到的行情数据
"""
for stock_code in datas:
print(stock_code, datas[stock_code]) # 打印接收到的行情数据
# 批量增量下载或批量下载
async def download_history_data2_service(stock_list: list, period: str, start_time: str, end_time: str, callback: bool):
"""
批量下载股票历史数据支持进度回调
"""
if callback:
xtdata.download_history_data2(stock_list=stock_list, period=period, start_time=start_time, end_time=end_time,
callback=on_progress) # 有回调函数时调用
else:
xtdata.download_history_data2(stock_list=stock_list, period=period, start_time=start_time, end_time=end_time,
callback=None) # 无回调函数时调用
def on_progress(data):
"""
数据下载进度的回调函数实时显示进度
"""
print(data) # 打印下载进度
# 单只股票下载示例
def download_data():
"""
测试函数下载单只股票的历史数据并获取相关信息
"""
xtdata.download_history_data("000300.SH", "1d", start_time='', end_time='', incrementally=None) # 下载历史数据
a = xtdata.get_full_tick(code_list=['SH', 'SZ']) # 获取逐笔成交数据
print(list(a.keys())) # 打印获取的股票代码
async def update_stock_data_service():
"""
# 用于本地xtdata数据更新
@return:
@rtype:
"""
stock_list = await get_full_tick_keys_service(['SH', 'SZ'])
for stock in stock_list:
await download_history_data_service(stock_code=stock,
period='1d',
start_time='',
end_time='',
incrementally=True)
async def get_stock_factor_service(query_params: StockQuery):
"""
动态查询方法根据传入的字段和值进行条件构造
:param kwargs: 查询条件键为字段名值为查询值
:return: 符合条件的记录列表
"""
"""
使用示例
{
"financial_asset_value": 10.7169,
"financial_cash_flow": -0.42982,
"gt": {
"profit_gross_rate": 4.0,
"valuation_PB_TTM": 0.8
},
"lt": {
"market_indicator": -0.6
},
"gte": {
"valuation_market_percentile":5.38793
},
"lte": {
"valuation_PEGTTM_ratio": 15.2773
},
"between": {
"financial_dividend": {
"min": 1.0,
"max": 2.0
},
"growth_Income_rate": {
"min": -0.6,
"max": 1.0
}
},
"stock_code": "600051.SH",
"stock_name": "宁波联合",
"stock_sector": ["上证A股", "沪深A股"],
"stock_type": ["stock"],
"time_start": "0",
"time_end": "20240911",
"valuation_PEG_percentile": 15.2222,
"valuation_PTS_TTM": 2.3077
}
financial_asset_value financial_cash_flow 字段为普通等于条件
gt 字段用于指定字段大于某个值的条件例如 profit_gross_rate 大于 15.0
lt 字段用于指定字段小于某个值的条件例如 market_indicator 小于 30.0
gte 字段用于指定字段大于等于某个值的条件例如 valuation_market_percentile 大于等于 0.5
lte 字段用于指定字段小于等于某个值的条件例如 valuation_PEGTTM_ratio 小于等于 2.0
between 字段用于指定字段在两个值之间的范围例如 financial_dividend 1.0 5.0 之间growth_Income_rate 10.0 25.0 之间
stock_code stock_name 字段为普通等于条件
stock_sector stock_type 字段为列表类型的条件可能会进行包含匹配
time_start time_end 字段为普通等于条件
valuation_PEG_percentile valuation_PTS_TTM 字段为普通等于条件
比较大小字段说明
'gt' 对应 __gt大于
'lt' 对应 __lt小于
'gte' 对应 __gte大于等于
'lte' 对应 __lte小于等于
'between' 对应在某个值之间的条件使用 __gt __lt 组合实现
时间字段说明
time_start={'recent_years': 2} 近2年
time_start={'recent_months': 2} 近2月
"""
filters = {}
# 处理大于、小于等操作
if query_params.gt:
for field, value in query_params.gt.items():
filters[f"{field}__gt"] = value
if query_params.lt:
for field, value in query_params.lt.items():
filters[f"{field}__lt"] = value
if query_params.gte:
for field, value in query_params.gte.items():
filters[f"{field}__gte"] = value
if query_params.lte:
for field, value in query_params.lte.items():
filters[f"{field}__lte"] = value
# 处理范围查询
if query_params.between:
for field, range_values in query_params.between.items():
if 'min' in range_values and 'max' in range_values:
filters[f"{field}__gte"] = range_values['min']
filters[f"{field}__lte"] = range_values['max']
# 处理普通字段
for field, value in query_params.model_dump(exclude_unset=True).items():
if field not in ['gt', 'lt', 'gte', 'lte', 'between']:
if isinstance(value, float):
filters[f"{field}__gte"] = value
elif isinstance(value, str):
filters[f"{field}__icontains"] = value
elif isinstance(value, list):
if field == "stock_sector":
# 转换 `stock_sector` 中的中文板块名到英文
translated_sectors = [translation_dict.get(sector, sector) for sector in value]
filters[f"{field}__contains"] = translated_sectors
else:
filters[f"{field}__contains"] = value
try:
stocks = await WanceDataStock.filter(**filters).all()
return stocks
except Exception as e:
print(f"Error occurred when querying stocks: {e}")
return []
# 将 numpy 类型转换为原生 Python 类型
async def convert_numpy_to_native(obj):
"""
递归地将 numpy 类型转换为原生 Python 类型
"""
if isinstance(obj, np.generic): # 如果是 numpy 类型
return obj.item() # 转换为 Python 原生类型
elif isinstance(obj, dict): # 如果是字典类型
return {k: convert_numpy_to_native(v) for k, v in obj.items()} # 递归转换字典中的值
elif isinstance(obj, list): # 如果是列表类型
return [convert_numpy_to_native(i) for i in obj] # 递归转换列表中的值
elif isinstance(obj, tuple): # 如果是元组类型
return tuple(convert_numpy_to_native(i) for i in obj) # 递归转换元组中的值
else:
return obj # 如果是其他类型,直接返回
if __name__ == '__main__':
# result = asyncio.run(get_instrument_detail_service("000300.SH",False))
# print(result)
# resulta = asyncio.run(get_sector_list_service())
# print(resulta)
pass