1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 10:31:00 +08:00

Merge branch 'main' into main

This commit is contained in:
you-n-g
2020-11-14 17:17:02 +08:00
committed by GitHub
10 changed files with 230 additions and 86 deletions

View File

@@ -15,7 +15,9 @@
Qlib is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.
With Qlib, you can easily try your ideas to create better Quant investment strategies.
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
With Qlib, user can easily try ideas to create better Quant investment strategies.
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).

View File

@@ -221,12 +221,13 @@ class QlibConfig(Config):
self["provider_uri"] = str(Path(self["provider_uri"]).expanduser().resolve())
def get_uri_type(self):
rm = re.match("^[^/ ]+:.+", self["provider_uri"])
# Windows path is "C:\\"
if rm is None or Path(self["provider_uri"]).exists():
return QlibConfig.LOCAL_URI
else:
is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:'
is_nfs_or_win = re.match("^[^/]+:.+", self["provider_uri"]) is not None # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI
else:
return QlibConfig.LOCAL_URI
def get_data_path(self):
if self.get_uri_type() == QlibConfig.LOCAL_URI:

View File

@@ -49,42 +49,38 @@ class Account:
return self.current.position["cash"]
def update_state_from_order(self, order, trade_val, cost, trade_price):
# update cash
if order.direction == Order.SELL: # 0 for sell
self.current.position["cash"] += trade_val - cost
elif order.direction == Order.BUY: # 1 for buy
self.current.position["cash"] -= trade_val + cost
else:
raise NotImplementedError("{} ".format(order.direction))
# update turnover
self.to += trade_val
# update cost
self.ct += cost
# update return
# update self.rtn from order
trade_amount = trade_val / trade_price
if order.direction == Order.SELL: # 0 for sell
# when sell stock, get profit from price change
profit = trade_val - self.current.get_stock_price(order.stock_id) * order.deal_amount
profit = trade_val - self.current.get_stock_price(order.stock_id) * trade_amount
self.rtn += profit # note here do not consider cost
elif order.direction == Order.BUY: # 1 for buy
# when buy stock, we get return for the rtn computing method
# profit in buy order is to make self.rtn is consistent with self.earning at the end of date
profit = self.current.get_stock_price(order.stock_id) * order.deal_amount - trade_val
profit = self.current.get_stock_price(order.stock_id) * trade_amount - trade_val
self.rtn += profit
def update_order(self, order, trade_val, cost, trade_price):
# if stock is sold out, no stock price information in Position, then we should update account first, then update current position
# if stock is bought, there is no stock in current position, update current, then update account
# The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
trade_amount = trade_val / trade_price
if order.direction == Order.SELL:
# sell stock
self.update_state_from_order(order, trade_val, cost, trade_price)
# update current position
# for may sell all of stock_id
self.current.update_order(order, trade_price)
self.current.update_order(order, trade_val, cost, trade_price)
else:
# buy stock
# deal order, then update state
self.current.update_order(order, trade_price)
self.current.update_order(order, trade_val, cost, trade_price)
self.update_state_from_order(order, trade_val, cost, trade_price)
def update_daily_end(self, today, trader):

View File

@@ -208,14 +208,9 @@ class Exchange:
# If the order can only be deal 0 trade_val. Nothing to be updated
# Otherwise, it will result some stock with 0 amount in the position
if trade_account:
trade_account.update_order(
order=order,
trade_val=trade_val,
cost=trade_cost,
trade_price=trade_price,
)
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
elif position:
position.update_order(order, trade_price)
position.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
return trade_val, trade_cost, trade_price

View File

@@ -43,38 +43,44 @@ class Position:
self.position[stock_id]["price"] = price
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
def buy_stock(self, stock_id, amount, price):
def buy_stock(self, stock_id, trade_val, cost, trade_price):
trade_amount = trade_val / trade_price
if stock_id not in self.position:
self.init_stock(stock_id=stock_id, amount=amount, price=price)
self.init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
else:
# exist, add amount
self.position[stock_id]["amount"] += amount
self.position[stock_id]["amount"] += trade_amount
def sell_stock(self, stock_id, amount):
self.position["cash"] -= trade_val + cost
def sell_stock(self, stock_id, trade_val, cost, trade_price):
trade_amount = trade_val / trade_price
if stock_id not in self.position:
raise KeyError("{} not in current position".format(stock_id))
else:
# decrease the amount of stock
self.position[stock_id]["amount"] -= amount
self.position[stock_id]["amount"] -= trade_amount
# check if to delete
if self.position[stock_id]["amount"] < -1e-5:
raise ValueError(
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, amount)
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
)
elif abs(self.position[stock_id]["amount"]) <= 1e-5:
self.del_stock(stock_id)
self.position["cash"] += trade_val - cost
def del_stock(self, stock_id):
del self.position[stock_id]
def update_order(self, order, trade_price):
def update_order(self, order, trade_val, cost, trade_price):
# handle order, order is a order class, defined in exchange.py
if order.direction == Order.BUY:
# BUY
self.buy_stock(stock_id=order.stock_id, amount=order.deal_amount, price=trade_price)
self.buy_stock(order.stock_id, trade_val, cost, trade_price)
elif order.direction == Order.SELL:
# SELL
self.sell_stock(stock_id=order.stock_id, amount=order.deal_amount)
self.sell_stock(order.stock_id, trade_val, cost, trade_price)
else:
raise NotImplementedError("do not suppotr order direction {}".format(order.direction))

View File

@@ -239,7 +239,6 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
sell_order_list = []
buy_order_list = []
# load score
cash = current_temp.get_cash()
current_stock_list = current_temp.get_stock_list()
last = score_series.reindex(current_stock_list).sort_values(ascending=False).index
today = (
@@ -276,8 +275,6 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
if trade_exchange.check_order(sell_order):
sell_order_list.append(sell_order)
trade_val, trade_cost, trade_price = trade_exchange.deal_order(sell_order, position=current_temp)
# update cash
cash += trade_val - trade_cost
# sold
del self.stock_count[code]
else:
@@ -293,7 +290,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
# buy new stock
# note the current has been changed
current_stock_list = current_temp.get_stock_list()
value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
value = current_temp.get_cash() * self.risk_degree / len(buy) if len(buy) > 0 else 0
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not consider it
# as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line

View File

@@ -5,6 +5,7 @@
from __future__ import division
from __future__ import print_function
import sys
import numpy as np
import pandas as pd
@@ -19,7 +20,7 @@ try:
except ImportError as err:
print(err)
print("Do not import qlib package in the repository directory")
exit(-1)
sys.exit(-1)
__all__ = (
"Ref",

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import re
import abc
import sys
import bisect
from io import BytesIO
@@ -18,14 +19,12 @@ sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.utils import get_hs_calendar_list as get_calendar_list
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/000300cons.xls"
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
CSI300_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
CSI300_START_DATE = pd.Timestamp("2005-01-01")
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
class CSI300:
class CSIIndex:
REMOVE = "remove"
ADD = "add"
@@ -45,6 +44,9 @@ class CSI300:
self.instruments_dir.mkdir(exist_ok=True, parents=True)
self._calendar_list = None
self.cache_dir = Path("~/.cache/csi").expanduser().resolve()
self.cache_dir.mkdir(exist_ok=True, parents=True)
@property
def calendar_list(self) -> list:
"""get history trading date
@@ -52,7 +54,41 @@ class CSI300:
Returns
-------
"""
return get_calendar_list(bench=True)
return get_calendar_list(bench_code=self.index_name.upper())
@property
def new_companies_url(self):
return NEW_COMPANIES_URL.format(index_code=self.index_code)
@property
def changes_url(self):
return INDEX_CHANGES_URL
@property
@abc.abstractmethod
def bench_start_date(self) -> pd.Timestamp:
raise NotImplementedError()
@property
@abc.abstractmethod
def index_code(self):
raise NotImplementedError()
@property
@abc.abstractmethod
def index_name(self):
raise NotImplementedError()
@property
@abc.abstractmethod
def html_table_index(self):
"""Which table of changes in html
CSI300: 0
CSI100: 1
:return:
"""
raise NotImplementedError()
def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1):
"""get trading date by shift
@@ -119,14 +155,18 @@ class CSI300:
remove_date = self._get_trading_date_by_shift(add_date, shift=-1)
logger.info(f"get {add_date} changes")
try:
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
_io = BytesIO(requests.get(f"http://www.csindex.com.cn{excel_url}").content)
content = requests.get(f"http://www.csindex.com.cn{excel_url}").content
_io = BytesIO(content)
df_map = pd.read_excel(_io, sheet_name=None)
with self.cache_dir.joinpath(
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}"
).open("wb") as fp:
fp.write(content)
tmp = []
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
_df = df_map[_s_name]
_df = _df.loc[_df["指数代码"] == "000300", ["证券代码"]]
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
_df = _df.applymap(self.normalize_symbol)
_df.columns = ["symbol"]
_df["type"] = _type
@@ -135,9 +175,13 @@ class CSI300:
df = pd.concat(tmp)
except Exception:
df = None
_tmp_count = 0
for _df in pd.read_html(resp.content):
if _df.shape[-1] != 4:
continue
_tmp_count += 1
if self.html_table_index + 1 > _tmp_count:
continue
tmp = []
for _s, _type, _date in [
(_df.iloc[2:, 0], self.REMOVE, remove_date),
@@ -149,31 +193,42 @@ class CSI300:
_tmp_df["date"] = _date
tmp.append(_tmp_df)
df = pd.concat(tmp)
df.to_csv(
str(
self.cache_dir.joinpath(
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv"
).resolve()
)
)
break
return df
@staticmethod
def _get_change_notices_url() -> list:
def _get_change_notices_url(self) -> list:
"""get change notices url
Returns
-------
"""
resp = requests.get(CSI300_CHANGES_URL)
resp = requests.get(self.changes_url)
html = etree.HTML(resp.text)
return html.xpath("//*[@id='itemContainer']//li/a/@href")
def _get_new_companies(self):
logger.info("get new companies")
_io = BytesIO(requests.get(NEW_COMPANIES_URL).content)
context = requests.get(self.new_companies_url).content
with self.cache_dir.joinpath(
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
).open("wb") as fp:
fp.write(context)
_io = BytesIO(context)
df = pd.read_excel(_io)
df = df.iloc[:, [0, 4]]
df.columns = ["end_date", "symbol"]
df["symbol"] = df["symbol"].map(self.normalize_symbol)
df["end_date"] = pd.to_datetime(df["end_date"])
df["start_date"] = CSI300_START_DATE
df["start_date"] = self.bench_start_date
return df
def parse_instruments(self):
@@ -183,7 +238,7 @@ class CSI300:
-------
$ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data
"""
logger.info("start parse csi300 companies.....")
logger.info(f"start parse {self.index_name.lower()} companies.....")
instruments_columns = ["symbol", "start_date", "end_date"]
changers_df = self._get_changes()
new_df = self._get_new_companies()
@@ -196,15 +251,65 @@ class CSI300:
] = _row.date
else:
_tmp_df = pd.DataFrame(
[[_row.symbol, CSI300_START_DATE, _row.date]], columns=["symbol", "start_date", "end_date"]
[[_row.symbol, self.bench_start_date, _row.date]], columns=["symbol", "start_date", "end_date"]
)
new_df = new_df.append(_tmp_df, sort=False)
new_df.loc[:, instruments_columns].to_csv(
self.instruments_dir.joinpath("csi300.txt"), sep="\t", index=False, header=None
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
)
logger.info("parse csi300 companies finished.")
logger.info(f"parse {self.index_name.lower()} companies finished.")
class CSI300(CSIIndex):
@property
def index_code(self):
return "000300"
@property
def index_name(self):
return "csi300"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2005-01-01")
@property
def html_table_index(self):
return 0
class CSI100(CSIIndex):
@property
def index_code(self):
return "000903"
@property
def index_name(self):
return "csi100"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2006-05-29")
@property
def html_table_index(self):
return 1
def parse_instruments(qlib_dir: str):
"""
Parameters
----------
qlib_dir: str
qlib data dir, default "Path(__file__).parent/qlib_data"
"""
qlib_dir = Path(qlib_dir).expanduser().resolve()
qlib_dir.mkdir(exist_ok=True, parents=True)
CSI300(qlib_dir).parse_instruments()
CSI100(qlib_dir).parse_instruments()
if __name__ == "__main__":
fire.Fire(CSI300)
fire.Fire(parse_instruments)

View File

@@ -2,7 +2,10 @@
# Licensed under the MIT License.
import re
import time
import pickle
import requests
from pathlib import Path
import pandas as pd
from lxml import etree
@@ -11,39 +14,46 @@ SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_ty
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
CALENDAR_BENCH_URL_MAP = {
"CSI300": CALENDAR_URL_BASE.format(bench_code="000300"),
"CSI100": CALENDAR_URL_BASE.format(bench_code="000903"),
# NOTE: Use the time series of SH600000 as the sequence of all stocks
"ALL": CALENDAR_URL_BASE.format(bench_code="600000"),
}
_BENCH_CALENDAR_LIST = None
_ALL_CALENDAR_LIST = None
_HS_SYMBOLS = None
_CALENDAR_MAP = {}
# NOTE: Until 2020-10-20 20:00:00
MINIMUM_SYMBOLS_NUM = 3900
def get_hs_calendar_list(bench=False) -> list:
def get_hs_calendar_list(bench_code="CSI300") -> list:
"""get SH/SZ history calendar list
Parameters
----------
bench: bool
whether to get the bench calendar list, by default False
bench_code: str
value from ["CSI300", "CSI500", "ALL"]
Returns
-------
history calendar list
"""
global _ALL_CALENDAR_LIST
global _BENCH_CALENDAR_LIST
def _get_calendar(url):
_value_list = requests.get(url).json()["data"]["klines"]
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
# TODO: get calendar from MSN
if bench:
if _BENCH_CALENDAR_LIST is None:
_BENCH_CALENDAR_LIST = _get_calendar(CSI300_BENCH_URL)
return _BENCH_CALENDAR_LIST
if _ALL_CALENDAR_LIST is None:
_ALL_CALENDAR_LIST = _get_calendar(SH600000_BENCH_URL)
return _ALL_CALENDAR_LIST
calendar = _CALENDAR_MAP.get(bench_code, None)
if calendar is None:
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
_CALENDAR_MAP[bench_code] = calendar
return calendar
def get_hs_stock_symbols() -> list:
@@ -54,7 +64,8 @@ def get_hs_stock_symbols() -> list:
stock symbols
"""
global _HS_SYMBOLS
if _HS_SYMBOLS is None:
def _get_symbol():
_res = set()
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
@@ -64,7 +75,27 @@ def get_hs_stock_symbols() -> list:
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
)
)
_HS_SYMBOLS = sorted(list(_res))
return _res
if _HS_SYMBOLS is None:
symbols = set()
_retry = 60
# It may take multiple times to get the complete
while len(symbols) < MINIMUM_SYMBOLS_NUM:
symbols |= _get_symbol()
time.sleep(3)
symbol_cache_path = Path("~/.cache/hs_symbols_cache.pkl").expanduser().resolve()
symbol_cache_path.parent.mkdir(parents=True, exist_ok=True)
if symbol_cache_path.exists():
with symbol_cache_path.open("rb") as fp:
cache_symbols = pickle.load(fp)
symbols |= cache_symbols
with symbol_cache_path.open("wb") as fp:
pickle.dump(symbols, fp)
_HS_SYMBOLS = sorted(list(symbols))
return _HS_SYMBOLS
@@ -104,3 +135,7 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
"""
res = f"{symbol[:-2]}.{symbol[-2:]}"
return res.upper() if capital else res.lower()
if __name__ == '__main__':
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM

View File

@@ -19,7 +19,7 @@ sys.path.append(str(CUR_DIR.parent.parent))
from dump_bin import DumpData
from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
MIN_NUMBERS_TRADING = 252 / 4
@@ -130,17 +130,23 @@ class YahooCollector:
logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}")
self.download_csi300_data()
self.download_index_data()
def download_csi300_data(self):
def download_index_data(self):
# TODO: from MSN
logger.info(f"get bench data: csi300(SH000300)......")
df = pd.DataFrame(map(lambda x: x.split(","), requests.get(CSI300_BENCH_URL).json()["data"]["klines"]))
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
df["date"] = pd.to_datetime(df["date"])
df = df.astype(float, errors="ignore")
df["adjclose"] = df["close"]
df.to_csv(self.save_dir.joinpath("sh000300.csv"), index=False)
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
logger.info(f"get bench data: {_index_name}({_index_code})......")
df = pd.DataFrame(
map(
lambda x: x.split(","),
requests.get(INDEX_BENCH_URL.format(index_code=_index_code)).json()["data"]["klines"],
)
)
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
df["date"] = pd.to_datetime(df["date"])
df = df.astype(float, errors="ignore")
df["adjclose"] = df["close"]
df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False)
class Run:
@@ -192,7 +198,7 @@ class Run:
df = df[~df.index.duplicated(keep="first")]
# using China stock market data calendar
df = df.reindex(pd.Index(get_calendar_list()))
df = df.reindex(pd.Index(get_calendar_list("ALL")))
df.sort_index(inplace=True)
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan
@@ -274,8 +280,8 @@ class Run:
delay=delay,
).collector_data()
def download_csi300_data(self):
YahooCollector(self.source_dir).download_csi300_data()
def download_index_data(self):
YahooCollector(self.source_dir).download_index_data()
def download_bench_data(self):
"""download bench stock data(SH000300)"""