From 719074d30673d9ba256bf69df396854c3e498c47 Mon Sep 17 00:00:00 2001 From: wangershi Date: Thu, 25 Feb 2021 19:20:14 +0800 Subject: [PATCH 01/77] touch file --- scripts/data_collector/fund/README.md | 49 ++++++++++++++++++++ scripts/data_collector/fund/collector.py | 0 scripts/data_collector/fund/requirements.txt | 0 3 files changed, 49 insertions(+) create mode 100644 scripts/data_collector/fund/README.md create mode 100644 scripts/data_collector/fund/collector.py create mode 100644 scripts/data_collector/fund/requirements.txt diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md new file mode 100644 index 000000000..c7b91a3f5 --- /dev/null +++ b/scripts/data_collector/fund/README.md @@ -0,0 +1,49 @@ +# Collect Fund Data + +> *Please pay **ATTENTION** that the data is collected from [天天基金网](https://fund.eastmoney.com/) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)* + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + + +### CN Data + +#### 1d from East Money + +```bash + +# download from yahoo finance +python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + + +# dump data +cd qlib/scripts +python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol + +``` + +### using data + +```python +import qlib +from qlib.data import D + +qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="CN") +df = D.features(D.instruments("all"), ["$close"], freq="day") +``` + + +### Help +```bash +pythono collector.py collector_data --help +``` + +## Parameters + +- interval: 1min or 1d +- region: CN or US diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py new file mode 100644 index 000000000..e69de29bb diff --git a/scripts/data_collector/fund/requirements.txt b/scripts/data_collector/fund/requirements.txt new file mode 100644 index 000000000..e69de29bb From 6e5639621710bae40d508357d77b9d60a104aff8 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 28 Feb 2021 12:24:26 +0800 Subject: [PATCH 02/77] add crawler --- scripts/data_collector/fund/README.md | 4 +- scripts/data_collector/fund/collector.py | 408 +++++++++++++++++++++++ scripts/data_collector/utils.py | 37 ++ 3 files changed, 447 insertions(+), 2 deletions(-) diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md index c7b91a3f5..b14938a3d 100644 --- a/scripts/data_collector/fund/README.md +++ b/scripts/data_collector/fund/README.md @@ -17,8 +17,8 @@ pip install -r requirements.txt ```bash -# download from yahoo finance -python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d +# download from eastmoney.com +python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d # dump data diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index e69de29bb..404b6af0e 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -0,0 +1,408 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import abc +import sys +import copy +import time +import datetime +import importlib +import json +from abc import ABC +from pathlib import Path +from typing import Iterable, Type +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor + +import fire +import requests +import numpy as np +import pandas as pd +from tqdm import tqdm +from loguru import logger +from dateutil.tz import tzlocal +from qlib.utils import code_to_fname, fname_to_code + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) +from data_collector.utils import get_calendar_list, get_en_fund_symbols + +INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" +REGION_CN = "CN" +REGION_US = "US" + + +class FundData: + START_DATETIME = pd.Timestamp("2000-01-01") + HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6)) + END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + INTERVAL_1d = "1d" + + def __init__( + self, + timezone: str = None, + start=None, + end=None, + interval="1d", + delay=0, + show_1min_logging: bool = False, + ): + """ + + Parameters + ---------- + timezone: str + The timezone where the data is located + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1d], default 1d + start: str + start datetime, default None + end: str + end datetime, default None + show_1min_logging: bool + show 1min logging, by default False; if True, there may be many warning logs + """ + self._timezone = tzlocal() if timezone is None else timezone + self._delay = delay + self._interval = interval + self._show_1min_logging = show_1min_logging + self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME + self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME) + if self._interval != self.INTERVAL_1d: + raise ValueError(f"interval error: {self._interval}") + + # using for 1min + self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone) + self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone) + + self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) + self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) + + @staticmethod + def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): + try: + dt = pd.Timestamp(dt, tz=timezone).timestamp() + dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") + except ValueError as e: + pass + return dt + + def _sleep(self): + time.sleep(self._delay) + + @staticmethod + def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): + error_msg = f"{symbol}-{interval}-{start}-{end}" + + try: + _resp = None + + # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg + url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=100, startDate=start, endDate=end) + resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) + + if resp.status_code != 200: + raise ValueError("request error") + try: + data = json.loads(resp.text.split("(")[-1].split(")")[0]) + + # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html + SYType = data["Data"]["SYType"] + if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): + raise Exception("The fund contains 每*份收益") + + _resp = pd.DataFrame( + data["Data"]["LSJZList"] + ) + + except Exception as e: + logger.warning(f"request error: {e}") + raise + + if isinstance(_resp, pd.DataFrame): + return _resp.reset_index() + except Exception as e: + logger.warning(f"{error_msg}:{e}") + + def get_data(self, symbol: str) -> [pd.DataFrame]: + def _get_simple(start_, end_): + self._sleep() + _remote_interval = self._interval + return self.get_data_from_remote( + symbol, + interval=_remote_interval, + start=start_, + end=end_, + show_1min_logging=self._show_1min_logging, + ) + + if self._interval == self.INTERVAL_1d: + _result = _get_simple(self.start_datetime, self.end_datetime) + else: + raise ValueError(f"cannot support {self._interval}") + return _result + + +class FundCollector: + def __init__( + self, + save_dir: [str, Path], + start=None, + end=None, + interval="1d", + max_workers=4, + max_collector_count=2, + delay=0, + check_data_length: bool = False, + limit_nums: int = None, + show_1min_logging: bool = False, + ): + """ + + Parameters + ---------- + save_dir: str + stock save dir + max_workers: int + workers, default 4 + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1min + start: str + start datetime, default None + end: str + end datetime, default None + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None + show_1min_logging: bool + show 1m logging, by default False; if True, there may be many warning logs + """ + self.save_dir = Path(save_dir).expanduser().resolve() + self.save_dir.mkdir(parents=True, exist_ok=True) + + self._delay = delay + self.max_workers = max_workers + self._max_collector_count = max_collector_count + self._mini_symbol_map = {} + self._interval = interval + self._check_small_data = check_data_length + + self.fund_list = sorted(set(self.get_fund_list())) + if limit_nums is not None: + try: + self.fund_list = self.fund_list[: int(limit_nums)] + except Exception as e: + logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") + + self.fund_data = FundData( + timezone=self._timezone, + start=start, + end=end, + interval=interval, + delay=delay, + show_1min_logging=show_1min_logging, + ) + + @property + @abc.abstractmethod + def min_numbers_trading(self): + # daily, one year: 252 / 4 + # us 1min, a week: 6.5 * 60 * 5 + # cn 1min, a week: 4 * 60 * 5 + raise NotImplementedError("rewrite min_numbers_trading") + + @abc.abstractmethod + def get_fund_list(self): + raise NotImplementedError("rewrite get_fund_list") + + @property + @abc.abstractmethod + def _timezone(self): + raise NotImplementedError("rewrite get_timezone") + + def save_fund(self, symbol, df: pd.DataFrame): + """save fund data to file + + Parameters + ---------- + symbol: str + fund code + df : pd.DataFrame + df.columns must contain "symbol" and "datetime" + """ + if df.empty: + logger.warning(f"{symbol} is empty") + return + + symbol = code_to_fname(symbol) + stock_path = self.save_dir.joinpath(f"{symbol}.csv") + df["symbol"] = symbol + if stock_path.exists(): + _old_df = pd.read_csv(stock_path) + df = _old_df.append(df, sort=False) + df.to_csv(stock_path, index=False) + + def _save_small_data(self, symbol, df): + if len(df) <= self.min_numbers_trading: + logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") + _temp = self._mini_symbol_map.setdefault(symbol, []) + _temp.append(df.copy()) + return None + else: + if symbol in self._mini_symbol_map: + self._mini_symbol_map.pop(symbol) + return symbol + + def _get_data(self, symbol): + _result = None + df = self.fund_data.get_data(symbol) + if isinstance(df, pd.DataFrame): + if not df.empty: + if self._check_small_data: + if self._save_small_data(symbol, df) is not None: + _result = symbol + self.save_fund(symbol, df) + else: + _result = symbol + self.save_fund(symbol, df) + return _result + + def _collector(self, fund_list): + + error_symbol = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with tqdm(total=len(fund_list)) as p_bar: + for _symbol, _result in zip(fund_list, executor.map(self._get_data, fund_list)): + if _result is None: + error_symbol.append(_symbol) + p_bar.update() + print(error_symbol) + logger.info(f"error symbol nums: {len(error_symbol)}") + logger.info(f"current get symbol nums: {len(fund_list)}") + error_symbol.extend(self._mini_symbol_map.keys()) + return sorted(set(error_symbol)) + + def collector_data(self): + """collector data""" + logger.info("start collector fund data......") + fund_list = self.fund_list + for i in range(self._max_collector_count): + if not fund_list: + break + logger.info(f"getting data: {i+1}") + fund_list = self._collector(fund_list) + logger.info(f"{i+1} finish.") + for _symbol, _df_list in self._mini_symbol_map.items(): + self.save_fund(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) + if self._mini_symbol_map: + logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}") + logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}") + +class FundollectorCN(FundCollector, ABC): + def get_fund_list(self): + logger.info("get cn fund symbols......") + symbols = get_en_fund_symbols() + logger.info(f"get {len(symbols)} symbols.") + return symbols + + @property + def _timezone(self): + return "Asia/Shanghai" + + +class FundCollectorCN1d(FundollectorCN): + @property + def min_numbers_trading(self): + return 252 / 4 + +class Run: + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): + """ + + Parameters + ---------- + source_dir: str + The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" + max_workers: int + Concurrent number, default is 4 + region: str + region, value from ["CN", "US"], default "CN" + """ + if source_dir is None: + source_dir = CUR_DIR.joinpath("source") + self.source_dir = Path(source_dir).expanduser().resolve() + self.source_dir.mkdir(parents=True, exist_ok=True) + + if normalize_dir is None: + normalize_dir = CUR_DIR.joinpath("normalize") + self.normalize_dir = Path(normalize_dir).expanduser().resolve() + self.normalize_dir.mkdir(parents=True, exist_ok=True) + + self._cur_module = importlib.import_module("collector") + self.max_workers = max_workers + self.region = region + + def download_data( + self, + max_collector_count=2, + delay=0, + start=None, + end=None, + interval="1d", + check_data_length=False, + limit_nums=None, + show_1min_logging=False, + ): + """download data from Internet + + Parameters + ---------- + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1d + start: str + start datetime, default "2000-01-01" + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None + show_1min_logging: bool + show 1m logging, by default False; if True, there may be many warning logs + + Examples + --------- + # get daily data + $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + """ + + _class = getattr( + self._cur_module, f"FundCollector{self.region.upper()}{interval}" + ) # type: Type[FundCollector] + _class( + self.source_dir, + max_workers=self.max_workers, + max_collector_count=max_collector_count, + delay=delay, + start=start, + end=end, + interval=interval, + check_data_length=check_data_length, + limit_nums=limit_nums, + show_1min_logging=show_1min_logging, + ).collector_data() + +if __name__ == "__main__": + fire.Fire(Run) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 5f34aae7d..3319025fc 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -34,6 +34,7 @@ _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None _US_SYMBOLS = None +_EN_FUND_SYMBOLS = None _CALENDAR_MAP = {} # NOTE: Until 2020-10-20 20:00:00 @@ -220,6 +221,42 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: return _US_SYMBOLS +def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: + """get en fund symbols + + Returns + ------- + fund symbols in China + """ + global _EN_FUND_SYMBOLS + + @deco_retry + def _get_eastmoney(): + url = "http://fund.eastmoney.com/js/fundcode_search.js" + resp = requests.get(url) + if resp.status_code != 200: + raise ValueError("request error") + try: + _symbols = [] + for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): + data = sub_data.replace("\"","").replace("'","") + # TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] + _symbols.append(data.split(",")[0]) + except Exception as e: + logger.warning(f"request error: {e}") + raise + if len(_symbols) < 8000: + raise ValueError("request error") + return _symbols + + if _EN_FUND_SYMBOLS is None: + _all_symbols = _get_eastmoney() + + _EN_FUND_SYMBOLS = sorted(set(_all_symbols)) + + return _EN_FUND_SYMBOLS + + def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: """symbol suffix to prefix From db80b620d8408998cab30e4a6d51b9d038264c7e Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 28 Feb 2021 17:03:14 +0800 Subject: [PATCH 03/77] ready for collector --- scripts/data_collector/fund/collector.py | 79 ++++++-------------- scripts/data_collector/fund/requirements.txt | 10 +++ 2 files changed, 32 insertions(+), 57 deletions(-) diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 404b6af0e..f9b2a6775 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -20,20 +20,16 @@ import pandas as pd from tqdm import tqdm from loguru import logger from dateutil.tz import tzlocal -from qlib.utils import code_to_fname, fname_to_code CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.utils import get_calendar_list, get_en_fund_symbols +from data_collector.utils import get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" REGION_CN = "CN" -REGION_US = "US" - class FundData: START_DATETIME = pd.Timestamp("2000-01-01") - HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6)) END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) INTERVAL_1d = "1d" @@ -44,7 +40,6 @@ class FundData: end=None, interval="1d", delay=0, - show_1min_logging: bool = False, ): """ @@ -60,22 +55,15 @@ class FundData: start datetime, default None end: str end datetime, default None - show_1min_logging: bool - show 1min logging, by default False; if True, there may be many warning logs """ self._timezone = tzlocal() if timezone is None else timezone self._delay = delay self._interval = interval - self._show_1min_logging = show_1min_logging self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME) if self._interval != self.INTERVAL_1d: raise ValueError(f"interval error: {self._interval}") - # using for 1min - self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone) - self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone) - self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) @@ -92,33 +80,26 @@ class FundData: time.sleep(self._delay) @staticmethod - def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): + def get_data_from_remote(symbol, interval, start, end): error_msg = f"{symbol}-{interval}-{start}-{end}" try: - _resp = None - # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg - url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=100, startDate=start, endDate=end) + url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end) resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) if resp.status_code != 200: raise ValueError("request error") - try: - data = json.loads(resp.text.split("(")[-1].split(")")[0]) + + data = json.loads(resp.text.split("(")[-1].split(")")[0]) - # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html - SYType = data["Data"]["SYType"] - if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): - raise Exception("The fund contains 每*份收益") + # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html + SYType = data["Data"]["SYType"] + if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): + raise Exception("The fund contains 每*份收益") - _resp = pd.DataFrame( - data["Data"]["LSJZList"] - ) - - except Exception as e: - logger.warning(f"request error: {e}") - raise + # TODO: should we sort the value by datetime? + _resp = pd.DataFrame(data["Data"]["LSJZList"]) if isinstance(_resp, pd.DataFrame): return _resp.reset_index() @@ -134,7 +115,6 @@ class FundData: interval=_remote_interval, start=start_, end=end_, - show_1min_logging=self._show_1min_logging, ) if self._interval == self.INTERVAL_1d: @@ -156,14 +136,13 @@ class FundCollector: delay=0, check_data_length: bool = False, limit_nums: int = None, - show_1min_logging: bool = False, ): """ Parameters ---------- save_dir: str - stock save dir + fund save dir max_workers: int workers, default 4 max_collector_count: int @@ -180,8 +159,6 @@ class FundCollector: check data length, by default False limit_nums: int using for debug, by default None - show_1min_logging: bool - show 1m logging, by default False; if True, there may be many warning logs """ self.save_dir = Path(save_dir).expanduser().resolve() self.save_dir.mkdir(parents=True, exist_ok=True) @@ -206,7 +183,6 @@ class FundCollector: end=end, interval=interval, delay=delay, - show_1min_logging=show_1min_logging, ) @property @@ -240,13 +216,14 @@ class FundCollector: logger.warning(f"{symbol} is empty") return - symbol = code_to_fname(symbol) - stock_path = self.save_dir.joinpath(f"{symbol}.csv") + fund_path = self.save_dir.joinpath(f"{symbol}.csv") df["symbol"] = symbol - if stock_path.exists(): - _old_df = pd.read_csv(stock_path) + if fund_path.exists(): + # TODO: read the fund code as str, not int, like "000001" shouldn't be "1" + _old_df = pd.read_csv(fund_path) + # TODO: remove the duplicate date df = _old_df.append(df, sort=False) - df.to_csv(stock_path, index=False) + df.to_csv(fund_path, index=False) def _save_small_data(self, symbol, df): if len(df) <= self.min_numbers_trading: @@ -274,7 +251,6 @@ class FundCollector: return _result def _collector(self, fund_list): - error_symbol = [] with ThreadPoolExecutor(max_workers=self.max_workers) as executor: with tqdm(total=len(fund_list)) as p_bar: @@ -301,7 +277,7 @@ class FundCollector: for _symbol, _df_list in self._mini_symbol_map.items(): self.save_fund(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) if self._mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}") + logger.warning(f"less than {self.min_numbers_trading} fund list: {list(self._mini_symbol_map.keys())}") logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}") class FundollectorCN(FundCollector, ABC): @@ -322,30 +298,23 @@ class FundCollectorCN1d(FundollectorCN): return 252 / 4 class Run: - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): + def __init__(self, source_dir=None, max_workers=4, region=REGION_CN): """ Parameters ---------- source_dir: str The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" - normalize_dir: str - Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 region: str - region, value from ["CN", "US"], default "CN" + region, value from ["CN"], default "CN" """ if source_dir is None: source_dir = CUR_DIR.joinpath("source") self.source_dir = Path(source_dir).expanduser().resolve() self.source_dir.mkdir(parents=True, exist_ok=True) - if normalize_dir is None: - normalize_dir = CUR_DIR.joinpath("normalize") - self.normalize_dir = Path(normalize_dir).expanduser().resolve() - self.normalize_dir.mkdir(parents=True, exist_ok=True) - self._cur_module = importlib.import_module("collector") self.max_workers = max_workers self.region = region @@ -359,7 +328,6 @@ class Run: interval="1d", check_data_length=False, limit_nums=None, - show_1min_logging=False, ): """download data from Internet @@ -375,12 +343,10 @@ class Run: start datetime, default "2000-01-01" end: str end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` - check_data_length: bool + check_data_length: bool # if this param useful? check data length, by default False limit_nums: int using for debug, by default None - show_1min_logging: bool - show 1m logging, by default False; if True, there may be many warning logs Examples --------- @@ -401,7 +367,6 @@ class Run: interval=interval, check_data_length=check_data_length, limit_nums=limit_nums, - show_1min_logging=show_1min_logging, ).collector_data() if __name__ == "__main__": diff --git a/scripts/data_collector/fund/requirements.txt b/scripts/data_collector/fund/requirements.txt index e69de29bb..11c6730c0 100644 --- a/scripts/data_collector/fund/requirements.txt +++ b/scripts/data_collector/fund/requirements.txt @@ -0,0 +1,10 @@ +loguru +fire +requests +numpy +pandas +tqdm +lxml +loguru +yahooquery +json \ No newline at end of file From 3082f6ac1ba201ebcdeb115f27e97f75d2017439 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 28 Feb 2021 19:06:40 +0800 Subject: [PATCH 04/77] ready for dump_bin --- scripts/data_collector/fund/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md index b14938a3d..bcbbbcba7 100644 --- a/scripts/data_collector/fund/README.md +++ b/scripts/data_collector/fund/README.md @@ -23,7 +23,7 @@ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d -- # dump data cd qlib/scripts -python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/qlib_cn_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol +python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ ``` @@ -33,8 +33,8 @@ python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qli import qlib from qlib.data import D -qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="CN") -df = D.features(D.instruments("all"), ["$close"], freq="day") +qlib.init(provider_uri="~/.qlib/qlib_data/cn_fund_data") +df = D.features(D.instruments(market="all"), ["$DWJZ", "$LJJZ"], freq="day") ``` @@ -45,5 +45,5 @@ pythono collector.py collector_data --help ## Parameters -- interval: 1min or 1d -- region: CN or US +- interval: 1d +- region: CN From 82353b20e18462d8b1a67742b2b6210dc80461ee Mon Sep 17 00:00:00 2001 From: wangershi Date: Mon, 1 Mar 2021 21:10:46 +0800 Subject: [PATCH 05/77] black format --- scripts/data_collector/fund/collector.py | 14 +++++++++----- scripts/data_collector/utils.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index f9b2a6775..a2b7089a1 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -28,6 +28,7 @@ from data_collector.utils import get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" REGION_CN = "CN" + class FundData: START_DATETIME = pd.Timestamp("2000-01-01") END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) @@ -85,12 +86,14 @@ class FundData: try: # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg - url = INDEX_BENCH_URL.format(index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end) + url = INDEX_BENCH_URL.format( + index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end + ) resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) if resp.status_code != 200: raise ValueError("request error") - + data = json.loads(resp.text.split("(")[-1].split(")")[0]) # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html @@ -280,6 +283,7 @@ class FundCollector: logger.warning(f"less than {self.min_numbers_trading} fund list: {list(self._mini_symbol_map.keys())}") logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}") + class FundollectorCN(FundCollector, ABC): def get_fund_list(self): logger.info("get cn fund symbols......") @@ -297,6 +301,7 @@ class FundCollectorCN1d(FundollectorCN): def min_numbers_trading(self): return 252 / 4 + class Run: def __init__(self, source_dir=None, max_workers=4, region=REGION_CN): """ @@ -354,9 +359,7 @@ class Run: $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d """ - _class = getattr( - self._cur_module, f"FundCollector{self.region.upper()}{interval}" - ) # type: Type[FundCollector] + _class = getattr(self._cur_module, f"FundCollector{self.region.upper()}{interval}") # type: Type[FundCollector] _class( self.source_dir, max_workers=self.max_workers, @@ -369,5 +372,6 @@ class Run: limit_nums=limit_nums, ).collector_data() + if __name__ == "__main__": fire.Fire(Run) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 3319025fc..5d5822f91 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -239,7 +239,7 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: try: _symbols = [] for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): - data = sub_data.replace("\"","").replace("'","") + data = sub_data.replace('"', "").replace("'", "") # TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] _symbols.append(data.split(",")[0]) except Exception as e: From 34b7da1dd88ad6886d36f44ddf7d8ccf44a66eda Mon Sep 17 00:00:00 2001 From: wangershi Date: Wed, 3 Mar 2021 22:49:48 +0800 Subject: [PATCH 06/77] add calendar list by threshold --- scripts/data_collector/utils.py | 55 +++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 5d5822f91..1a08c514f 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import re +import os import time import bisect import pickle @@ -14,6 +15,9 @@ import pandas as pd from lxml import etree from loguru import logger from yahooquery import Ticker +from tqdm import tqdm +from functools import partial +from concurrent.futures import ProcessPoolExecutor HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" @@ -94,6 +98,57 @@ def get_calendar_list(bench_code="CSI300") -> list: return calendar +def return_date_list(source_dir, date_field_name, file_path): + df = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0) + + return df[date_field_name].to_list() + + +def get_calendar_list_by_ratio( + source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, max_workers: int = 16 +) -> list: + """get calendar list by selecting the date when few funds trade in this day + + Parameters + ---------- + source_dir: str or Path + The directory where the raw data collected from the Internet is saved + date_field_name: str + date field name, default is date + threshold: float + threshold to exclude some days when few funds trade in this day, default 0.5 + max_workers: int + Concurrent number, default is 16 + + Returns + ------- + history calendar list + """ + logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......") + + _number_all_funds = len(os.listdir(source_dir)) + + _list_all_date = dict() + + _fun = partial(return_date_list, source_dir, date_field_name) + + with tqdm(total=_number_all_funds) as p_bar: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for date_list in executor.map(_fun, os.listdir(source_dir)): + for date in date_list: + if date in _list_all_date.keys(): + _list_all_date[date] += 1 + else: + _list_all_date[date] = 0 + + p_bar.update() + + _threshold_number = int(_number_all_funds * threshold) + calendar = [date for date in _list_all_date if _list_all_date[date] >= _threshold_number] + + return calendar + + def get_hs_stock_symbols() -> list: """get SH/SZ stock symbols From 11412727ef9089863b88f4d58b332513350cb115 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 7 Mar 2021 18:51:38 +0800 Subject: [PATCH 07/77] add normalizer --- scripts/data_collector/fund/README.md | 4 +- scripts/data_collector/fund/collector.py | 171 ++++++++++++++++++++++- scripts/data_collector/utils.py | 43 ++++-- 3 files changed, 200 insertions(+), 18 deletions(-) diff --git a/scripts/data_collector/fund/README.md b/scripts/data_collector/fund/README.md index bcbbbcba7..c729b7eaa 100644 --- a/scripts/data_collector/fund/README.md +++ b/scripts/data_collector/fund/README.md @@ -20,10 +20,12 @@ pip install -r requirements.txt # download from eastmoney.com python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d +# normalize +python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ # dump data cd qlib/scripts -python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ +python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ ``` diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index a2b7089a1..795d8848e 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -23,7 +23,7 @@ from dateutil.tz import tzlocal CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.utils import get_en_fund_symbols +from data_collector.utils import get_calendar_list, get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" REGION_CN = "CN" @@ -302,14 +302,149 @@ class FundCollectorCN1d(FundollectorCN): return 252 / 4 +class FundNormalize: + COLUMNS = ["open", "close", "high", "low", "volume"] + DAILY_FORMAT = "%Y-%m-%d" + + def __init__( + self, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + """ + + Parameters + ---------- + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + self._date_field_name = date_field_name + self._symbol_field_name = symbol_field_name + + self._calendar_list = self._get_calendar_list() + print (self._calendar_list) + + @staticmethod + def normalize_fund( + df: pd.DataFrame, + calendar_list: list = None, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + if df.empty: + return df + df = df.copy() + df.set_index(date_field_name, inplace=True) + df.index = pd.to_datetime(df.index) + df = df[~df.index.duplicated(keep="first")] + if calendar_list is not None: + df = df.reindex( + pd.DataFrame(index=calendar_list) + .loc[ + pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + + pd.Timedelta(hours=23, minutes=59) + ] + .index + ) + df.sort_index(inplace=True) + + df.index.names = [date_field_name] + return df.reset_index() + + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + # normalize + df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + return df + + @abc.abstractmethod + def _get_calendar_list(self): + """Get benchmark calendar""" + raise NotImplementedError("") + + +class FundNormalize1d(FundNormalize, ABC): + DAILY_FORMAT = "%Y-%m-%d" + + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + df = super(FundNormalize, self).normalize(df) + return df + + +class FundNormalizeCN: + def _get_calendar_list(self): + return get_calendar_list("ALL") + + +class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d): + pass + + +class Normalize: + def __init__( + self, + source_dir: [str, Path], + target_dir: [str, Path], + normalize_class: Type[FundNormalize], + max_workers: int = 16, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + ): + """ + + Parameters + ---------- + source_dir: str or Path + The directory where the raw data collected from the Internet is saved + target_dir: str or Path + Directory for normalize data + normalize_class: Type[FundNormalize] + normalize class + max_workers: int + Concurrent number, default is 16 + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + if not (source_dir and target_dir): + raise ValueError("source_dir and target_dir cannot be None") + self._source_dir = Path(source_dir).expanduser() + self._target_dir = Path(target_dir).expanduser() + self._target_dir.mkdir(parents=True, exist_ok=True) + + self._max_workers = max_workers + + self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name) + + def _executor(self, file_path: Path): + file_path = Path(file_path) + df = pd.read_csv(file_path) + df = self._normalize_obj.normalize(df) + if not df.empty: + df.to_csv(self._target_dir.joinpath(file_path.name), index=False) + + def normalize(self): + logger.info("normalize data......") + + with ProcessPoolExecutor(max_workers=self._max_workers) as worker: + file_list = list(self._source_dir.glob("*.csv")) + with tqdm(total=len(file_list)) as p_bar: + for _ in worker.map(self._executor, file_list): + p_bar.update() + + class Run: - def __init__(self, source_dir=None, max_workers=4, region=REGION_CN): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): """ Parameters ---------- source_dir: str The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 region: str @@ -320,6 +455,11 @@ class Run: self.source_dir = Path(source_dir).expanduser().resolve() self.source_dir.mkdir(parents=True, exist_ok=True) + if normalize_dir is None: + normalize_dir = CUR_DIR.joinpath("normalize") + self.normalize_dir = Path(normalize_dir).expanduser().resolve() + self.normalize_dir.mkdir(parents=True, exist_ok=True) + self._cur_module = importlib.import_module("collector") self.max_workers = max_workers self.region = region @@ -372,6 +512,33 @@ class Run: limit_nums=limit_nums, ).collector_data() + def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"): + """normalize data + + Parameters + ---------- + interval: str + freq, value from [1d], default 1d + date_field_name: str + date field name, default date + symbol_field_name: str + symbol field name, default symbol + + Examples + --------- + $ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ + """ + _class = getattr(self._cur_module, f"FundNormalize{self.region.upper()}{interval}") + fc = Normalize( + source_dir=self.source_dir, + target_dir=self.normalize_dir, + normalize_class=_class, + max_workers=self.max_workers, + date_field_name=date_field_name, + symbol_field_name=symbol_field_name, + ) + fc.normalize() + if __name__ == "__main__": fire.Fire(Run) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 1a08c514f..56d010974 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -98,14 +98,14 @@ def get_calendar_list(bench_code="CSI300") -> list: return calendar -def return_date_list(source_dir, date_field_name, file_path): - df = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0) - - return df[date_field_name].to_list() +def return_date_list(source_dir, date_field_name: str, file_path: Path): + file_path = Path(file_path) + date_list = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0)[date_field_name].to_list() + return sorted(map(lambda x: pd.Timestamp(x), date_list)) def get_calendar_list_by_ratio( - source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, max_workers: int = 16 + source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, minimum_count: int = 10, max_workers: int = 16 ) -> list: """get calendar list by selecting the date when few funds trade in this day @@ -117,6 +117,8 @@ def get_calendar_list_by_ratio( date field name, default is date threshold: float threshold to exclude some days when few funds trade in this day, default 0.5 + minimum_count: int + minimum count of funds should trade in one day max_workers: int Concurrent number, default is 16 @@ -126,25 +128,36 @@ def get_calendar_list_by_ratio( """ logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......") - _number_all_funds = len(os.listdir(source_dir)) + source_dir = Path(source_dir).expanduser() + file_list = list(source_dir.glob("*.csv")) - _list_all_date = dict() + _number_all_funds = len(file_list) + logger.info(f"count how many funds trade in this day......") + _dict_count_trade = dict() # dict{date:count} _fun = partial(return_date_list, source_dir, date_field_name) - with tqdm(total=_number_all_funds) as p_bar: with ProcessPoolExecutor(max_workers=max_workers) as executor: - for date_list in executor.map(_fun, os.listdir(source_dir)): + for date_list in executor.map(_fun, file_list[:_number_all_funds]): for date in date_list: - if date in _list_all_date.keys(): - _list_all_date[date] += 1 - else: - _list_all_date[date] = 0 + if date not in _dict_count_trade.keys(): + _dict_count_trade[date] = 0 + + _dict_count_trade[date] += 1 p_bar.update() + + logger.info(f"count how many funds have founded in this day......") + _dict_count_founding = {date:_number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} + with tqdm(total=_number_all_funds) as p_bar: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for date_list in executor.map(_fun, file_list[:_number_all_funds]): + oldest_date = sorted(date_list)[0] # this fund haven't found before this day + for date in _dict_count_founding.keys(): + if date < oldest_date: + _dict_count_founding[date] -= 1 - _threshold_number = int(_number_all_funds * threshold) - calendar = [date for date in _list_all_date if _list_all_date[date] >= _threshold_number] + calendar = [date for date in _dict_count_trade if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)] return calendar From 6bcd88973b3ff906ac61e174185a28ea2b1d8ea0 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 7 Mar 2021 19:32:37 +0800 Subject: [PATCH 08/77] resolve one bug --- scripts/data_collector/fund/collector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 795d8848e..08773cb9f 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -324,7 +324,6 @@ class FundNormalize: self._symbol_field_name = symbol_field_name self._calendar_list = self._get_calendar_list() - print (self._calendar_list) @staticmethod def normalize_fund( @@ -368,7 +367,7 @@ class FundNormalize1d(FundNormalize, ABC): DAILY_FORMAT = "%Y-%m-%d" def normalize(self, df: pd.DataFrame) -> pd.DataFrame: - df = super(FundNormalize, self).normalize(df) + df = super(FundNormalize1d, self).normalize(df) return df From 9df0361262eadd065ab783bb5349ef16203d04b4 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 7 Mar 2021 19:35:50 +0800 Subject: [PATCH 09/77] black --- scripts/data_collector/utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 56d010974..ed14ad6e1 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -105,7 +105,11 @@ def return_date_list(source_dir, date_field_name: str, file_path: Path): def get_calendar_list_by_ratio( - source_dir: [str, Path], date_field_name: str = "date", threshold: float = 0.5, minimum_count: int = 10, max_workers: int = 16 + source_dir: [str, Path], + date_field_name: str = "date", + threshold: float = 0.5, + minimum_count: int = 10, + max_workers: int = 16, ) -> list: """get calendar list by selecting the date when few funds trade in this day @@ -134,7 +138,7 @@ def get_calendar_list_by_ratio( _number_all_funds = len(file_list) logger.info(f"count how many funds trade in this day......") - _dict_count_trade = dict() # dict{date:count} + _dict_count_trade = dict() # dict{date:count} _fun = partial(return_date_list, source_dir, date_field_name) with tqdm(total=_number_all_funds) as p_bar: with ProcessPoolExecutor(max_workers=max_workers) as executor: @@ -146,9 +150,9 @@ def get_calendar_list_by_ratio( _dict_count_trade[date] += 1 p_bar.update() - + logger.info(f"count how many funds have founded in this day......") - _dict_count_founding = {date:_number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} + _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: with ProcessPoolExecutor(max_workers=max_workers) as executor: for date_list in executor.map(_fun, file_list[:_number_all_funds]): @@ -157,7 +161,11 @@ def get_calendar_list_by_ratio( if date < oldest_date: _dict_count_founding[date] -= 1 - calendar = [date for date in _dict_count_trade if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)] + calendar = [ + date + for date in _dict_count_trade + if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count) + ] return calendar From 4e7a147759286b8b20729a07381652030d428a58 Mon Sep 17 00:00:00 2001 From: wangershi Date: Sun, 14 Mar 2021 14:24:14 +0800 Subject: [PATCH 10/77] use base.py --- scripts/data_collector/base.py | 48 +-- scripts/data_collector/fund/collector.py | 417 ++++++----------------- 2 files changed, 121 insertions(+), 344 deletions(-) diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index ccd8f59e5..12983f6a5 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -46,7 +46,7 @@ class BaseCollector(abc.ABC): Parameters ---------- save_dir: str - stock save dir + instrument save dir max_workers: int workers, default 4 max_collector_count: int @@ -77,11 +77,11 @@ class BaseCollector(abc.ABC): self.start_datetime = self.normalize_start_datetime(start) self.end_datetime = self.normalize_end_datetime(end) - self.stock_list = sorted(set(self.get_stock_list())) + self.instrument_list = sorted(set(self.get_instrument_list())) if limit_nums is not None: try: - self.stock_list = self.stock_list[: int(limit_nums)] + self.instrument_list = self.instrument_list[: int(limit_nums)] except Exception as e: logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") @@ -108,8 +108,8 @@ class BaseCollector(abc.ABC): raise NotImplementedError("rewrite min_numbers_trading") @abc.abstractmethod - def get_stock_list(self): - raise NotImplementedError("rewrite get_stock_list") + def get_instrument_list(self): + raise NotImplementedError("rewrite get_instrument_list") @abc.abstractmethod def normalize_symbol(self, symbol: str): @@ -158,27 +158,27 @@ class BaseCollector(abc.ABC): return _result def save_instrument(self, symbol, df: pd.DataFrame): - """save stock data to file + """save instrument data to file Parameters ---------- symbol: str - stock code + instrument code df : pd.DataFrame df.columns must contain "symbol" and "datetime" """ - if df.empty: + if df is None or df.empty: logger.warning(f"{symbol} is empty") return symbol = self.normalize_symbol(symbol) symbol = code_to_fname(symbol) - stock_path = self.save_dir.joinpath(f"{symbol}.csv") + instrument_path = self.save_dir.joinpath(f"{symbol}.csv") df["symbol"] = symbol - if stock_path.exists(): - _old_df = pd.read_csv(stock_path) + if instrument_path.exists(): + _old_df = pd.read_csv(instrument_path) df = _old_df.append(df, sort=False) - df.to_csv(stock_path, index=False) + df.to_csv(instrument_path, index=False) def cache_small_data(self, symbol, df): if len(df) <= self.min_numbers_trading: @@ -191,38 +191,38 @@ class BaseCollector(abc.ABC): self.mini_symbol_map.pop(symbol) return self.NORMAL_FLAG - def _collector(self, stock_list): + def _collector(self, instrument_list): error_symbol = [] with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - with tqdm(total=len(stock_list)) as p_bar: - for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)): + with tqdm(total=len(instrument_list)) as p_bar: + for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)): if _result != self.NORMAL_FLAG: error_symbol.append(_symbol) p_bar.update() print(error_symbol) logger.info(f"error symbol nums: {len(error_symbol)}") - logger.info(f"current get symbol nums: {len(stock_list)}") + logger.info(f"current get symbol nums: {len(instrument_list)}") error_symbol.extend(self.mini_symbol_map.keys()) return sorted(set(error_symbol)) def collector_data(self): """collector data""" logger.info("start collector data......") - stock_list = self.stock_list + instrument_list = self.instrument_list for i in range(self.max_collector_count): - if not stock_list: + if not instrument_list: break logger.info(f"getting data: {i+1}") - stock_list = self._collector(stock_list) + instrument_list = self._collector(instrument_list) logger.info(f"{i+1} finish.") for _symbol, _df_list in self.mini_symbol_map.items(): self.save_instrument( _symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]) ) if self.mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}") - logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}") + logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}") + logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}") class BaseNormalize(abc.ABC): @@ -386,9 +386,9 @@ class BaseRun(abc.ABC): Examples --------- # get daily data - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d # get 1m data - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m + $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ _class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector] @@ -416,7 +416,7 @@ class BaseRun(abc.ABC): Examples --------- - $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d + $ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d """ _class = getattr(self._cur_module, self.normalize_class_name) yc = Normalize( diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 08773cb9f..1e0d2d8bf 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -11,123 +11,24 @@ import json from abc import ABC from pathlib import Path from typing import Iterable, Type -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor import fire import requests import numpy as np import pandas as pd -from tqdm import tqdm from loguru import logger from dateutil.tz import tzlocal +from qlib.config import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) +from data_collector.base import BaseCollector, BaseNormalize, BaseRun from data_collector.utils import get_calendar_list, get_en_fund_symbols INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}" -REGION_CN = "CN" -class FundData: - START_DATETIME = pd.Timestamp("2000-01-01") - END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) - INTERVAL_1d = "1d" - - def __init__( - self, - timezone: str = None, - start=None, - end=None, - interval="1d", - delay=0, - ): - """ - - Parameters - ---------- - timezone: str - The timezone where the data is located - delay: float - time.sleep(delay), default 0 - interval: str - freq, value from [1d], default 1d - start: str - start datetime, default None - end: str - end datetime, default None - """ - self._timezone = tzlocal() if timezone is None else timezone - self._delay = delay - self._interval = interval - self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME - self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME) - if self._interval != self.INTERVAL_1d: - raise ValueError(f"interval error: {self._interval}") - - self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) - self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) - - @staticmethod - def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): - try: - dt = pd.Timestamp(dt, tz=timezone).timestamp() - dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") - except ValueError as e: - pass - return dt - - def _sleep(self): - time.sleep(self._delay) - - @staticmethod - def get_data_from_remote(symbol, interval, start, end): - error_msg = f"{symbol}-{interval}-{start}-{end}" - - try: - # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg - url = INDEX_BENCH_URL.format( - index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end - ) - resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) - - if resp.status_code != 200: - raise ValueError("request error") - - data = json.loads(resp.text.split("(")[-1].split(")")[0]) - - # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html - SYType = data["Data"]["SYType"] - if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): - raise Exception("The fund contains 每*份收益") - - # TODO: should we sort the value by datetime? - _resp = pd.DataFrame(data["Data"]["LSJZList"]) - - if isinstance(_resp, pd.DataFrame): - return _resp.reset_index() - except Exception as e: - logger.warning(f"{error_msg}:{e}") - - def get_data(self, symbol: str) -> [pd.DataFrame]: - def _get_simple(start_, end_): - self._sleep() - _remote_interval = self._interval - return self.get_data_from_remote( - symbol, - interval=_remote_interval, - start=start_, - end=end_, - ) - - if self._interval == self.INTERVAL_1d: - _result = _get_simple(self.start_datetime, self.end_datetime) - else: - raise ValueError(f"cannot support {self._interval}") - return _result - - -class FundCollector: +class FundCollector(BaseCollector): def __init__( self, save_dir: [str, Path], @@ -163,134 +64,108 @@ class FundCollector: limit_nums: int using for debug, by default None """ - self.save_dir = Path(save_dir).expanduser().resolve() - self.save_dir.mkdir(parents=True, exist_ok=True) - - self._delay = delay - self.max_workers = max_workers - self._max_collector_count = max_collector_count - self._mini_symbol_map = {} - self._interval = interval - self._check_small_data = check_data_length - - self.fund_list = sorted(set(self.get_fund_list())) - if limit_nums is not None: - try: - self.fund_list = self.fund_list[: int(limit_nums)] - except Exception as e: - logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") - - self.fund_data = FundData( - timezone=self._timezone, + super(FundCollector, self).__init__( + save_dir=save_dir, start=start, end=end, interval=interval, + max_workers=max_workers, + max_collector_count=max_collector_count, delay=delay, + check_data_length=check_data_length, + limit_nums=limit_nums, ) - @property - @abc.abstractmethod - def min_numbers_trading(self): - # daily, one year: 252 / 4 - # us 1min, a week: 6.5 * 60 * 5 - # cn 1min, a week: 4 * 60 * 5 - raise NotImplementedError("rewrite min_numbers_trading") + self.init_datetime() - @abc.abstractmethod - def get_fund_list(self): - raise NotImplementedError("rewrite get_fund_list") + def init_datetime(self): + if self.interval == self.INTERVAL_1min: + self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) + elif self.interval == self.INTERVAL_1d: + pass + else: + raise ValueError(f"interval error: {self.interval}") + + self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) + self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) + + @staticmethod + def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone): + try: + dt = pd.Timestamp(dt, tz=timezone).timestamp() + dt = pd.Timestamp(dt, tz=tzlocal(), unit="s") + except ValueError as e: + pass + return dt @property @abc.abstractmethod def _timezone(self): raise NotImplementedError("rewrite get_timezone") - def save_fund(self, symbol, df: pd.DataFrame): - """save fund data to file + @staticmethod + def get_data_from_remote(symbol, interval, start, end): + error_msg = f"{symbol}-{interval}-{start}-{end}" - Parameters - ---------- - symbol: str - fund code - df : pd.DataFrame - df.columns must contain "symbol" and "datetime" - """ - if df.empty: - logger.warning(f"{symbol} is empty") - return + try: + # TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg + url = INDEX_BENCH_URL.format( + index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end + ) + resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) - fund_path = self.save_dir.joinpath(f"{symbol}.csv") - df["symbol"] = symbol - if fund_path.exists(): - # TODO: read the fund code as str, not int, like "000001" shouldn't be "1" - _old_df = pd.read_csv(fund_path) - # TODO: remove the duplicate date - df = _old_df.append(df, sort=False) - df.to_csv(fund_path, index=False) + if resp.status_code != 200: + raise ValueError("request error") - def _save_small_data(self, symbol, df): - if len(df) <= self.min_numbers_trading: - logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") - _temp = self._mini_symbol_map.setdefault(symbol, []) - _temp.append(df.copy()) - return None + data = json.loads(resp.text.split("(")[-1].split(")")[0]) + + # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html + SYType = data["Data"]["SYType"] + if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): + raise Exception("The fund contains 每*份收益") + + # TODO: should we sort the value by datetime? + _resp = pd.DataFrame(data["Data"]["LSJZList"]) + + if isinstance(_resp, pd.DataFrame): + return _resp.reset_index() + except Exception as e: + logger.warning(f"{error_msg}:{e}") + + def get_data( + self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + ) -> [pd.DataFrame]: + def _get_simple(start_, end_): + self.sleep() + _remote_interval = interval + return self.get_data_from_remote( + symbol, + interval=_remote_interval, + start=start_, + end=end_, + ) + + if interval == self.INTERVAL_1d: + _result = _get_simple(start_datetime, end_datetime) else: - if symbol in self._mini_symbol_map: - self._mini_symbol_map.pop(symbol) - return symbol - - def _get_data(self, symbol): - _result = None - df = self.fund_data.get_data(symbol) - if isinstance(df, pd.DataFrame): - if not df.empty: - if self._check_small_data: - if self._save_small_data(symbol, df) is not None: - _result = symbol - self.save_fund(symbol, df) - else: - _result = symbol - self.save_fund(symbol, df) + raise ValueError(f"cannot support {interval}") return _result - def _collector(self, fund_list): - error_symbol = [] - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - with tqdm(total=len(fund_list)) as p_bar: - for _symbol, _result in zip(fund_list, executor.map(self._get_data, fund_list)): - if _result is None: - error_symbol.append(_symbol) - p_bar.update() - print(error_symbol) - logger.info(f"error symbol nums: {len(error_symbol)}") - logger.info(f"current get symbol nums: {len(fund_list)}") - error_symbol.extend(self._mini_symbol_map.keys()) - return sorted(set(error_symbol)) - def collector_data(self): """collector data""" - logger.info("start collector fund data......") - fund_list = self.fund_list - for i in range(self._max_collector_count): - if not fund_list: - break - logger.info(f"getting data: {i+1}") - fund_list = self._collector(fund_list) - logger.info(f"{i+1} finish.") - for _symbol, _df_list in self._mini_symbol_map.items(): - self.save_fund(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) - if self._mini_symbol_map: - logger.warning(f"less than {self.min_numbers_trading} fund list: {list(self._mini_symbol_map.keys())}") - logger.info(f"total {len(self.fund_list)}, error: {len(set(fund_list))}") + super(FundCollector, self).collector_data() class FundollectorCN(FundCollector, ABC): - def get_fund_list(self): + def get_instrument_list(self): logger.info("get cn fund symbols......") symbols = get_en_fund_symbols() logger.info(f"get {len(symbols)} symbols.") return symbols + def normalize_symbol(self, symbol): + return symbol + @property def _timezone(self): return "Asia/Shanghai" @@ -302,29 +177,9 @@ class FundCollectorCN1d(FundollectorCN): return 252 / 4 -class FundNormalize: - COLUMNS = ["open", "close", "high", "low", "volume"] +class FundNormalize(BaseNormalize): DAILY_FORMAT = "%Y-%m-%d" - def __init__( - self, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): - """ - - Parameters - ---------- - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - self._date_field_name = date_field_name - self._symbol_field_name = symbol_field_name - - self._calendar_list = self._get_calendar_list() - @staticmethod def normalize_fund( df: pd.DataFrame, @@ -357,11 +212,6 @@ class FundNormalize: df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name) return df - @abc.abstractmethod - def _get_calendar_list(self): - """Get benchmark calendar""" - raise NotImplementedError("") - class FundNormalize1d(FundNormalize, ABC): DAILY_FORMAT = "%Y-%m-%d" @@ -380,62 +230,8 @@ class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d): pass -class Normalize: - def __init__( - self, - source_dir: [str, Path], - target_dir: [str, Path], - normalize_class: Type[FundNormalize], - max_workers: int = 16, - date_field_name: str = "date", - symbol_field_name: str = "symbol", - ): - """ - - Parameters - ---------- - source_dir: str or Path - The directory where the raw data collected from the Internet is saved - target_dir: str or Path - Directory for normalize data - normalize_class: Type[FundNormalize] - normalize class - max_workers: int - Concurrent number, default is 16 - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - if not (source_dir and target_dir): - raise ValueError("source_dir and target_dir cannot be None") - self._source_dir = Path(source_dir).expanduser() - self._target_dir = Path(target_dir).expanduser() - self._target_dir.mkdir(parents=True, exist_ok=True) - - self._max_workers = max_workers - - self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name) - - def _executor(self, file_path: Path): - file_path = Path(file_path) - df = pd.read_csv(file_path) - df = self._normalize_obj.normalize(df) - if not df.empty: - df.to_csv(self._target_dir.joinpath(file_path.name), index=False) - - def normalize(self): - logger.info("normalize data......") - - with ProcessPoolExecutor(max_workers=self._max_workers) as worker: - file_list = list(self._source_dir.glob("*.csv")) - with tqdm(total=len(file_list)) as p_bar: - for _ in worker.map(self._executor, file_list): - p_bar.update() - - -class Run: - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): +class Run(BaseRun): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN): """ Parameters @@ -446,23 +242,26 @@ class Run: Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int Concurrent number, default is 4 + interval: str + freq, value from [1min, 1d], default 1d region: str region, value from ["CN"], default "CN" """ - if source_dir is None: - source_dir = CUR_DIR.joinpath("source") - self.source_dir = Path(source_dir).expanduser().resolve() - self.source_dir.mkdir(parents=True, exist_ok=True) - - if normalize_dir is None: - normalize_dir = CUR_DIR.joinpath("normalize") - self.normalize_dir = Path(normalize_dir).expanduser().resolve() - self.normalize_dir.mkdir(parents=True, exist_ok=True) - - self._cur_module = importlib.import_module("collector") - self.max_workers = max_workers + super().__init__(source_dir, normalize_dir, max_workers, interval) self.region = region + @property + def collector_class_name(self): + return f"FundCollector{self.region.upper()}{self.interval}" + + @property + def normalize_class_name(self): + return f"FundNormalize{self.region.upper()}{self.interval}" + + @property + def default_base_dir(self) -> [Path, str]: + return CUR_DIR + def download_data( self, max_collector_count=2, @@ -498,26 +297,13 @@ class Run: $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d """ - _class = getattr(self._cur_module, f"FundCollector{self.region.upper()}{interval}") # type: Type[FundCollector] - _class( - self.source_dir, - max_workers=self.max_workers, - max_collector_count=max_collector_count, - delay=delay, - start=start, - end=end, - interval=interval, - check_data_length=check_data_length, - limit_nums=limit_nums, - ).collector_data() + super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums) - def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): """normalize data Parameters ---------- - interval: str - freq, value from [1d], default 1d date_field_name: str date field name, default date symbol_field_name: str @@ -527,16 +313,7 @@ class Run: --------- $ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ """ - _class = getattr(self._cur_module, f"FundNormalize{self.region.upper()}{interval}") - fc = Normalize( - source_dir=self.source_dir, - target_dir=self.normalize_dir, - normalize_class=_class, - max_workers=self.max_workers, - date_field_name=date_field_name, - symbol_field_name=symbol_field_name, - ) - fc.normalize() + super(Run, self).normalize_data(date_field_name, symbol_field_name) if __name__ == "__main__": From d245242f2f859da841c6de70a2091f0cc78e9421 Mon Sep 17 00:00:00 2001 From: zhupr Date: Thu, 18 Mar 2021 11:21:25 +0800 Subject: [PATCH 11/77] Fix dump_bin.py && check_dump_bin.py --- scripts/check_dump_bin.py | 3 ++- scripts/dump_bin.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/check_dump_bin.py b/scripts/check_dump_bin.py index 7c2e837af..ef8023219 100644 --- a/scripts/check_dump_bin.py +++ b/scripts/check_dump_bin.py @@ -66,7 +66,7 @@ class CheckBin: self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path]) if check_fields is None: - check_fields = list(map(lambda x: x.split(".")[0], bin_path_list[0].glob(f"*.bin"))) + check_fields = list(map(lambda x: x.name.split(".")[0], bin_path_list[0].glob(f"*.bin"))) else: check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields self.check_fields = list(map(lambda x: x.strip(), check_fields)) @@ -91,6 +91,7 @@ class CheckBin: origin_df[self.symbol_field_name] = symbol origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True) origin_df.index.names = qlib_df.index.names + origin_df = origin_df.reindex(qlib_df.index) try: compare = datacompy.Compare( origin_df, diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index 4811fd486..38dd1f3af 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -219,7 +219,7 @@ class DumpDataBase: # used when creating a bin file date_index = self.get_datetime_index(_df, calendar_list) for field in self.get_dump_fields(_df.columns): - bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}") + bin_path = features_dir.joinpath(f"{field.lower()}.{self.freq}{self.DUMP_FILE_SUFFIX}") if field not in _df.columns: continue if bin_path.exists() and self._mode == self.UPDATE_MODE: From 38f35658e755451497c745842c80e9f6acbc4082 Mon Sep 17 00:00:00 2001 From: Wendi Li Date: Thu, 18 Mar 2021 13:19:27 +0800 Subject: [PATCH 12/77] Update __init__.py --- qlib/data/dataset/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 0cdb598f2..0f5d2baba 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -259,7 +259,7 @@ class TSDataSampler: self.fillna_type = fillna_type assert get_level_index(data, "datetime") == 0 self.data = lazy_sort_index(data) - self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! But + self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! # NOTE: append last line with full NaN for better performance in `__getitem__` self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0) self.nan_idx = -1 # The last line is all NaN @@ -267,7 +267,6 @@ class TSDataSampler: # the data type will be changed # The index of usable data is between start_idx and end_idx self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) - # self.index_link = self.build_link(self.data) self.idx_df, self.idx_map = self.build_index(self.data) self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance From d3160e94399ef395fdd12116b416a84bea3a87b4 Mon Sep 17 00:00:00 2001 From: wangershi Date: Thu, 18 Mar 2021 21:15:45 +0800 Subject: [PATCH 13/77] remove some useless code --- scripts/data_collector/fund/collector.py | 12 ++---------- scripts/data_collector/utils.py | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 1e0d2d8bf..10800a7a3 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -151,10 +151,6 @@ class FundCollector(BaseCollector): raise ValueError(f"cannot support {interval}") return _result - def collector_data(self): - """collector data""" - super(FundCollector, self).collector_data() - class FundollectorCN(FundCollector, ABC): def get_instrument_list(self): @@ -213,12 +209,8 @@ class FundNormalize(BaseNormalize): return df -class FundNormalize1d(FundNormalize, ABC): - DAILY_FORMAT = "%Y-%m-%d" - - def normalize(self, df: pd.DataFrame) -> pd.DataFrame: - df = super(FundNormalize1d, self).normalize(df) - return df +class FundNormalize1d(FundNormalize): + pass class FundNormalizeCN: diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index ed14ad6e1..e8c9b9dc4 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -98,9 +98,8 @@ def get_calendar_list(bench_code="CSI300") -> list: return calendar -def return_date_list(source_dir, date_field_name: str, file_path: Path): - file_path = Path(file_path) - date_list = pd.read_csv(Path(source_dir).joinpath(file_path), sep=",", index_col=0)[date_field_name].to_list() +def return_date_list(date_field_name: str, file_path: Path): + date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list() return sorted(map(lambda x: pd.Timestamp(x), date_list)) @@ -139,10 +138,13 @@ def get_calendar_list_by_ratio( logger.info(f"count how many funds trade in this day......") _dict_count_trade = dict() # dict{date:count} - _fun = partial(return_date_list, source_dir, date_field_name) + _fun = partial(return_date_list, date_field_name) + all_oldest_list = [] with tqdm(total=_number_all_funds) as p_bar: with ProcessPoolExecutor(max_workers=max_workers) as executor: - for date_list in executor.map(_fun, file_list[:_number_all_funds]): + for date_list in executor.map(_fun, file_list): + if date_list: + all_oldest_list.append(date_list[0]) for date in date_list: if date not in _dict_count_trade.keys(): _dict_count_trade[date] = 0 @@ -154,12 +156,10 @@ def get_calendar_list_by_ratio( logger.info(f"count how many funds have founded in this day......") _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: - with ProcessPoolExecutor(max_workers=max_workers) as executor: - for date_list in executor.map(_fun, file_list[:_number_all_funds]): - oldest_date = sorted(date_list)[0] # this fund haven't found before this day - for date in _dict_count_founding.keys(): - if date < oldest_date: - _dict_count_founding[date] -= 1 + for oldest_date in all_oldest_list: + for date in _dict_count_founding.keys(): + if date < oldest_date: + _dict_count_founding[date] -= 1 calendar = [ date From 598ee875a03753f64491a378d40d1907102f10a3 Mon Sep 17 00:00:00 2001 From: zhupr Date: Mon, 22 Mar 2021 10:29:07 +0800 Subject: [PATCH 14/77] Fix yahoo_collector --- scripts/data_collector/yahoo/collector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index eadc381ec..f0e110694 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -185,7 +185,7 @@ class YahooCollector(BaseCollector): class YahooCollectorCN(YahooCollector, ABC): - def get_stock_list(self): + def get_instrument_list(self): logger.info("get HS stock symbos......") symbols = get_hs_stock_symbols() logger.info(f"get {len(symbols)} symbols.") @@ -249,7 +249,7 @@ class YahooCollectorCN1min(YahooCollectorCN): class YahooCollectorUS(YahooCollector, ABC): - def get_stock_list(self): + def get_instrument_list(self): logger.info("get US stock symbols......") symbols = get_us_stock_symbols() + [ "^GSPC", From 1ad237f89fc5197a6629b8e2df2217dd3e2fb712 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Mon, 22 Mar 2021 14:20:44 +0800 Subject: [PATCH 15/77] update high freq demo --- ...rkflow_config_High_Freq_Tree_Alpha158.yaml | 65 ++++++++ qlib/contrib/eva/alpha.py | 40 +++++ qlib/contrib/model/highfreq_gdbt_model.py | 157 ++++++++++++++++++ qlib/workflow/record_temp.py | 50 +++++- 4 files changed, 311 insertions(+), 1 deletion(-) create mode 100644 examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml create mode 100644 qlib/contrib/model/highfreq_gdbt_model.py diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml new file mode 100644 index 000000000..ca8e92d08 --- /dev/null +++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml @@ -0,0 +1,65 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/yahoo_cn_1min" + region: cn +market: &market ['SH605222', 'SZ002796', 'SZ002246', 'SZ000713', 'SZ002820', 'SH601328', 'SZ000668', 'SH603359', 'SZ002144', 'SH600195', 'SH603685', 'SH603386', 'SZ002586', 'SZ000573', 'SZ000605', 'SZ002842', 'SH600068', 'SZ300547', 'SZ000926', 'SZ002036', 'SZ002161', 'SH600715', 'SZ300427', 'SZ002573', 'SZ300142', 'SH605116', 'SZ002951', 'SH600276', 'SZ002437', 'SH603355', 'SZ002893', 'SH600584'] +start_time: &start_time "2020-09-15 00:00:00" +end_time: &end_time "2021-01-18 16:00:00" +train_end_time: &train_end_time "2020-11-15 16:00:00" +valid_start_time: &valid_start_time "2020-11-16 00:00:00" +valid_end_time: &valid_end_time "2020-11-30 16:00:00" +test_start_time: &test_start_time "2020-12-01 00:00:00" +data_handler_config: &data_handler_config + start_time: *start_time + end_time: *end_time + fit_start_time: *start_time + fit_end_time: *train_end_time + instruments: *market + freq: '1min' + infer_processors: + - class: 'RobustZScoreNorm' + kwargs: + fields_group: 'feature' + clip_outlier: false + - class: "Fillna" + kwargs: + fields_group: 'feature' + learn_processors: + - class: 'DropnaLabel' + - class: 'CSRankNorm' + kwargs: + fields_group: 'label' + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +task: + model: + class: "HF_LGBModel" + module_path: "qlib.contrib.model.highfreq_gdbt_model" + kwargs: + objective: 'binary' + metric: ['binary_logloss','auc'] + verbosity: -1 + learning_rate: 0.01 + max_depth: 8 + num_leaves: 150 + lambda_l1: 1.5 + lambda_l2: 1 + num_threads: 20 + dataset: + class: "DatasetH" + module_path: "qlib.data.dataset" + kwargs: + handler: + class: "Alpha158" + module_path: "qlib.contrib.data.handler" + kwargs: *data_handler_config + segments: + train: [*start_time, *train_end_time] + valid: [*train_end_time, *valid_end_time] + test: [*test_start_time, *end_time] + record: + - class: "SignalRecord" + module_path: "qlib.workflow.record_temp" + kwargs: {} + - class: "HFSignalRecord" + module_path: "qlib.workflow.record_temp" + kwargs: {} \ No newline at end of file diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index c68571853..e2beafc13 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -7,6 +7,46 @@ import pandas as pd from typing import Tuple +def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False) -> Tuple[pd.Series, pd.Series]: + """ calculate the precision + pred : + pred + label : + label + date_col : + date_col + + Returns + ------- + (pd.Series, pd.Series) + long precision and short precision in time level + """ + if is_alpha: + label = label - label.mean(level=0) + if int(1/quantile) >= len(label.index.get_level_values(1).unique()): + raise ValueError("Need more instruments to calculate precision") + + + df = pd.DataFrame({"pred": pred, "label": label}) + if dropna: + df.dropna(inplace = True) + + group = df.groupby(level=date_col) + + N = lambda x: int(len(x) * quantile) + # find the top/low quantile of prediction and treat them as long and short target + long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True) + short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) + + groupll = long.groupby(date_col) + ll_ration = groupll.apply(lambda x: x > 0) + ll_c = groupll.count() + + groups = short.groupby(date_col) + s_ration = groups.apply(lambda x: x < 0) + s_c = groups.count() + return (ll_ration.groupby(date_col).sum()/ll_c), (s_ration.groupby(date_col).sum()/s_c) + def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: """calc_ic. diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py new file mode 100644 index 000000000..62e45c841 --- /dev/null +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +import pandas as pd +import lightgbm as lgb + +from qlib.model.base import ModelFT +from qlib.data.dataset import DatasetH +from qlib.data.dataset.handler import DataHandlerLP +import warnings + + +class HF_LGBModel(ModelFT): + """LightGBM Model""" + + def __init__(self, loss="mse", **kwargs): + if loss not in {"mse", "binary"}: + raise NotImplementedError + self.params = {"objective": loss, "verbosity": -1} + self.params.update(kwargs) + self.model = None + + def _cal_signal_metrics(self, y_test, l_cut, r_cut): + """ + Calcaute the signal metrics by daily level + """ + up_pre, down_pre = [], [] + up_alpha_ll, down_alpha_ll = [], [] + for date in y_test.index.get_level_values(0).unique(): + df_res = y_test.loc[date].sort_values("pred") + if int(l_cut * len(df_res)) < 10: + warnings.warn("Warning: threhold is too low or instruments number is not enough") + continue + top = df_res.iloc[: int(l_cut * len(df_res))] + bottom = df_res.iloc[int(r_cut * len(df_res)) :] + + down_precision = len(top[top[top.columns[0]] < 0]) / (len(top)) + up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom)) + + down_alpha = top[top.columns[0]].mean() + up_alpha = bottom[bottom.columns[0]].mean() + + up_pre.append(up_precision) + down_pre.append(down_precision) + up_alpha_ll.append(up_alpha) + down_alpha_ll.append(down_alpha) + + return ( + np.array(up_pre).mean(), + np.array(down_pre).mean(), + np.array(up_alpha_ll).mean(), + np.array(down_alpha_ll).mean(), + ) + + def hf_signal_test(self, dataset: DatasetH, threhold=0.2): + """ + Test the sigal in high frequency test set + """ + if self.model == None: + raise ValueError("Model hasn't been trained yet") + df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + df_test.dropna(inplace=True) + x_test, y_test = df_test["feature"], df_test["label"] + # Convert label into alpha + y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0) + + res = pd.Series(self.model.predict(x_test.values), index=x_test.index) + y_test["pred"] = res + + up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold) + print("===============================") + print("High frequency signal test") + print("===============================") + print("Test set precision: ") + print("Positive precision: {}, Negative precision: {}".format(up_p, down_p)) + print("Test Alpha Average in test set: ") + print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a)) + + def _prepare_data(self, dataset: DatasetH): + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + + x_train, y_train = df_train["feature"], df_train["label"] + x_valid, y_valid = df_train["feature"], df_valid["label"] + if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: + l_name = df_train["label"].columns[0] + # Convert label into alpha + df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0) + df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0) + mapping_fn = lambda x: 0 if x < 0 else 1 + df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn) + df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn) + x_train, y_train = df_train["feature"], df_train["label_c"].values + x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values + else: + raise ValueError("LightGBM doesn't support multi-label training") + + dtrain = lgb.Dataset(x_train.values, label=y_train) + dvalid = lgb.Dataset(x_valid.values, label=y_valid) + return dtrain, dvalid + + def fit( + self, + dataset: DatasetH, + num_boost_round=1000, + early_stopping_rounds=50, + verbose_eval=20, + evals_result=dict(), + **kwargs + ): + dtrain, dvalid = self._prepare_data(dataset) + self.model = lgb.train( + self.params, + dtrain, + num_boost_round=num_boost_round, + valid_sets=[dtrain, dvalid], + valid_names=["train", "valid"], + early_stopping_rounds=early_stopping_rounds, + verbose_eval=verbose_eval, + evals_result=evals_result, + **kwargs + ) + evals_result["train"] = list(evals_result["train"].values())[0] + evals_result["valid"] = list(evals_result["valid"].values())[0] + + def predict(self, dataset): + if self.model is None: + raise ValueError("model is not fitted yet!") + x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) + return pd.Series(self.model.predict(x_test.values), index=x_test.index) + + def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20): + """ + finetune model + + Parameters + ---------- + dataset : DatasetH + dataset for finetuning + num_boost_round : int + number of round to finetune model + verbose_eval : int + verbose level + """ + # Based on existing model and finetune by train more rounds + dtrain, _ = self._prepare_data(dataset) + self.model = lgb.train( + self.params, + dtrain, + num_boost_round=num_boost_round, + init_model=self.model, + valid_sets=[dtrain], + valid_names=["train"], + verbose_eval=verbose_eval, + ) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 2c1b6fecc..8ab8405a5 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -13,7 +13,7 @@ from ..data.dataset.handler import DataHandlerLP from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict -from ..contrib.eva.alpha import calc_ic, calc_long_short_return +from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_prec from ..contrib.strategy.strategy import BaseStrategy logger = get_module_logger("workflow", "INFO") @@ -154,6 +154,54 @@ class SignalRecord(RecordTemp): def load(self, name="pred.pkl"): return super().load(name) + + +class HFSignalRecord(SignalRecord): + """ + This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class. + """ + artifact_path = "hg_sig_analysis" + + def __init__(self, recorder, **kwargs): + super().__init__(recorder=recorder) + + def generate(self): + pred = self.load("pred.pkl") + raw_label = self.load("label.pkl") + + long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha = True) + ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) + metrics = { + "IC": ic.mean(), + "ICIR": ic.mean() / ic.std(), + "Rank IC": ric.mean(), + "Rank ICIR": ric.mean() / ric.std(), + "Long precision": long_pre.mean(), + "Short precision": short_pre.mean() + } + objects = {"ic.pkl": ic, "ric.pkl": ric} + objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre}) + long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0]) + metrics.update( + { + "Long-Short Average Return": long_short_r.mean(), + "Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(), + } + ) + objects.update( + { + "long_short_r.pkl": long_short_r, + "long_avg_r.pkl": long_avg_r, + } + ) + self.recorder.log_metrics(**metrics) + self.recorder.save_objects(**objects, artifact_path=self.get_path()) + pprint(metrics) + + def list(self): + paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl")] + paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) + return paths class SigAnaRecord(SignalRecord): From 3bf6c7f95f5cc77d4025358e618d5f688138f5cc Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Mon, 22 Mar 2021 15:37:54 +0800 Subject: [PATCH 16/77] update format --- qlib/contrib/eva/alpha.py | 24 +++++++++++++----------- qlib/workflow/record_temp.py | 16 +++++++++++----- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index e2beafc13..8078dd4ed 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -7,15 +7,18 @@ import pandas as pd from typing import Tuple -def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False) -> Tuple[pd.Series, pd.Series]: - """ calculate the precision + +def calc_prec( + pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False +) -> Tuple[pd.Series, pd.Series]: + """calculate the precision pred : pred label : label date_col : date_col - + Returns ------- (pd.Series, pd.Series) @@ -23,29 +26,28 @@ def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: """ if is_alpha: label = label - label.mean(level=0) - if int(1/quantile) >= len(label.index.get_level_values(1).unique()): + if int(1 / quantile) >= len(label.index.get_level_values(1).unique()): raise ValueError("Need more instruments to calculate precision") - df = pd.DataFrame({"pred": pred, "label": label}) if dropna: - df.dropna(inplace = True) - + df.dropna(inplace=True) + group = df.groupby(level=date_col) - + N = lambda x: int(len(x) * quantile) # find the top/low quantile of prediction and treat them as long and short target long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True) short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) - + groupll = long.groupby(date_col) ll_ration = groupll.apply(lambda x: x > 0) ll_c = groupll.count() - + groups = short.groupby(date_col) s_ration = groups.apply(lambda x: x < 0) s_c = groups.count() - return (ll_ration.groupby(date_col).sum()/ll_c), (s_ration.groupby(date_col).sum()/s_c) + return (ll_ration.groupby(date_col).sum() / ll_c), (s_ration.groupby(date_col).sum() / s_c) def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 8ab8405a5..c47b999f3 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -154,12 +154,13 @@ class SignalRecord(RecordTemp): def load(self, name="pred.pkl"): return super().load(name) - - + + class HFSignalRecord(SignalRecord): """ This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class. """ + artifact_path = "hg_sig_analysis" def __init__(self, recorder, **kwargs): @@ -169,7 +170,7 @@ class HFSignalRecord(SignalRecord): pred = self.load("pred.pkl") raw_label = self.load("label.pkl") - long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha = True) + long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) metrics = { "IC": ic.mean(), @@ -177,7 +178,7 @@ class HFSignalRecord(SignalRecord): "Rank IC": ric.mean(), "Rank ICIR": ric.mean() / ric.std(), "Long precision": long_pre.mean(), - "Short precision": short_pre.mean() + "Short precision": short_pre.mean(), } objects = {"ic.pkl": ic, "ric.pkl": ric} objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre}) @@ -199,7 +200,12 @@ class HFSignalRecord(SignalRecord): pprint(metrics) def list(self): - paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl")] + paths = [ + self.get_path("ic.pkl"), + self.get_path("ric.pkl"), + self.get_path("long_pre.pkl"), + self.get_path("short_pre.pkl"), + ] paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) return paths From c6b67cb8fe89fbe71759f0a91f2cf229625f7cd1 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 22 Mar 2021 18:37:13 +0800 Subject: [PATCH 17/77] fix doc --- docs/component/data.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/component/data.rst b/docs/component/data.rst index ba7979e23..89cc918c1 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -298,9 +298,9 @@ Here are some important interfaces that ``DataHandlerLP`` provides: .. autoclass:: qlib.data.dataset.handler.DataHandlerLP :members: __init__, fetch, get_cols -If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass. +If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``. -If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`. +Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess method for features defined by config into the new handler. Processor @@ -337,7 +337,6 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h .. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_. - .. code-block:: Python import qlib @@ -364,6 +363,7 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h # fetch all the features print(h.fetch(col_set="feature")) + API --------- From 7370d5af9e7f6d24fba90597a3e3097e21820c1a Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 22 Mar 2021 18:37:44 +0800 Subject: [PATCH 18/77] add label doc --- docs/component/data.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/component/data.rst b/docs/component/data.rst index 89cc918c1..ce639d8fa 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -363,6 +363,7 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h # fetch all the features print(h.fetch(col_set="feature")) +..note :: In the ``Alpha158``, ``Qlib`` use the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day. API --------- From 4b56a4e907ae956995aaa8badc616c592d4d1b7c Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 22 Mar 2021 18:45:27 +0800 Subject: [PATCH 19/77] fix doc --- docs/component/data.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/component/data.rst b/docs/component/data.rst index ce639d8fa..26f44a076 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -298,9 +298,10 @@ Here are some important interfaces that ``DataHandlerLP`` provides: .. autoclass:: qlib.data.dataset.handler.DataHandlerLP :members: __init__, fetch, get_cols + If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``. -Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess method for features defined by config into the new handler. +Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler. Processor @@ -363,7 +364,8 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h # fetch all the features print(h.fetch(col_set="feature")) -..note :: In the ``Alpha158``, ``Qlib`` use the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day. + +.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day. API --------- From 0a0c6a3185ac6bcec38b756f039b9ccc64b41827 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Tue, 23 Mar 2021 10:10:17 +0000 Subject: [PATCH 20/77] Add load_object function for R --- qlib/workflow/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 3d787562e..678ae99a8 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -416,6 +416,12 @@ class QlibRecorder: """ self.get_exp().get_recorder().save_objects(local_path, artifact_path, **kwargs) + def load_object(self, name: Text): + """ + Method for loading an object from artifacts in the experiment in the uri. + """ + return self.get_exp().get_recorder().load_object(name) + def log_params(self, **kwargs): """ Method for logging parameters during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API. From e490e83a163d00d9304554e356790359b8495d5a Mon Sep 17 00:00:00 2001 From: Flouse Date: Wed, 24 Mar 2021 11:37:09 +0800 Subject: [PATCH 21/77] fix docs --- qlib/contrib/strategy/strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/contrib/strategy/strategy.py b/qlib/contrib/strategy/strategy.py index 550ff649d..4f8eb0ab1 100644 --- a/qlib/contrib/strategy/strategy.py +++ b/qlib/contrib/strategy/strategy.py @@ -251,7 +251,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer): def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date): """ - Gnererate order list according to score_series at trade_date, will not change current. + Generate order list according to score_series at trade_date, will not change current. Parameters ----------- From e3739bb980b5347520a13fd510bf9bf7180c8905 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 24 Mar 2021 15:47:26 +0800 Subject: [PATCH 22/77] fix naming and code style --- ...rkflow_config_High_Freq_Tree_Alpha158.yaml | 2 +- qlib/contrib/eva/alpha.py | 29 +++++++++++++------ qlib/contrib/model/highfreq_gdbt_model.py | 4 +-- qlib/workflow/record_temp.py | 8 ++--- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml index ca8e92d08..c21ef1da3 100644 --- a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml +++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml @@ -32,7 +32,7 @@ data_handler_config: &data_handler_config task: model: - class: "HF_LGBModel" + class: "HFLGBModel" module_path: "qlib.contrib.model.highfreq_gdbt_model" kwargs: objective: 'binary' diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index 8078dd4ed..fadef9d16 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -8,12 +8,23 @@ import pandas as pd from typing import Tuple -def calc_prec( +def calc_long_short_prec( pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False ) -> Tuple[pd.Series, pd.Series]: - """calculate the precision - pred : - pred + """ + calculate the precision for long and short operation + + + :param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**. + + .. code-block:: python + score + datetime instrument + 2020-12-01 09:30:00 SH600068 0.553634 + SH600195 0.550017 + SH600276 0.540321 + SH600584 0.517297 + SH600715 0.544674 label : label date_col : @@ -25,7 +36,7 @@ def calc_prec( long precision and short precision in time level """ if is_alpha: - label = label - label.mean(level=0) + label = label - label.mean(level=date_col) if int(1 / quantile) >= len(label.index.get_level_values(1).unique()): raise ValueError("Need more instruments to calculate precision") @@ -41,13 +52,13 @@ def calc_prec( short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) groupll = long.groupby(date_col) - ll_ration = groupll.apply(lambda x: x > 0) - ll_c = groupll.count() + l_dom = groupll.apply(lambda x: x > 0) + l_c = groupll.count() groups = short.groupby(date_col) - s_ration = groups.apply(lambda x: x < 0) + s_dom = groups.apply(lambda x: x < 0) s_c = groups.count() - return (ll_ration.groupby(date_col).sum() / ll_c), (s_ration.groupby(date_col).sum() / s_c) + return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c) def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py index 62e45c841..5a2eeb50a 100644 --- a/qlib/contrib/model/highfreq_gdbt_model.py +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -11,8 +11,8 @@ from qlib.data.dataset.handler import DataHandlerLP import warnings -class HF_LGBModel(ModelFT): - """LightGBM Model""" +class HFLGBModel(ModelFT): + """LightGBM Model for high frequency prediction""" def __init__(self, loss="mse", **kwargs): if loss not in {"mse", "binary"}: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index c47b999f3..239527fa0 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -13,7 +13,7 @@ from ..data.dataset.handler import DataHandlerLP from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict -from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_prec +from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec from ..contrib.strategy.strategy import BaseStrategy logger = get_module_logger("workflow", "INFO") @@ -169,8 +169,7 @@ class HFSignalRecord(SignalRecord): def generate(self): pred = self.load("pred.pkl") raw_label = self.load("label.pkl") - - long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) + long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) metrics = { "IC": ic.mean(), @@ -205,8 +204,9 @@ class HFSignalRecord(SignalRecord): self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl"), + self.get_path("long_short_r.pkl"), + self.get_path("long_avg_r.pkl"), ] - paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) return paths From 1ca3c6a61c11cff9adf79b1657af555cf68a365a Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 01:29:59 +0800 Subject: [PATCH 23/77] add DataHandlerDL --- qlib/data/dataset/loader.py | 58 +++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 921bf01c5..faabe2c02 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -217,3 +217,61 @@ class StaticDataLoader(DataLoader): join=self.join, ) self._data.sort_index(inplace=True) + +class DataHandlerDL(DataLoader): + '''DataHandlerDL + DataHandler-based (D)ata (L)oader + It is designed to load multiple data from data handler + - If you just want to load data from single datahandler, you can write them in single data handler + ''' + + def __init__(self, handler_config:dict, fetch_config:dict = {}, is_group=False): + """ + Parameters + ---------- + handler_config : dict + handler_config will be used to describe the handlers + + .. code-block:: + + := { + "group_name1": + "group_name2": + } + or + := + := DataHandler Instance | DataHandler Config + + fetch_config : dict + fetch_config will be used to describe the different arguments of fetch method, such as squeeze, data_key, etc. + + is_group: bool + is_group will be used to describe whether the key of handler_config is group + + """ + if self.is_group: + self.handlers = { + grp: init_instance_by_config(config, accept_types=DataHandler) + for grp, config in handler_config.items() + } + else: + self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler) + + self.is_group = is_group + self.fetch_config = fetch_config + + def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: + if instruments is not None: + LOG.warning(f"instruments[{instruments}] is ignored") + + if self.is_group: + df = pd.concat( + { + grp: dh.fetch(slice(start_time, end_time), col_set=DataHandler.CS_RAW, **fetch_config) + for grp, dh in self.handlers.items() + }, + axis=1, + ) + else: + df = self.handler.fetch(slice(start_time, end_time), col_set=DataHandler.CS_RAW, **fetch_config) + return df From b1a28358adb9b9e15abd09fe59f7ff4544e399ed Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 01:30:31 +0800 Subject: [PATCH 24/77] black format --- qlib/data/dataset/loader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index faabe2c02..884d15635 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -218,14 +218,15 @@ class StaticDataLoader(DataLoader): ) self._data.sort_index(inplace=True) + class DataHandlerDL(DataLoader): - '''DataHandlerDL + """DataHandlerDL DataHandler-based (D)ata (L)oader It is designed to load multiple data from data handler - If you just want to load data from single datahandler, you can write them in single data handler - ''' + """ - def __init__(self, handler_config:dict, fetch_config:dict = {}, is_group=False): + def __init__(self, handler_config: dict, fetch_config: dict = {}, is_group=False): """ Parameters ---------- @@ -251,12 +252,11 @@ class DataHandlerDL(DataLoader): """ if self.is_group: self.handlers = { - grp: init_instance_by_config(config, accept_types=DataHandler) - for grp, config in handler_config.items() + grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items() } else: self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler) - + self.is_group = is_group self.fetch_config = fetch_config From 1fcfe8e4ba6e655ba59ae95180c491ea3fe85c8e Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 01:37:17 +0800 Subject: [PATCH 25/77] add rolling process data --- examples/rolling_process_data/README.md | 2 ++ examples/rolling_process_data/workflow.py | 0 2 files changed, 2 insertions(+) create mode 100644 examples/rolling_process_data/README.md create mode 100644 examples/rolling_process_data/workflow.py diff --git a/examples/rolling_process_data/README.md b/examples/rolling_process_data/README.md new file mode 100644 index 000000000..3f1c8768d --- /dev/null +++ b/examples/rolling_process_data/README.md @@ -0,0 +1,2 @@ +# Rolling Process Data + diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py new file mode 100644 index 000000000..e69de29bb From f6dc25b22982d5e80b4cd2f9c2fc823ed98d244b Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 16:14:22 +0800 Subject: [PATCH 26/77] update rolling process --- examples/highfreq/workflow.py | 1 - .../rolling_process_data/rolling_handler.py | 34 ++++ examples/rolling_process_data/workflow.py | 145 ++++++++++++++++++ qlib/data/dataset/handler.py | 2 +- qlib/data/dataset/loader.py | 21 +-- 5 files changed, 192 insertions(+), 11 deletions(-) create mode 100644 examples/rolling_process_data/rolling_handler.py diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 01de59c0e..c2ca36db3 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -32,7 +32,6 @@ class HighfreqWorkflow(object): SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None} MARKET = "all" - BENCHMARK = "SH000300" start_time = "2020-09-15 00:00:00" end_time = "2021-01-18 16:00:00" diff --git a/examples/rolling_process_data/rolling_handler.py b/examples/rolling_process_data/rolling_handler.py new file mode 100644 index 000000000..50a36f219 --- /dev/null +++ b/examples/rolling_process_data/rolling_handler.py @@ -0,0 +1,34 @@ +from qlib.data.dataset.handler import DataHandlerLP +from qlib.data.dataset.loader import DataLoaderDH +from qlib.contrib.data.handler import check_transform_proc + + +class RollingDataHandler(DataHandlerLP): + def __init__( + self, + start_time=None, + end_time=None, + infer_processors=[], + learn_processors=[], + fit_start_time=None, + fit_end_time=None, + data_loader_kwargs={} + ): + infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) + learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + + data_loader = { + "class": "DataLoaderDH", + "kwargs": { + **data_loader_kwargs + }, + } + + super().__init__( + instruments=None, + start_time=start_time, + end_time=end_time, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors, + ) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index e69de29bb..8581f149b 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -0,0 +1,145 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import qlib +import pickle +import datetime +import pandas as pd +from qlib.config import REG_CN +from qlib.data.dataset.handler import DataHandlerLP +from qlib.contrib.data.handler import Alpha158 +from qlib.utils import exists_qlib_data, init_instance_by_config +from qlib.tests.data import GetData + +class RollingDataWorkflow(object): + + MARKET = "csi300" + + start_time = "2010-01-01" + end_time = "2019-12-31" + rolling_cnt = 5 + + def _init_qlib(self): + """initialize qlib""" + # use yahoo_cn_1min data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + qlib.init(provider_uri=provider_uri, region=REG_CN) + + def _dump_pre_handler(self, path): + handler_config = { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": { + "start_time": start_time, + "end_time": end_time, + "instruments": MARKET, + }, + } + pre_handler = init_instance_by_config(handler_config) + pre_handler.to_pickle(path) + + def _load_pre_handler(self, path): + with open(path, "rb") as file_dataset: + pre_handler = pickle.load(file_dataset) + return pre_handler + + def rolling_process(self): + self._init_qlib() + self._dump_pre_handler("pre_handler.py") + pre_handler = self._load_pre_handler("pre_handler.py") + + init_start_time = datetime.datetime(2010,1,1) + init_end_time = datetime.datetime(2014,12,31) + init_fit_end_time = datetime.datetime(2012,12,31) + + dataset_config = { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "RollingDataHandler", + "module_path": "rolling_handler", + "kwargs": { + "start_time": init_start_time, + "end_time": init_start_time, + "fit_start_time": init_fit_start_time, + "fit_end_time": init_fit_end_time, + "data_loader_kwargs":{ + "handler_config": pre_handler, + } + }, + }, + "segments": { + "train": (init_start_time, init_fit_end_time), + "valid": (init_start_time, "2013-12-31"), + "test": (init_start_time, init_end_time), + }, + }, + } + + dataset = init_instance_by_config(dataset_config) + + for rolling_offset in range(rolling_cnt): + if rolling_offset: + dataset.init( + handler_kwargs={ + "init_type": DataHandlerLP.IT_FIT_IND, + "start_time": "2021-01-19 00:00:00", + "end_time": "2021-01-25 16:00:00", + }, + segment_kwargs={ + "train": ("2010-01-01", "2012-12-31"), + "valid": ("2013-01-01", "2013-12-31"), + "test": ("2014-01-01", "2014-12-31"), + }, + ) + + +if __name__ == "__main__": + + # use default data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + + qlib.init(provider_uri=provider_uri, region=REG_CN) + + market = "csi300" + benchmark = "SH000300" + + ################################### + # train model + ################################### + data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, + } + + task = { + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + }, + } + + dataset = init_instance_by_config(task["dataset"]) + diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 050043ba6..f4795c566 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -16,7 +16,7 @@ from ...data import D from ...config import C from ...utils import parse_config, transform_end_date, init_instance_by_config from ...utils.serial import Serializable -from .utils import get_level_index, fetch_df_by_index +from .utils import fetch_df_by_index from pathlib import Path from .loader import DataLoader diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 884d15635..f88aaf05e 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -219,14 +219,14 @@ class StaticDataLoader(DataLoader): self._data.sort_index(inplace=True) -class DataHandlerDL(DataLoader): - """DataHandlerDL - DataHandler-based (D)ata (L)oader +class DataLoaderDH(DataLoader): + """DataLoaderDH + DataLoader based on (D)ata (H)andler It is designed to load multiple data from data handler - If you just want to load data from single datahandler, you can write them in single data handler """ - def __init__(self, handler_config: dict, fetch_config: dict = {}, is_group=False): + def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False): """ Parameters ---------- @@ -243,8 +243,8 @@ class DataHandlerDL(DataLoader): := := DataHandler Instance | DataHandler Config - fetch_config : dict - fetch_config will be used to describe the different arguments of fetch method, such as squeeze, data_key, etc. + fetch_kwargs : dict + fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc. is_group: bool is_group will be used to describe whether the key of handler_config is group @@ -258,7 +258,10 @@ class DataHandlerDL(DataLoader): self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler) self.is_group = is_group - self.fetch_config = fetch_config + self.fetch_kwargs = { + "col_set":DataHandler.CS_RAW + } + self.fetch_kwargs = {**self.fetch_kwargs, **fetch_kwargs} def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: if instruments is not None: @@ -267,11 +270,11 @@ class DataHandlerDL(DataLoader): if self.is_group: df = pd.concat( { - grp: dh.fetch(slice(start_time, end_time), col_set=DataHandler.CS_RAW, **fetch_config) + grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs) for grp, dh in self.handlers.items() }, axis=1, ) else: - df = self.handler.fetch(slice(start_time, end_time), col_set=DataHandler.CS_RAW, **fetch_config) + df = self.handler.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs) return df From 834f9bd9b860b3bcbb67d81d2c706797c748db39 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Thu, 25 Mar 2021 16:58:35 +0800 Subject: [PATCH 27/77] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 3603818a8..e78ffe751 100644 --- a/README.md +++ b/README.md @@ -243,6 +243,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu - Rank Label ![Rank Label](docs/_static/img/rank_label.png) --> + - [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results ## Building Customized Quant Research Workflow by Code The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code. From 4861552d281da094e932f3b11feab6bd21728139 Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 25 Mar 2021 17:13:52 +0800 Subject: [PATCH 28/77] Update notebook --- examples/workflow_by_code.ipynb | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb index 5a992e339..1dda1c621 100644 --- a/examples/workflow_by_code.ipynb +++ b/examples/workflow_by_code.ipynb @@ -28,11 +28,17 @@ "import sys, site\n", "from pathlib import Path\n", "\n", + "################################# NOTE #################################\n", + "# Please be aware that if colab installs the latest numpy and pyqlib #\n", + "# in this cell, users should RESTART the runtime in order to run the #\n", + "# following cells successfully. #\n", + "########################################################################\n", "\n", "try:\n", " import qlib\n", "except ImportError:\n", " # install qlib\n", + " ! pip install --upgrade numpy\n", " ! pip install pyqlib\n", " # reload\n", " site.main()\n", @@ -238,9 +244,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "from qlib.contrib.report import analysis_model, analysis_position\n", @@ -359,7 +363,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.9" + "version": "3.8.3" }, "toc": { "base_numbering": 1, @@ -377,4 +381,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file From 4ec300787efc87900db522145f43e20d52402bc1 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 19:54:52 +0800 Subject: [PATCH 29/77] update rolling workflow --- examples/rolling_process_data/workflow.py | 49 +++++++++++++---------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 8581f149b..62523aefd 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -3,8 +3,9 @@ import qlib import pickle -import datetime import pandas as pd + +from datetime import datetime from qlib.config import REG_CN from qlib.data.dataset.handler import DataHandlerLP from qlib.contrib.data.handler import Alpha158 @@ -14,7 +15,6 @@ from qlib.tests.data import GetData class RollingDataWorkflow(object): MARKET = "csi300" - start_time = "2010-01-01" end_time = "2019-12-31" rolling_cnt = 5 @@ -33,9 +33,9 @@ class RollingDataWorkflow(object): "class": "Alpha158", "module_path": "qlib.contrib.data.handler", "kwargs": { - "start_time": start_time, - "end_time": end_time, - "instruments": MARKET, + "start_time": self.start_time, + "end_time": self.end_time, + "instruments": self.MARKET, }, } pre_handler = init_instance_by_config(handler_config) @@ -51,10 +51,13 @@ class RollingDataWorkflow(object): self._dump_pre_handler("pre_handler.py") pre_handler = self._load_pre_handler("pre_handler.py") - init_start_time = datetime.datetime(2010,1,1) - init_end_time = datetime.datetime(2014,12,31) - init_fit_end_time = datetime.datetime(2012,12,31) - + train_start_time = (2010,1,1) + train_end_time = (2012,12,31) + valid_start_time = (2013,1,1) + valid_end_time = (2013,12,31) + test_start_time = (2014,1,1) + test_end_time = (2014,12,31) + dataset_config = { "class": "DatasetH", "module_path": "qlib.data.dataset", @@ -63,19 +66,19 @@ class RollingDataWorkflow(object): "class": "RollingDataHandler", "module_path": "rolling_handler", "kwargs": { - "start_time": init_start_time, - "end_time": init_start_time, - "fit_start_time": init_fit_start_time, - "fit_end_time": init_fit_end_time, + "start_time": datetime(*train_start_time), + "end_time": datetime(*test_end_time), + "fit_start_time": datetime(*train_start_time), + "fit_end_time": datetime(*train_end_time), "data_loader_kwargs":{ "handler_config": pre_handler, } }, }, "segments": { - "train": (init_start_time, init_fit_end_time), - "valid": (init_start_time, "2013-12-31"), - "test": (init_start_time, init_end_time), + "train": (datetime(*train_start_time), datetime(*train_end_time)), + "valid": (datetime(*valid_start_time), datetime(*valid_end_time)), + "test": (datetime(*test_start_time), datetime(*test_end_time)), }, }, } @@ -86,17 +89,19 @@ class RollingDataWorkflow(object): if rolling_offset: dataset.init( handler_kwargs={ - "init_type": DataHandlerLP.IT_FIT_IND, - "start_time": "2021-01-19 00:00:00", - "end_time": "2021-01-25 16:00:00", + "init_type": DataHandlerLP.IT_FIT_SEQ, + "start_time": datetime(train_start_time[0] + 1, *train_start_time[1:]), + "end_time": datetime(test_end_time[0] + 1, *test_end_time[1:]), }, segment_kwargs={ - "train": ("2010-01-01", "2012-12-31"), - "valid": ("2013-01-01", "2013-12-31"), - "test": ("2014-01-01", "2014-12-31"), + "train": (datetime(train_start_time[0] + 1, *train_start_time[1:]), datetime(train_end_time[0], *train_end_time[1:])), + "valid": (datetime(valid_start_time[0] + 1, *valid_start_time[1:]), datetime(valid_end_time[0], *valid_end_time[1:])), + "test": (datetime(test_start_time[0] + 1, *test_start_time[1:]), datetime(test_end_time[0], *test_end_time[1:])), }, ) + dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + if __name__ == "__main__": From efe134e9f4f5445055f9c1cd30576bf5f6b42217 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 19:56:04 +0800 Subject: [PATCH 30/77] update workflow --- examples/rolling_process_data/rolling_handler.py | 8 +++----- examples/rolling_process_data/workflow.py | 2 +- qlib/data/dataset/loader.py | 4 +--- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/examples/rolling_process_data/rolling_handler.py b/examples/rolling_process_data/rolling_handler.py index 50a36f219..13b399afd 100644 --- a/examples/rolling_process_data/rolling_handler.py +++ b/examples/rolling_process_data/rolling_handler.py @@ -12,17 +12,15 @@ class RollingDataHandler(DataHandlerLP): learn_processors=[], fit_start_time=None, fit_end_time=None, - data_loader_kwargs={} + data_loader_kwargs={}, ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) data_loader = { "class": "DataLoaderDH", - "kwargs": { - **data_loader_kwargs - }, - } + "kwargs": {**data_loader_kwargs}, + } super().__init__( instruments=None, diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 62523aefd..9b61af47e 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -101,7 +101,7 @@ class RollingDataWorkflow(object): ) dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) - + if __name__ == "__main__": diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index f88aaf05e..539b930ec 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -258,9 +258,7 @@ class DataLoaderDH(DataLoader): self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler) self.is_group = is_group - self.fetch_kwargs = { - "col_set":DataHandler.CS_RAW - } + self.fetch_kwargs = {"col_set": DataHandler.CS_RAW} self.fetch_kwargs = {**self.fetch_kwargs, **fetch_kwargs} def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: From a04c6bd6c941027d1beab07d65be8712d41e2406 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 19:56:22 +0800 Subject: [PATCH 31/77] balck format --- examples/rolling_process_data/workflow.py | 43 ++++++++++++++--------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 9b61af47e..9dd4285da 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -12,11 +12,12 @@ from qlib.contrib.data.handler import Alpha158 from qlib.utils import exists_qlib_data, init_instance_by_config from qlib.tests.data import GetData + class RollingDataWorkflow(object): MARKET = "csi300" start_time = "2010-01-01" - end_time = "2019-12-31" + end_time = "2019-12-31" rolling_cnt = 5 def _init_qlib(self): @@ -27,7 +28,7 @@ class RollingDataWorkflow(object): print(f"Qlib data is not found in {provider_uri}") GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) - + def _dump_pre_handler(self, path): handler_config = { "class": "Alpha158", @@ -51,13 +52,13 @@ class RollingDataWorkflow(object): self._dump_pre_handler("pre_handler.py") pre_handler = self._load_pre_handler("pre_handler.py") - train_start_time = (2010,1,1) - train_end_time = (2012,12,31) - valid_start_time = (2013,1,1) - valid_end_time = (2013,12,31) - test_start_time = (2014,1,1) - test_end_time = (2014,12,31) - + train_start_time = (2010, 1, 1) + train_end_time = (2012, 12, 31) + valid_start_time = (2013, 1, 1) + valid_end_time = (2013, 12, 31) + test_start_time = (2014, 1, 1) + test_end_time = (2014, 12, 31) + dataset_config = { "class": "DatasetH", "module_path": "qlib.data.dataset", @@ -70,9 +71,9 @@ class RollingDataWorkflow(object): "end_time": datetime(*test_end_time), "fit_start_time": datetime(*train_start_time), "fit_end_time": datetime(*train_end_time), - "data_loader_kwargs":{ + "data_loader_kwargs": { "handler_config": pre_handler, - } + }, }, }, "segments": { @@ -94,14 +95,23 @@ class RollingDataWorkflow(object): "end_time": datetime(test_end_time[0] + 1, *test_end_time[1:]), }, segment_kwargs={ - "train": (datetime(train_start_time[0] + 1, *train_start_time[1:]), datetime(train_end_time[0], *train_end_time[1:])), - "valid": (datetime(valid_start_time[0] + 1, *valid_start_time[1:]), datetime(valid_end_time[0], *valid_end_time[1:])), - "test": (datetime(test_start_time[0] + 1, *test_start_time[1:]), datetime(test_end_time[0], *test_end_time[1:])), + "train": ( + datetime(train_start_time[0] + 1, *train_start_time[1:]), + datetime(train_end_time[0], *train_end_time[1:]), + ), + "valid": ( + datetime(valid_start_time[0] + 1, *valid_start_time[1:]), + datetime(valid_end_time[0], *valid_end_time[1:]), + ), + "test": ( + datetime(test_start_time[0] + 1, *test_start_time[1:]), + datetime(test_end_time[0], *test_end_time[1:]), + ), }, ) - dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) - + dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + if __name__ == "__main__": @@ -147,4 +157,3 @@ if __name__ == "__main__": } dataset = init_instance_by_config(task["dataset"]) - From 68246b3b6d7037f3134ceb6e59aef869e96f1d8f Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 19:58:55 +0800 Subject: [PATCH 32/77] update workflow --- examples/rolling_process_data/workflow.py | 87 +++++------------------ 1 file changed, 18 insertions(+), 69 deletions(-) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 9dd4285da..2f48662bd 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import qlib +import fire import pickle import pandas as pd @@ -12,12 +13,11 @@ from qlib.contrib.data.handler import Alpha158 from qlib.utils import exists_qlib_data, init_instance_by_config from qlib.tests.data import GetData - class RollingDataWorkflow(object): MARKET = "csi300" start_time = "2010-01-01" - end_time = "2019-12-31" + end_time = "2019-12-31" rolling_cnt = 5 def _init_qlib(self): @@ -28,7 +28,7 @@ class RollingDataWorkflow(object): print(f"Qlib data is not found in {provider_uri}") GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) - + def _dump_pre_handler(self, path): handler_config = { "class": "Alpha158", @@ -52,13 +52,13 @@ class RollingDataWorkflow(object): self._dump_pre_handler("pre_handler.py") pre_handler = self._load_pre_handler("pre_handler.py") - train_start_time = (2010, 1, 1) - train_end_time = (2012, 12, 31) - valid_start_time = (2013, 1, 1) - valid_end_time = (2013, 12, 31) - test_start_time = (2014, 1, 1) - test_end_time = (2014, 12, 31) - + train_start_time = (2010,1,1) + train_end_time = (2012,12,31) + valid_start_time = (2013,1,1) + valid_end_time = (2013,12,31) + test_start_time = (2014,1,1) + test_end_time = (2014,12,31) + dataset_config = { "class": "DatasetH", "module_path": "qlib.data.dataset", @@ -71,9 +71,9 @@ class RollingDataWorkflow(object): "end_time": datetime(*test_end_time), "fit_start_time": datetime(*train_start_time), "fit_end_time": datetime(*train_end_time), - "data_loader_kwargs": { + "data_loader_kwargs":{ "handler_config": pre_handler, - }, + } }, }, "segments": { @@ -95,65 +95,14 @@ class RollingDataWorkflow(object): "end_time": datetime(test_end_time[0] + 1, *test_end_time[1:]), }, segment_kwargs={ - "train": ( - datetime(train_start_time[0] + 1, *train_start_time[1:]), - datetime(train_end_time[0], *train_end_time[1:]), - ), - "valid": ( - datetime(valid_start_time[0] + 1, *valid_start_time[1:]), - datetime(valid_end_time[0], *valid_end_time[1:]), - ), - "test": ( - datetime(test_start_time[0] + 1, *test_start_time[1:]), - datetime(test_end_time[0], *test_end_time[1:]), - ), + "train": (datetime(train_start_time[0] + 1, *train_start_time[1:]), datetime(train_end_time[0], *train_end_time[1:])), + "valid": (datetime(valid_start_time[0] + 1, *valid_start_time[1:]), datetime(valid_end_time[0], *valid_end_time[1:])), + "test": (datetime(test_start_time[0] + 1, *test_start_time[1:]), datetime(test_end_time[0], *test_end_time[1:])), }, ) - dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) - + dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + if __name__ == "__main__": - - # use default data - provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - - qlib.init(provider_uri=provider_uri, region=REG_CN) - - market = "csi300" - benchmark = "SH000300" - - ################################### - # train model - ################################### - data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, - } - - task = { - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, - } - - dataset = init_instance_by_config(task["dataset"]) + fire.Fire(RollingDataWorkflow) From e119c8576c78f7729364358ce1a3515ca682177a Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 19:59:22 +0800 Subject: [PATCH 33/77] black format --- examples/rolling_process_data/workflow.py | 42 ++++++++++++++--------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 2f48662bd..d5f7fec10 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -13,11 +13,12 @@ from qlib.contrib.data.handler import Alpha158 from qlib.utils import exists_qlib_data, init_instance_by_config from qlib.tests.data import GetData + class RollingDataWorkflow(object): MARKET = "csi300" start_time = "2010-01-01" - end_time = "2019-12-31" + end_time = "2019-12-31" rolling_cnt = 5 def _init_qlib(self): @@ -28,7 +29,7 @@ class RollingDataWorkflow(object): print(f"Qlib data is not found in {provider_uri}") GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) - + def _dump_pre_handler(self, path): handler_config = { "class": "Alpha158", @@ -52,13 +53,13 @@ class RollingDataWorkflow(object): self._dump_pre_handler("pre_handler.py") pre_handler = self._load_pre_handler("pre_handler.py") - train_start_time = (2010,1,1) - train_end_time = (2012,12,31) - valid_start_time = (2013,1,1) - valid_end_time = (2013,12,31) - test_start_time = (2014,1,1) - test_end_time = (2014,12,31) - + train_start_time = (2010, 1, 1) + train_end_time = (2012, 12, 31) + valid_start_time = (2013, 1, 1) + valid_end_time = (2013, 12, 31) + test_start_time = (2014, 1, 1) + test_end_time = (2014, 12, 31) + dataset_config = { "class": "DatasetH", "module_path": "qlib.data.dataset", @@ -71,9 +72,9 @@ class RollingDataWorkflow(object): "end_time": datetime(*test_end_time), "fit_start_time": datetime(*train_start_time), "fit_end_time": datetime(*train_end_time), - "data_loader_kwargs":{ + "data_loader_kwargs": { "handler_config": pre_handler, - } + }, }, }, "segments": { @@ -95,14 +96,23 @@ class RollingDataWorkflow(object): "end_time": datetime(test_end_time[0] + 1, *test_end_time[1:]), }, segment_kwargs={ - "train": (datetime(train_start_time[0] + 1, *train_start_time[1:]), datetime(train_end_time[0], *train_end_time[1:])), - "valid": (datetime(valid_start_time[0] + 1, *valid_start_time[1:]), datetime(valid_end_time[0], *valid_end_time[1:])), - "test": (datetime(test_start_time[0] + 1, *test_start_time[1:]), datetime(test_end_time[0], *test_end_time[1:])), + "train": ( + datetime(train_start_time[0] + 1, *train_start_time[1:]), + datetime(train_end_time[0], *train_end_time[1:]), + ), + "valid": ( + datetime(valid_start_time[0] + 1, *valid_start_time[1:]), + datetime(valid_end_time[0], *valid_end_time[1:]), + ), + "test": ( + datetime(test_start_time[0] + 1, *test_start_time[1:]), + datetime(test_end_time[0], *test_end_time[1:]), + ), }, ) - dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) - + dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + if __name__ == "__main__": fire.Fire(RollingDataWorkflow) From 56eaacd931bf409c0f1719518296d99d11dd6330 Mon Sep 17 00:00:00 2001 From: LewenWang Date: Thu, 25 Mar 2021 20:34:45 +0800 Subject: [PATCH 34/77] debug --- qlib/contrib/model/pytorch_gru_ts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index 2839b35e4..de5e280d0 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -126,8 +126,8 @@ class GRU(Model): num_layers=self.num_layers, dropout=self.dropout, ) - self.logger.info("model:\n{:}".format(self.gru_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model))) + self.logger.info("model:\n{:}".format(self.GRU_model)) + self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GRU_model))) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr) From 9cc3b18e4e9cd61f7745271a01d628063b1b48a3 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 20:36:07 +0800 Subject: [PATCH 35/77] fix but --- examples/rolling_process_data/README.md | 1 - examples/rolling_process_data/workflow.py | 19 ++++++++++++++++--- qlib/data/dataset/loader.py | 6 ++++-- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/examples/rolling_process_data/README.md b/examples/rolling_process_data/README.md index 3f1c8768d..6a6af0d3d 100644 --- a/examples/rolling_process_data/README.md +++ b/examples/rolling_process_data/README.md @@ -1,2 +1 @@ # Rolling Process Data - diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index d5f7fec10..29b1c19f8 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -38,9 +38,12 @@ class RollingDataWorkflow(object): "start_time": self.start_time, "end_time": self.end_time, "instruments": self.MARKET, + "infer_processors": [], + "learn_processors": [], }, } pre_handler = init_instance_by_config(handler_config) + pre_handler.config(dump_all=True) pre_handler.to_pickle(path) def _load_pre_handler(self, path): @@ -50,8 +53,8 @@ class RollingDataWorkflow(object): def rolling_process(self): self._init_qlib() - self._dump_pre_handler("pre_handler.py") - pre_handler = self._load_pre_handler("pre_handler.py") + self._dump_pre_handler("pre_handler.pkl") + pre_handler = self._load_pre_handler("pre_handler.pkl") train_start_time = (2010, 1, 1) train_end_time = (2012, 12, 31) @@ -72,6 +75,13 @@ class RollingDataWorkflow(object): "end_time": datetime(*test_end_time), "fit_start_time": datetime(*train_start_time), "fit_end_time": datetime(*train_end_time), + "infer_processors": [ + {"class":"RobustZScoreNorm", "kwargs": {"fields_group": "feature"}}, + ], + "learn_processors": [ + {"class": "DropnaLabel"}, + {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, + ], "data_loader_kwargs": { "handler_config": pre_handler, }, @@ -87,7 +97,8 @@ class RollingDataWorkflow(object): dataset = init_instance_by_config(dataset_config) - for rolling_offset in range(rolling_cnt): + for rolling_offset in range(self.rolling_cnt): + print(f"===========rolling{rolling_offset} start===========") if rolling_offset: dataset.init( handler_kwargs={ @@ -112,6 +123,8 @@ class RollingDataWorkflow(object): ) dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + ## print or dump data + print(f"===========rolling{rolling_offset} end===========") if __name__ == "__main__": diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 539b930ec..1cda5c025 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -250,7 +250,9 @@ class DataLoaderDH(DataLoader): is_group will be used to describe whether the key of handler_config is group """ - if self.is_group: + from qlib.data.dataset.handler import DataHandler + + if is_group: self.handlers = { grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items() } @@ -274,5 +276,5 @@ class DataLoaderDH(DataLoader): axis=1, ) else: - df = self.handler.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs) + df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs) return df From d6ff764bb270017b74099205dcfb78ade161a9e7 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 20:36:45 +0800 Subject: [PATCH 36/77] black format --- examples/rolling_process_data/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 29b1c19f8..3b38faa31 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -76,7 +76,7 @@ class RollingDataWorkflow(object): "fit_start_time": datetime(*train_start_time), "fit_end_time": datetime(*train_end_time), "infer_processors": [ - {"class":"RobustZScoreNorm", "kwargs": {"fields_group": "feature"}}, + {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}}, ], "learn_processors": [ {"class": "DropnaLabel"}, From 194217fb07696530d5b575567c5bb664d479948d Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 21:47:17 +0800 Subject: [PATCH 37/77] fix bug --- examples/rolling_process_data/workflow.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 3b38faa31..719d93a1b 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -103,21 +103,21 @@ class RollingDataWorkflow(object): dataset.init( handler_kwargs={ "init_type": DataHandlerLP.IT_FIT_SEQ, - "start_time": datetime(train_start_time[0] + 1, *train_start_time[1:]), - "end_time": datetime(test_end_time[0] + 1, *test_end_time[1:]), + "start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), + "end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), }, segment_kwargs={ "train": ( - datetime(train_start_time[0] + 1, *train_start_time[1:]), - datetime(train_end_time[0], *train_end_time[1:]), + datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), + datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), ), "valid": ( - datetime(valid_start_time[0] + 1, *valid_start_time[1:]), - datetime(valid_end_time[0], *valid_end_time[1:]), + datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]), + datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]), ), "test": ( - datetime(test_start_time[0] + 1, *test_start_time[1:]), - datetime(test_end_time[0], *test_end_time[1:]), + datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]), + datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), ), }, ) From 5f60d18dfe2fa71d341ee7e8128f0f4c1f79c119 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 22:08:23 +0800 Subject: [PATCH 38/77] fix config_data bug --- examples/rolling_process_data/workflow.py | 4 ++++ qlib/data/dataset/__init__.py | 2 +- qlib/data/dataset/handler.py | 28 ++++++++++++++++++++--- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 719d93a1b..0be88dddc 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -98,6 +98,7 @@ class RollingDataWorkflow(object): dataset = init_instance_by_config(dataset_config) for rolling_offset in range(self.rolling_cnt): + print(f"===========rolling{rolling_offset} start===========") if rolling_offset: dataset.init( @@ -105,6 +106,8 @@ class RollingDataWorkflow(object): "init_type": DataHandlerLP.IT_FIT_SEQ, "start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), "end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), + "fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), + "fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), }, segment_kwargs={ "train": ( @@ -123,6 +126,7 @@ class RollingDataWorkflow(object): ) dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) + print(dtrain, dvalid, dtest) ## print or dump data print(f"===========rolling{rolling_offset} end===========") diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 0f5d2baba..518b8eecd 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -98,7 +98,7 @@ class DatasetH(Dataset): raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}") kwargs_init = {} kwargs_conf_data = {} - conf_data_arg = {"instruments", "start_time", "end_time"} + conf_data_arg = {"instruments", "start_time", "end_time", "fit_start_time", "fit_end_time"} for k, v in handler_kwargs.items(): if k in conf_data_arg: kwargs_conf_data.update({k: v}) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index f4795c566..40db5e4f3 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -115,8 +115,7 @@ class DataHandler(Serializable): for k, v in kwargs.items(): if k in attr_list: setattr(self, k, v) - else: - raise KeyError("Such config is not supported.") + def init(self, enable_cache: bool = False): """ @@ -405,11 +404,34 @@ class DataHandlerLP(DataHandler): if self.drop_raw: del self._data + + def conf_data(self, **kwargs): + """ + configuration of data. + # what data to be loaded from data source + + This method will be used when loading pickled handler from dataset. + The data will be initialized with different time range. + + """ + attr_list = {"fit_start_time", "fit_end_time"} + for k, v in kwargs.items(): + if k in attr_list: + for infer_processor in self.infer_processors: + if getattr(infer_processor, k, None): + setattr(infer_processor, k, v) + + for learn_processor in self.learn_processors: + if getattr(learn_processor, k, None): + setattr(learn_processor, k, v) + + super().conf_data(**kwargs) + # init type IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df IT_LS = "load_state" # The state of the object has been load by pickle - + def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False): """ Initialize the data of Qlib From 4ee0240c2483383a28099d97e5688bce8ea030b1 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 25 Mar 2021 22:08:39 +0800 Subject: [PATCH 39/77] black format --- qlib/data/dataset/handler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 40db5e4f3..9aa05b9b9 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -116,7 +116,6 @@ class DataHandler(Serializable): if k in attr_list: setattr(self, k, v) - def init(self, enable_cache: bool = False): """ initialize the data. @@ -404,7 +403,6 @@ class DataHandlerLP(DataHandler): if self.drop_raw: del self._data - def conf_data(self, **kwargs): """ configuration of data. @@ -431,7 +429,7 @@ class DataHandlerLP(DataHandler): IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df IT_LS = "load_state" # The state of the object has been load by pickle - + def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False): """ Initialize the data of Qlib From 9d04ae467618505d293df9bb0fa2f20004a6e00c Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 28 Mar 2021 00:33:59 -0700 Subject: [PATCH 40/77] Add MultiSegRecord in contrib.workflow and decouple its tests from test_all_pipeline --- qlib/contrib/workflow/__init__.py | 4 ++ qlib/contrib/workflow/record_temp.py | 29 +++++++++ qlib/workflow/exp.py | 10 ++- qlib/workflow/expm.py | 10 ++- qlib/workflow/record_temp.py | 23 +++++-- tests/test_all_pipeline.py | 25 ++----- tests/test_contrib_workflow.py | 97 ++++++++++++++++++++++++++++ 7 files changed, 171 insertions(+), 27 deletions(-) create mode 100644 tests/test_contrib_workflow.py diff --git a/qlib/contrib/workflow/__init__.py b/qlib/contrib/workflow/__init__.py index e69de29bb..9945e179c 100644 --- a/qlib/contrib/workflow/__init__.py +++ b/qlib/contrib/workflow/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .record_temp import MultiSegRecord +from .record_temp import SignalMseRecord diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 3fdf0c281..4baa15faa 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -5,14 +5,43 @@ import re import pandas as pd from sklearn.metrics import mean_squared_error from pprint import pprint +from typing import Dict, Text, Any import numpy as np +from ...workflow.record_temp import RecordTemp from ...workflow.record_temp import SignalRecord +from ...data import dataset as qlib_dataset from ...log import get_module_logger logger = get_module_logger("workflow", "INFO") +class MultiSegRecord(RecordTemp): + """ + This is the multiple segments signal record class that generates the signal prediction. + This class inherits the ``RecordTemp`` class. + """ + + def __init__(self, model, dataset, recorder=None): + super().__init__(recorder=recorder) + if not isinstance(dataset, qlib_dataset.DatasetH): + raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset))) + self.model = model + self.dataset = dataset + + def generate(self, segments: Dict[Text, Any], save: bool = False): + # generate prediciton + for key, segment in segments.items(): + predics = self.model.predict(self.dataset, segment) + if isinstance(pred, pd.Series): + predics = predictions.to_frame("score") + # self.recorder.save_objects(**{"pred.pkl": pred}) + labels = self.dataset.prepare( + segments=segment, col_set="label", data_key=dataset.handler.DataHandlerLP.DK_R + ) + # compute ic, rank_ic + + class SignalMseRecord(SignalRecord): """ This is the Signal MSE Record class that computes the mean squared error (MSE). diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 5ed4362de..0f420cec4 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -159,7 +159,10 @@ class Experiment: if create: recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name) else: - recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False + recorder, is_new = ( + self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), + False, + ) if is_new: self.active_recorder = recorder # start the recorder @@ -174,7 +177,10 @@ class Experiment: try: if recorder_id is None and recorder_name is None: recorder_name = self._default_rec_name - return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False + return ( + self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), + False, + ) except ValueError: if recorder_name is None: recorder_name = self._default_rec_name diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 95cad4c6e..28d6d92c7 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -159,7 +159,10 @@ class ExpManager: if create: exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name) else: - exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False + exp, is_new = ( + self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), + False, + ) if is_new: self.active_experiment = exp # start the recorder @@ -172,7 +175,10 @@ class ExpManager: automatically create a new experiment based on the given id and name. """ try: - return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False + return ( + self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), + False, + ) except ValueError: if experiment_name is None: experiment_name = self._default_exp_name diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 2c1b6fecc..ed8039ac8 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -39,7 +39,13 @@ class RecordTemp: return "/".join(names) def __init__(self, recorder): - self.recorder = recorder + self._recorder = recorder + + @property + def recorder(self): + if self._recorder is None: + raise ValueError("This RecordTemp did not set recorder yet.") + return self._recorder def generate(self, **kwargs): """ @@ -248,11 +254,20 @@ class PortAnaRecord(SignalRecord): report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config) report_normal = report_dict.get("report_df") positions_normal = report_dict.get("positions") - self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()) - self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()) + self.recorder.save_objects( + **{"report_normal.pkl": report_normal}, + artifact_path=PortAnaRecord.get_path(), + ) + self.recorder.save_objects( + **{"positions_normal.pkl": positions_normal}, + artifact_path=PortAnaRecord.get_path(), + ) order_normal = report_dict.get("order_list") if order_normal: - self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path()) + self.recorder.save_objects( + **{"order_normal.pkl": order_normal}, + artifact_path=PortAnaRecord.get_path(), + ) # analysis analysis = dict() diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 29d39179d..d34c1773a 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -6,24 +6,11 @@ import shutil import unittest from pathlib import Path -import numpy as np -import pandas as pd - import qlib -from qlib.config import REG_CN, C -from qlib.utils import drop_nan_by_y_index -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) -from qlib.contrib.workflow.record_temp import SignalMseRecord -from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict +from qlib.config import C +from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord -from qlib.tests.data import GetData from qlib.tests import TestAutoData @@ -166,8 +153,6 @@ def train_with_sigana(): ric = sar.load(sar.get_path("ric.pkl")) pred_score = sar.load("pred.pkl") - smr = SignalMseRecord(recorder) - smr.generate() uri_path = R.get_uri() return pred_score, {"ic": ic, "ric": ric}, uri_path @@ -256,8 +241,10 @@ class TestAllFlow(TestAutoData): def suite(): _suite = unittest.TestSuite() - _suite.addTest(TestAllFlow("test_0_train")) - _suite.addTest(TestAllFlow("test_1_backtest")) + _suite.addTest(TestAllFlow("test_0_train_with_sigana")) + _suite.addTest(TestAllFlow("test_1_train")) + _suite.addTest(TestAllFlow("test_2_backtest")) + _suite.addTest(TestAllFlow("test_3_expmanager")) return _suite diff --git a/tests/test_contrib_workflow.py b/tests/test_contrib_workflow.py new file mode 100644 index 000000000..92ed7e8d1 --- /dev/null +++ b/tests/test_contrib_workflow.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +import shutil +import unittest +from pathlib import Path + +import qlib +from qlib.config import C +from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord +from qlib.utils import init_instance_by_config, flatten_dict +from qlib.workflow import R +from qlib.tests import TestAutoData + + +market = "csi300" +benchmark = "SH000300" + +################################### +# train model +################################### +data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, +} + +task = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, + }, + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + }, +} + + +def test_multiseg(): + model = init_instance_by_config(task["model"]) + dataset = init_instance_by_config(task["dataset"]) + with R.start(experiment_name="workflow"): + R.log_params(**flatten_dict(task)) + model.fit(dataset) + + # prediction + recorder = R.get_recorder() + sr = MultiSegRecord(model, dataset, recorder) + sr.generate(dict(valid="valid", test="test")) + + uri = R.get_uri() + + return uri + + +class TestAllFlow(TestAutoData): + def test_0_multiseg(self): + uri_path = test_multiseg() + shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) + + +def suite(): + _suite = unittest.TestSuite() + _suite.addTest(TestAllFlow("test_0_multiseg")) + return _suite + + +if __name__ == "__main__": + runner = unittest.TextTestRunner() + runner.run(suite()) From 8a2e7b62af087f41792b84bc1e0dd2d9a1ee26cf Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 28 Mar 2021 08:30:16 +0000 Subject: [PATCH 41/77] Add segment args for pred and refine MultiSegRecord --- qlib/contrib/model/gbdt.py | 4 ++-- qlib/contrib/model/linear.py | 4 ++-- qlib/contrib/model/xgboost.py | 4 ++-- qlib/contrib/workflow/record_temp.py | 30 +++++++++++++++++++--------- tests/test_contrib_workflow.py | 26 ++++++++++++++++++------ 5 files changed, 47 insertions(+), 21 deletions(-) diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index 058d9a0e3..e4ac48ed6 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -61,10 +61,10 @@ class LGBModel(ModelFT): evals_result["train"] = list(evals_result["train"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0] - def predict(self, dataset): + def predict(self, dataset, segment="test"): if self.model is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) return pd.Series(self.model.predict(x_test.values), index=x_test.index) def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20): diff --git a/qlib/contrib/model/linear.py b/qlib/contrib/model/linear.py index 0f9223737..269e788c5 100644 --- a/qlib/contrib/model/linear.py +++ b/qlib/contrib/model/linear.py @@ -84,8 +84,8 @@ class LinearModel(Model): self.coef_ = coef self.intercept_ = 0.0 - def predict(self, dataset): + def predict(self, dataset, segment="test"): if self.coef_ is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index) diff --git a/qlib/contrib/model/xgboost.py b/qlib/contrib/model/xgboost.py index ba2e5789b..6bfd2c799 100755 --- a/qlib/contrib/model/xgboost.py +++ b/qlib/contrib/model/xgboost.py @@ -57,8 +57,8 @@ class XGBModel(Model): evals_result["train"] = list(evals_result["train"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0] - def predict(self, dataset): + def predict(self, dataset, segment="test"): if self.model is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature") + x_test = dataset.prepare(segment, col_set="feature") return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index) diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 4baa15faa..863daee85 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -1,13 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import re import pandas as pd from sklearn.metrics import mean_squared_error -from pprint import pprint from typing import Dict, Text, Any import numpy as np +from ...contrib.eva.alpha import calc_ic from ...workflow.record_temp import RecordTemp from ...workflow.record_temp import SignalRecord from ...data import dataset as qlib_dataset @@ -30,16 +29,29 @@ class MultiSegRecord(RecordTemp): self.dataset = dataset def generate(self, segments: Dict[Text, Any], save: bool = False): - # generate prediciton for key, segment in segments.items(): predics = self.model.predict(self.dataset, segment) - if isinstance(pred, pd.Series): - predics = predictions.to_frame("score") - # self.recorder.save_objects(**{"pred.pkl": pred}) + if isinstance(predics, pd.Series): + predics = predics.to_frame("score") labels = self.dataset.prepare( - segments=segment, col_set="label", data_key=dataset.handler.DataHandlerLP.DK_R + segments=segment, col_set="label", data_key=qlib_dataset.handler.DataHandlerLP.DK_R ) - # compute ic, rank_ic + # Compute the IC and Rank IC + ic, ric = calc_ic(predics.iloc[:, 0], labels.iloc[:, 0]) + results = {"all-IC": ic, "mean-IC": ic.mean(), "all-Rank-IC": ric, "mean-Rank-IC": ric.mean()} + logger.info("--- Results for {:} ({:}) ---".format(key, segment)) + ic_x100, ric_x100 = ic * 100, ric * 100 + logger.info("IC: {:.4f}%".format(ic_x100.mean())) + logger.info("ICIR: {:.4f}%".format(ic_x100.mean() / ic_x100.std())) + logger.info("Rank IC: {:.4f}%".format(ric_x100.mean())) + logger.info("Rank ICIR: {:.4f}%".format(ric_x100.mean() / ric_x100.std())) + + if save: + save_name = "results-{:}.pkl".format(key) + self.recorder.save_objects(**{save_name: results}) + logger.info( + "The record '{save_name}' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" + ) class SignalMseRecord(SignalRecord): @@ -67,7 +79,7 @@ class SignalMseRecord(SignalRecord): objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)} self.recorder.log_metrics(**metrics) self.recorder.save_objects(**objects, artifact_path=self.get_path()) - pprint(metrics) + logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics)) def list(self): paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")] diff --git a/tests/test_contrib_workflow.py b/tests/test_contrib_workflow.py index 92ed7e8d1..ccd3c6a90 100644 --- a/tests/test_contrib_workflow.py +++ b/tests/test_contrib_workflow.py @@ -63,32 +63,46 @@ task = { } -def test_multiseg(): +def train_multiseg(): model = init_instance_by_config(task["model"]) dataset = init_instance_by_config(task["dataset"]) with R.start(experiment_name="workflow"): R.log_params(**flatten_dict(task)) model.fit(dataset) - - # prediction recorder = R.get_recorder() sr = MultiSegRecord(model, dataset, recorder) - sr.generate(dict(valid="valid", test="test")) - + sr.generate(dict(valid="valid", test="test"), True) uri = R.get_uri() + return uri + +def train_mse(): + model = init_instance_by_config(task["model"]) + dataset = init_instance_by_config(task["dataset"]) + with R.start(experiment_name="workflow"): + R.log_params(**flatten_dict(task)) + model.fit(dataset) + recorder = R.get_recorder() + sr = SignalMseRecord(recorder, model=model, dataset=dataset) + sr.generate() + uri = R.get_uri() return uri class TestAllFlow(TestAutoData): def test_0_multiseg(self): - uri_path = test_multiseg() + uri_path = train_multiseg() + shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) + + def test_1_mse(self): + uri_path = train_mse() shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) def suite(): _suite = unittest.TestSuite() _suite.addTest(TestAllFlow("test_0_multiseg")) + _suite.addTest(TestAllFlow("test_1_mse")) return _suite From 0386df7b16ce4480687a49af07a3a2fac3a0caad Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 28 Mar 2021 10:39:28 +0000 Subject: [PATCH 42/77] Collect all contrib models in __init__ and add unit tests for init --- qlib/contrib/model/__init__.py | 39 ++++++++++++++++++++++++++ qlib/contrib/model/catboost_model.py | 5 ++-- qlib/contrib/model/double_ensemble.py | 10 +++++-- qlib/contrib/model/gbdt.py | 4 +-- qlib/contrib/model/linear.py | 4 +-- qlib/contrib/model/pytorch_alstm.py | 6 ++-- qlib/contrib/model/pytorch_alstm_ts.py | 5 ++-- qlib/contrib/model/pytorch_gats.py | 6 ++-- qlib/contrib/model/pytorch_gru.py | 5 ++-- qlib/contrib/model/pytorch_lstm.py | 9 ++---- qlib/contrib/model/pytorch_nn.py | 10 +++---- qlib/contrib/model/pytorch_sfm.py | 12 +++----- qlib/contrib/model/pytorch_tabnet.py | 7 ++--- qlib/contrib/model/xgboost.py | 6 ++-- qlib/data/dataset/processor.py | 0 qlib/model/base.py | 6 +++- tests/test_contrib_model.py | 27 ++++++++++++++++++ 17 files changed, 115 insertions(+), 46 deletions(-) mode change 100755 => 100644 qlib/data/dataset/processor.py create mode 100644 tests/test_contrib_model.py diff --git a/qlib/contrib/model/__init__.py b/qlib/contrib/model/__init__.py index e69de29bb..09b0c929b 100644 --- a/qlib/contrib/model/__init__.py +++ b/qlib/contrib/model/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +try: + from .catboost_model import CatBoostModel +except ModuleNotFoundError: + CatBoostModel = None + print("Please install necessary libs for CatBoostModel.") +try: + from .double_ensemble import DEnsembleModel + from .gbdt import LGBModel +except ModuleNotFoundError: + DEnsembleModel, LGBModel = None, None + print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.") +try: + from .xgboost import XGBModel +except ModuleNotFoundError: + XGBModel = None + print("Please install necessary libs for XGBModel, such as xgboost.") +try: + from .linear import LinearModel +except ModuleNotFoundError: + LinearModel = None + print("Please install necessary libs for LinearModel, such as scipy and sklearn.") +# import pytorch models +try: + from .pytorch_alstm import ALSTM + from .pytorch_gats import GATs + from .pytorch_gru import GRU + from .pytorch_lstm import LSTM + from .pytorch_nn import DNNModelPytorch + from .pytorch_tabnet import TabnetModel + from .pytorch_sfm import SFM_Model + + pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model) +except ModuleNotFoundError: + pytorch_classes = () + print("Please install necessary libs for PyTorch models.") + +all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes diff --git a/qlib/contrib/model/catboost_model.py b/qlib/contrib/model/catboost_model.py index d57c32b70..98b9b9c2d 100644 --- a/qlib/contrib/model/catboost_model.py +++ b/qlib/contrib/model/catboost_model.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +from typing import Text, Union from catboost import Pool, CatBoost from catboost.utils import get_gpu_device_count @@ -62,10 +63,10 @@ class CatBoostModel(Model): evals_result["train"] = list(evals_result["learn"].values())[0] evals_result["valid"] = list(evals_result["validation"].values())[0] - def predict(self, dataset): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if self.model is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature") + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) return pd.Series(self.model.predict(x_test.values), index=x_test.index) diff --git a/qlib/contrib/model/double_ensemble.py b/qlib/contrib/model/double_ensemble.py index 541f74e99..4b267a2b0 100644 --- a/qlib/contrib/model/double_ensemble.py +++ b/qlib/contrib/model/double_ensemble.py @@ -4,7 +4,7 @@ import lightgbm as lgb import numpy as np import pandas as pd - +from typing import Text, Union from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -40,6 +40,10 @@ class DEnsembleModel(Model): self.bins_sr = bins_sr self.bins_fs = bins_fs self.decay = decay + if sample_ratios is None: # the default values for sample_ratios + sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4] + if sub_weights is None: # the default values for sub_weights + sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2] if not len(sample_ratios) == bins_fs: raise ValueError("The length of sample_ratios should be equal to bins_fs.") self.sample_ratios = sample_ratios @@ -228,10 +232,10 @@ class DEnsembleModel(Model): raise ValueError("not implemented yet") return loss_curve - def predict(self, dataset): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if self.ensemble is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index) for i_sub, submodel in enumerate(self.ensemble): feat_sub = self.sub_features[i_sub] diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index e4ac48ed6..463cf8f4f 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd import lightgbm as lgb - +from typing import Text, Union from ...model.base import ModelFT from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -61,7 +61,7 @@ class LGBModel(ModelFT): evals_result["train"] = list(evals_result["train"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0] - def predict(self, dataset, segment="test"): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if self.model is None: raise ValueError("model is not fitted yet!") x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) diff --git a/qlib/contrib/model/linear.py b/qlib/contrib/model/linear.py index 269e788c5..f16acc1ec 100644 --- a/qlib/contrib/model/linear.py +++ b/qlib/contrib/model/linear.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd - +from typing import Text, Union from scipy.optimize import nnls from sklearn.linear_model import LinearRegression, Ridge, Lasso @@ -84,7 +84,7 @@ class LinearModel(Model): self.coef_ = coef self.intercept_ = 0.0 - def predict(self, dataset, segment="test"): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if self.coef_ is None: raise ValueError("model is not fitted yet!") x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index a149272da..ed706be86 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -8,9 +8,9 @@ from __future__ import print_function import os import numpy as np import pandas as pd +from typing import Text, Union import copy from ...utils import ( - unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index, @@ -273,11 +273,11 @@ class ALSTM(Model): if self.use_gpu: torch.cuda.empty_cache() - def predict(self, dataset): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature") + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) index = x_test.index self.ALSTM_model.eval() x_values = x_test.values diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index c38727b9e..3cd7ec280 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -8,6 +8,7 @@ from __future__ import print_function import os import numpy as np import pandas as pd +from typing import Text, Union import copy from ...utils import ( unpack_archive_with_buffer, @@ -264,11 +265,11 @@ class ALSTM(Model): if self.use_gpu: torch.cuda.empty_cache() - def predict(self, dataset): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) dl_test.config(fillna_type="ffill+bfill") test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) self.ALSTM_model.eval() diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 53afb5404..71edda76e 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -8,6 +8,7 @@ from __future__ import print_function import os import numpy as np import pandas as pd +from typing import Text, Union import copy from ...utils import ( unpack_archive_with_buffer, @@ -83,7 +84,6 @@ class GATs(Model): self.with_pretrain = with_pretrain self.model_path = model_path self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") - self.use_gpu = torch.cuda.is_available() self.seed = seed self.logger.info( @@ -310,11 +310,11 @@ class GATs(Model): if self.use_gpu: torch.cuda.empty_cache() - def predict(self, dataset): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature") + x_test = dataset.prepare(segment, col_set="feature") index = x_test.index self.GAT_model.eval() x_values = x_test.values diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index 5eba33595..da2161653 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -8,6 +8,7 @@ from __future__ import print_function import os import numpy as np import pandas as pd +from typing import Text, Union import copy from ...utils import ( unpack_archive_with_buffer, @@ -273,11 +274,11 @@ class GRU(Model): if self.use_gpu: torch.cuda.empty_cache() - def predict(self, dataset): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature") + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) index = x_test.index self.gru_model.eval() x_values = x_test.values diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index 636ef6e3a..bafd83ea6 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -8,6 +8,7 @@ from __future__ import print_function import os import numpy as np import pandas as pd +from typing import Text, Union import copy from ...utils import ( unpack_archive_with_buffer, @@ -268,11 +269,11 @@ class LSTM(Model): if self.use_gpu: torch.cuda.empty_cache() - def predict(self, dataset): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature") + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) index = x_test.index self.lstm_model.eval() x_values = x_test.values @@ -280,17 +281,13 @@ class LSTM(Model): preds = [] for begin in range(sample_num)[:: self.batch_size]: - if sample_num - begin < self.batch_size: end = sample_num else: end = begin + self.batch_size - x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) - with torch.no_grad(): pred = self.lstm_model(x_batch).detach().cpu().numpy() - preds.append(pred) return pd.Series(np.concatenate(preds), index=index) diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index caf34b330..4dc02cc0a 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -8,6 +8,7 @@ from __future__ import print_function import os import numpy as np import pandas as pd +from typing import Text, Union from sklearn.metrics import roc_auc_score, mean_squared_error import torch @@ -48,8 +49,8 @@ class DNNModelPytorch(Model): def __init__( self, - input_dim, - output_dim, + input_dim=360, + output_dim=1, layers=(256,), lr=0.001, max_steps=300, @@ -271,13 +272,12 @@ class DNNModelPytorch(Model): else: raise NotImplementedError("loss {} is not supported!".format(loss_type)) - def predict(self, dataset): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test_pd = dataset.prepare("test", col_set="feature") + x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) x_test = torch.from_numpy(x_test_pd.values).float().to(self.device) self.dnn_model.eval() - with torch.no_grad(): preds = self.dnn_model(x_test).detach().cpu().numpy() return pd.Series(np.squeeze(preds), index=x_test_pd.index) diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index db3e8bb12..4eb89bdda 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -7,10 +7,9 @@ from __future__ import print_function import os import numpy as np import pandas as pd +from typing import Text, Union import copy from ...utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index, ) @@ -442,11 +441,11 @@ class SFM(Model): raise ValueError("unknown metric `%s`" % self.metric) - def predict(self, dataset): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature") + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) index = x_test.index self.sfm_model.eval() x_values = x_test.values @@ -459,10 +458,7 @@ class SFM(Model): else: end = begin + self.batch_size - x_batch = torch.from_numpy(x_values[begin:end]).float() - - if self.device != "cpu": - x_batch = x_batch.to(self.device) + x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) with torch.no_grad(): pred = self.sfm_model(x_batch).detach().cpu().numpy() diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 450e6f5d1..b772b60d9 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -6,10 +6,9 @@ from __future__ import print_function import os import numpy as np import pandas as pd +from typing import Text, Union import copy from ...utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index, ) @@ -217,11 +216,11 @@ class TabnetModel(Model): if self.use_gpu: torch.cuda.empty_cache() - def predict(self, dataset): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) index = x_test.index self.tabnet_model.eval() x_values = torch.from_numpy(x_test.values) diff --git a/qlib/contrib/model/xgboost.py b/qlib/contrib/model/xgboost.py index 6bfd2c799..cbba14678 100755 --- a/qlib/contrib/model/xgboost.py +++ b/qlib/contrib/model/xgboost.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd import xgboost as xgb - +from typing import Text, Union from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP @@ -57,8 +57,8 @@ class XGBModel(Model): evals_result["train"] = list(evals_result["train"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0] - def predict(self, dataset, segment="test"): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if self.model is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature") + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index) diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py old mode 100755 new mode 100644 diff --git a/qlib/model/base.py b/qlib/model/base.py index 3708298d5..1ac8f2fc9 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import abc +from typing import Text, Union from ..utils.serial import Serializable from ..data.dataset import Dataset @@ -59,7 +60,7 @@ class Model(BaseModel): raise NotImplementedError() @abc.abstractmethod - def predict(self, dataset: Dataset) -> object: + def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object: """give prediction given Dataset Parameters @@ -67,6 +68,9 @@ class Model(BaseModel): dataset : Dataset dataset will generate the processed dataset from model training. + segment : Text or slice + dataset will use this segment to prepare data. (default=test) + Returns ------- Prediction results with certain type such as `pandas.Series`. diff --git a/tests/test_contrib_model.py b/tests/test_contrib_model.py new file mode 100644 index 000000000..a82a3042e --- /dev/null +++ b/tests/test_contrib_model.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +from qlib.contrib.model import all_model_classes + + +class TestAllFlow(unittest.TestCase): + def test_0_initialize(self): + num = 0 + for model_class in all_model_classes: + if model_class is not None: + model = model_class() + num += 1 + print("There are {:}/{:} valid models in total.".format(num, len(all_model_classes))) + + +def suite(): + _suite = unittest.TestSuite() + _suite.addTest(TestAllFlow("test_0_initialize")) + return _suite + + +if __name__ == "__main__": + runner = unittest.TextTestRunner() + runner.run(suite()) From f809f0a0636ca7baeb8e7e98c5a8b387096e7217 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 28 Mar 2021 10:50:25 +0000 Subject: [PATCH 43/77] Remove un-used imports --- qlib/contrib/model/pytorch_alstm.py | 6 +----- qlib/contrib/model/pytorch_alstm_ts.py | 7 +------ qlib/contrib/model/pytorch_gats.py | 7 +------ qlib/contrib/model/pytorch_gats_ts.py | 7 +------ qlib/contrib/model/pytorch_gru.py | 7 +------ qlib/contrib/model/pytorch_gru_ts.py | 7 +------ qlib/contrib/model/pytorch_lstm.py | 7 +------ qlib/contrib/model/pytorch_lstm_ts.py | 7 +------ qlib/contrib/model/pytorch_nn.py | 2 +- qlib/contrib/model/pytorch_sfm.py | 5 +---- qlib/contrib/model/pytorch_tabnet.py | 5 +---- 11 files changed, 11 insertions(+), 56 deletions(-) diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index ed706be86..4fe2b2714 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -10,11 +10,7 @@ import numpy as np import pandas as pd from typing import Text, Union import copy -from ...utils import ( - save_multiple_parts_file, - get_or_create_path, - drop_nan_by_y_index, -) +from ...utils import get_or_create_path from ...log import get_module_logger import torch diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 3cd7ec280..f1aa8227c 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -10,12 +10,7 @@ import numpy as np import pandas as pd from typing import Text, Union import copy -from ...utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, - get_or_create_path, - drop_nan_by_y_index, -) +from ...utils import get_or_create_path from ...log import get_module_logger import torch diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 71edda76e..493bf120f 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -10,12 +10,7 @@ import numpy as np import pandas as pd from typing import Text, Union import copy -from ...utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, - get_or_create_path, - drop_nan_by_y_index, -) +from ...utils import get_or_create_path from ...log import get_module_logger import torch import torch.nn as nn diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index f02bf1e47..5f9961b0b 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -9,12 +9,7 @@ import os import numpy as np import pandas as pd import copy -from ...utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, - get_or_create_path, - drop_nan_by_y_index, -) +from ...utils import get_or_create_path from ...log import get_module_logger import torch import torch.nn as nn diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index da2161653..552815d39 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -10,12 +10,7 @@ import numpy as np import pandas as pd from typing import Text, Union import copy -from ...utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, - get_or_create_path, - drop_nan_by_y_index, -) +from ...utils import get_or_create_path from ...log import get_module_logger import torch diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index de5e280d0..c094a3e3c 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -9,12 +9,7 @@ import os import numpy as np import pandas as pd import copy -from ...utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, - get_or_create_path, - drop_nan_by_y_index, -) +from ...utils import get_or_create_path from ...log import get_module_logger import torch diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index bafd83ea6..0ecfc2083 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -10,12 +10,7 @@ import numpy as np import pandas as pd from typing import Text, Union import copy -from ...utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, - get_or_create_path, - drop_nan_by_y_index, -) +from ...utils import get_or_create_path from ...log import get_module_logger import torch diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index c978e84c7..1f97bd5b1 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -9,12 +9,7 @@ import os import numpy as np import pandas as pd import copy -from ...utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, - get_or_create_path, - drop_nan_by_y_index, -) +from ...utils import get_or_create_path from ...log import get_module_logger import torch diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 4dc02cc0a..15ee7ef71 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -19,7 +19,7 @@ from .pytorch_utils import count_parameters from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP -from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index +from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path from ...log import get_module_logger from ...workflow import R diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index 4eb89bdda..cf65c2662 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -9,10 +9,7 @@ import numpy as np import pandas as pd from typing import Text, Union import copy -from ...utils import ( - get_or_create_path, - drop_nan_by_y_index, -) +from ...utils import get_or_create_path from ...log import get_module_logger import torch diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index b772b60d9..b05d9a026 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -8,10 +8,7 @@ import numpy as np import pandas as pd from typing import Text, Union import copy -from ...utils import ( - get_or_create_path, - drop_nan_by_y_index, -) +from ...utils import get_or_create_path from ...log import get_module_logger import torch From 4b663049781ff9bc022a5e095772888965d27c91 Mon Sep 17 00:00:00 2001 From: zhupr Date: Mon, 29 Mar 2021 11:18:33 +0800 Subject: [PATCH 44/77] Fix us_index collector --- scripts/data_collector/index.py | 7 ++++++- scripts/data_collector/us_index/collector.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/scripts/data_collector/index.py b/scripts/data_collector/index.py index 300e6b625..82a230e37 100644 --- a/scripts/data_collector/index.py +++ b/scripts/data_collector/index.py @@ -114,6 +114,8 @@ class IndexBase: $ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data """ df = self.get_new_companies() + if df is None or df.empty: + raise ValueError(f"get new companies error: {self.index_name}") df = df.drop_duplicates([self.SYMBOL_FIELD_NAME]) df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv( self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None @@ -184,7 +186,10 @@ class IndexBase: logger.info(f"start parse {self.index_name.lower()} companies.....") instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD] changers_df = self.get_changes() - new_df = self.get_new_companies().copy() + new_df = self.get_new_companies() + if new_df is None or new_df.empty: + raise ValueError(f"get new companies error: {self.index_name}") + new_df = new_df.copy() logger.info("parse history companies by changes......") for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)): if _row.type == self.ADD: diff --git a/scripts/data_collector/us_index/collector.py b/scripts/data_collector/us_index/collector.py index 0641437e0..371668330 100644 --- a/scripts/data_collector/us_index/collector.py +++ b/scripts/data_collector/us_index/collector.py @@ -35,7 +35,7 @@ WIKI_INDEX_NAME_MAP = { class WIKIIndex(IndexBase): # NOTE: The US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix # https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows - INST_PREFIX = "_" + INST_PREFIX = "" def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3): super(WIKIIndex, self).__init__( @@ -123,7 +123,7 @@ class NASDAQ100Index(WIKIIndex): MAX_WORKERS = 16 def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: - if not (set(df.columns) - {"Company", "Ticker"}): + if len(df) >= 100 and "Ticker" in df.columns: return df.loc[:, ["Ticker"]].copy() @property From 968930e85f4958d16dfc2c5740c02f5c91745b97 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Mon, 29 Mar 2021 04:46:38 +0000 Subject: [PATCH 45/77] Fix print issue --- qlib/contrib/workflow/record_temp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 863daee85..12792fbcb 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -50,7 +50,9 @@ class MultiSegRecord(RecordTemp): save_name = "results-{:}.pkl".format(key) self.recorder.save_objects(**{save_name: results}) logger.info( - "The record '{save_name}' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" + "The record '{:}' has been saved as the artifact of the Experiment {:}".format( + save_name, self.recorder.experiment_id + ) ) From 31bc85bf867ba2161c638b819b41e3cb7e863ce1 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 29 Mar 2021 19:49:30 +0800 Subject: [PATCH 46/77] restructure data layer config & setup --- examples/highfreq/highfreq_processor.py | 7 ++ examples/highfreq/workflow.py | 33 ++++--- qlib/data/dataset/__init__.py | 116 ++++++++++++++---------- qlib/data/dataset/handler.py | 46 +++++----- qlib/data/dataset/loader.py | 1 - qlib/data/dataset/processor.py | 22 +++++ 6 files changed, 138 insertions(+), 87 deletions(-) diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py index f0ab0dec2..4ec8f3dd2 100644 --- a/examples/highfreq/highfreq_processor.py +++ b/examples/highfreq/highfreq_processor.py @@ -70,3 +70,10 @@ class HighFreqNorm(Processor): columns=["FEATURE_%d" % i for i in range(12 * 240)], ).sort_index() return df_new_features + + def config(fit_start_time=None, fit_end_time=None, **kwargs): + if fit_start_time: + self.fit_start_time = fit_start_time + if fit_end_time: + self.fit_end_time = fit_end_time + super().config(**kwargs) diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index c2ca36db3..0b48b971f 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -31,7 +31,7 @@ class HighfreqWorkflow(object): SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None} - MARKET = "all" + MARKET = "csi300" start_time = "2020-09-15 00:00:00" end_time = "2021-01-18 16:00:00" @@ -145,35 +145,40 @@ class HighfreqWorkflow(object): self._prepare_calender_cache() ##=============reinit dataset============= - dataset.init( + dataset.config( + handler_kwargs={ + "start_time": "2021-01-19 00:00:00", + "end_time": "2021-01-25 16:00:00", + }, + segments={ + "test": ( + "2021-01-19 00:00:00", + "2021-01-25 16:00:00", + ), + }, + ) + dataset.setup_data( handler_kwargs={ "init_type": DataHandlerLP.IT_LS, - "start_time": "2021-01-19 00:00:00", - "end_time": "2021-01-25 16:00:00", - }, - segment_kwargs={ - "test": ( - "2021-01-19 00:00:00", - "2021-01-25 16:00:00", - ), }, ) - dataset_backtest.init( + dataset_backtest.config( handler_kwargs={ "start_time": "2021-01-19 00:00:00", "end_time": "2021-01-25 16:00:00", }, - segment_kwargs={ + segments={ "test": ( "2021-01-19 00:00:00", "2021-01-25 16:00:00", ), }, ) + dataset_backtest.setup_data(handler_kwargs={}) ##=============get data============= - xtest = dataset.prepare(["test"]) - backtest_test = dataset_backtest.prepare(["test"]) + xtest, = dataset.prepare(["test"]) + backtest_test, = dataset_backtest.prepare(["test"]) print(xtest, backtest_test) return diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 518b8eecd..aa90cee2f 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -20,17 +20,25 @@ class Dataset(Serializable): """ init is designed to finish following steps: + - init instance + + - config the state of the dataset(info to prepare the data) + - The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing. + - setup data - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing. - - initialize the state of the dataset(info to prepare the data) - - The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing. - The data could specify the info to caculate the essential data for preparation """ self.setup_data(*args, **kwargs) super().__init__() + def config(self, *arg, **kwargs): + """ + config is designed to configure and parameters that cannot be learned from the data + """ + super().config(*arg, **kwargs) + def setup_data(self, *args, **kwargs): """ Setup the data. @@ -39,7 +47,7 @@ class Dataset(Serializable): - User have a Dataset object with learned status on disk. - - User load the Dataset object from the disk(Note the init function is skiped). + - User load the Dataset object from the disk. - User call `setup_data` to load new data. @@ -76,44 +84,7 @@ class DatasetH(Dataset): - The processing is related to data split. """ - def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None): - """ - Initialize the DatasetH - - Parameters - ---------- - handler_kwargs : dict - Config of DataHanlder, which could include the following arguments: - - - arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'. - - - arguments of DataHandler.init, such as 'enable_cache', etc. - - segment_kwargs : dict - Config of segments which is same as 'segments' in DatasetH.setup_data - - """ - if handler_kwargs: - if not isinstance(handler_kwargs, dict): - raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}") - kwargs_init = {} - kwargs_conf_data = {} - conf_data_arg = {"instruments", "start_time", "end_time", "fit_start_time", "fit_end_time"} - for k, v in handler_kwargs.items(): - if k in conf_data_arg: - kwargs_conf_data.update({k: v}) - else: - kwargs_init.update({k: v}) - - self.handler.conf_data(**kwargs_conf_data) - self.handler.init(**kwargs_init) - - if segment_kwargs: - if not isinstance(segment_kwargs, dict): - raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}") - self.segments = segment_kwargs.copy() - - def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]): + def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs): """ Setup the underlying data. @@ -144,6 +115,52 @@ class DatasetH(Dataset): """ self.handler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() + super().__init__(**kwargs) + + def config(self, handler_kwargs:dict = None, segments:dict = None, **kwargs): + """ + Initialize the DatasetH + + Parameters + ---------- + handler_kwargs : dict + Config of DataHanlder, which could include the following arguments: + + - arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'. + + kwargs : dict + Config of DatasetH, such as + + - segments : dict + Config of segments which is same as 'segments' in self.__init__ + + """ + super().config(**kwargs) + if handler_kwargs is not None: + self.handler.config(**handler_kwargs) + if segments is not None: + self.segments = segments.copy() + + + + def setup_data(self, handler_kwargs: dict = None, **kwargs): + """ + Setup the Data + + Parameters + ---------- + handler_kwargs : dict + init arguments of DataHanlder, which could include the following arguments: + + - init_type : Init Type of Handler + + - enable_cache : wheter to enable cache + + """ + super().setup_data(**kwargs) + if handler_kwargs is not None: + self.handler.setup_data(**handler_kwargs) + def __repr__(self): return "{name}(handler={handler}, segments={segments})".format( @@ -433,16 +450,21 @@ class TSDatasetH(DatasetH): - The dimension of a batch of data """ - def __init__(self, step_len=30, *args, **kwargs): + def __init__(self, step_len=30, **kwargs): self.step_len = step_len - super().__init__(*args, **kwargs) + super().__init__(**kwargs) - def setup_data(self, *args, **kwargs): - super().setup_data(*args, **kwargs) + def config(self, step_len=None, **kwargs): + super().config(**kwargs) + if step_len: + self.step_len = step_len + + def setup_data(self, **kwargs): + super().setup_data(**kwargs) cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() cal = sorted(cal) - # Get the datatime index for building timestamp self.cal = cal + def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: # Dataset decide how to slice data(Get more data for timeseries). diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 9aa05b9b9..712cd6232 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -6,6 +6,7 @@ import abc import bisect import logging import warnings +from inspect import getfullargspec from typing import Union, Tuple, List, Iterator, Optional import pandas as pd @@ -99,10 +100,10 @@ class DataHandler(Serializable): self.fetch_orig = fetch_orig if init_data: with TimeInspector.logt("Init data"): - self.init() + self.setup_data() super().__init__() - def conf_data(self, **kwargs): + def config(self, instruments=None, start_time=None, end_time=None, **kwargs): """ configuration of data. # what data to be loaded from data source @@ -111,14 +112,17 @@ class DataHandler(Serializable): The data will be initialized with different time range. """ - attr_list = {"instruments", "start_time", "end_time"} - for k, v in kwargs.items(): - if k in attr_list: - setattr(self, k, v) - - def init(self, enable_cache: bool = False): + super().config(**kwargs) + if instruments: + self.instruments = instruments + if start_time: + self.start_time = start_time + if end_time: + self.end_time = end_time + + def setup_data(self, enable_cache: bool = False): """ - initialize the data. + Set Up the data. In case of running intialization for multiple time, it will do nothing for the second time. It is responsible for maintaining following variable @@ -403,7 +407,7 @@ class DataHandlerLP(DataHandler): if self.drop_raw: del self._data - def conf_data(self, **kwargs): + def config(self, processors_kwargs:dict = None, **kwargs): """ configuration of data. # what data to be loaded from data source @@ -412,27 +416,19 @@ class DataHandlerLP(DataHandler): The data will be initialized with different time range. """ - attr_list = {"fit_start_time", "fit_end_time"} - for k, v in kwargs.items(): - if k in attr_list: - for infer_processor in self.infer_processors: - if getattr(infer_processor, k, None): - setattr(infer_processor, k, v) - - for learn_processor in self.learn_processors: - if getattr(learn_processor, k, None): - setattr(learn_processor, k, v) - - super().conf_data(**kwargs) + super().config(**kwargs) + if processors_kwargs is not None: + for processor in self.get_all_processors(): + processor.config(**processor_kwargs) # init type IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df IT_LS = "load_state" # The state of the object has been load by pickle - def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False): + def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs): """ - Initialize the data of Qlib + Set up the data of Qlib Parameters ---------- @@ -447,7 +443,7 @@ class DataHandlerLP(DataHandler): when we call `init` next time """ # init raw data - super().init(enable_cache=enable_cache) + super().setup_data(**kwargs) with TimeInspector.logt("fit & process data"): if init_type == DataHandlerLP.IT_FIT_IND: diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 1cda5c025..a58bca5e8 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -53,7 +53,6 @@ class DataLoader(abc.ABC): """ pass - class DLWParser(DataLoader): """ (D)ata(L)oader (W)ith (P)arser for features and names diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 5a06f66be..e14e85831 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -72,6 +72,9 @@ class Processor(Serializable): """ return True + def config(**kwargs): + super().config(kwargs.get("dump_all", None), kwargs.get("exclude", None)) + class DropnaProcessor(Processor): def __init__(self, fields_group=None): @@ -192,6 +195,12 @@ class MinMaxNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df + def config(fit_start_time=None, fit_end_time=None, **kwargs): + if fit_start_time: + self.fit_start_time = fit_start_time + if fit_end_time: + self.fit_end_time = fit_end_time + super().config(**kwargs) class ZScoreNorm(Processor): """ZScore Normalization""" @@ -220,6 +229,13 @@ class ZScoreNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df + + def config(fit_start_time=None, fit_end_time=None, **kwargs): + if fit_start_time: + self.fit_start_time = fit_start_time + if fit_end_time: + self.fit_end_time = fit_end_time + super().config(**kwargs) class RobustZScoreNorm(Processor): @@ -257,6 +273,12 @@ class RobustZScoreNorm(Processor): df.clip(-3, 3, inplace=True) return df + def config(fit_start_time=None, fit_end_time=None, **kwargs): + if fit_start_time: + self.fit_start_time = fit_start_time + if fit_end_time: + self.fit_end_time = fit_end_time + super().config(**kwargs) class CSZScoreNorm(Processor): """Cross Sectional ZScore Normalization""" From fb7f84f31e4e3b6a6e76cf496d97b6a62fe2fe04 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 29 Mar 2021 20:15:42 +0800 Subject: [PATCH 47/77] fix ubg --- examples/highfreq/highfreq_processor.py | 2 +- examples/highfreq/workflow.py | 2 +- examples/rolling_process_data/workflow.py | 14 +++++++++----- qlib/data/dataset/handler.py | 4 ++-- qlib/data/dataset/processor.py | 8 ++++---- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py index 4ec8f3dd2..6ed68ff38 100644 --- a/examples/highfreq/highfreq_processor.py +++ b/examples/highfreq/highfreq_processor.py @@ -71,7 +71,7 @@ class HighFreqNorm(Processor): ).sort_index() return df_new_features - def config(fit_start_time=None, fit_end_time=None, **kwargs): + def config(self, fit_start_time=None, fit_end_time=None, **kwargs): if fit_start_time: self.fit_start_time = fit_start_time if fit_end_time: diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 0b48b971f..97762f182 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -31,7 +31,7 @@ class HighfreqWorkflow(object): SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None} - MARKET = "csi300" + MARKET = "all" start_time = "2020-09-15 00:00:00" end_time = "2021-01-18 16:00:00" diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 0be88dddc..ffdd8329a 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -101,15 +101,16 @@ class RollingDataWorkflow(object): print(f"===========rolling{rolling_offset} start===========") if rolling_offset: - dataset.init( + dataset.config( handler_kwargs={ - "init_type": DataHandlerLP.IT_FIT_SEQ, "start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), "end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), - "fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), - "fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), + "processor_kwargs":{ + "fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), + "fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), + }, }, - segment_kwargs={ + segments={ "train": ( datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), @@ -124,6 +125,9 @@ class RollingDataWorkflow(object): ), }, ) + dataset.setup_data( + handler_kwargs={"init_type": DataHandlerLP.IT_FIT_SEQ,} + ) dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) print(dtrain, dvalid, dtest) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 712cd6232..4adef23a0 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -407,7 +407,7 @@ class DataHandlerLP(DataHandler): if self.drop_raw: del self._data - def config(self, processors_kwargs:dict = None, **kwargs): + def config(self, processor_kwargs:dict = None, **kwargs): """ configuration of data. # what data to be loaded from data source @@ -417,7 +417,7 @@ class DataHandlerLP(DataHandler): """ super().config(**kwargs) - if processors_kwargs is not None: + if processor_kwargs is not None: for processor in self.get_all_processors(): processor.config(**processor_kwargs) diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index e14e85831..5be178c5c 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -72,7 +72,7 @@ class Processor(Serializable): """ return True - def config(**kwargs): + def config(self, **kwargs): super().config(kwargs.get("dump_all", None), kwargs.get("exclude", None)) @@ -195,7 +195,7 @@ class MinMaxNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df - def config(fit_start_time=None, fit_end_time=None, **kwargs): + def config(self, fit_start_time=None, fit_end_time=None, **kwargs): if fit_start_time: self.fit_start_time = fit_start_time if fit_end_time: @@ -230,7 +230,7 @@ class ZScoreNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df - def config(fit_start_time=None, fit_end_time=None, **kwargs): + def config(self, fit_start_time=None, fit_end_time=None, **kwargs): if fit_start_time: self.fit_start_time = fit_start_time if fit_end_time: @@ -273,7 +273,7 @@ class RobustZScoreNorm(Processor): df.clip(-3, 3, inplace=True) return df - def config(fit_start_time=None, fit_end_time=None, **kwargs): + def config(self, fit_start_time=None, fit_end_time=None, **kwargs): if fit_start_time: self.fit_start_time = fit_start_time if fit_end_time: From 8743576f7238003530ae55e78fa50554d8d6ba33 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 29 Mar 2021 20:16:00 +0800 Subject: [PATCH 48/77] black format --- examples/highfreq/highfreq_processor.py | 2 +- examples/highfreq/workflow.py | 4 ++-- examples/rolling_process_data/workflow.py | 6 ++++-- qlib/data/dataset/__init__.py | 14 +++++--------- qlib/data/dataset/handler.py | 4 ++-- qlib/data/dataset/loader.py | 1 + qlib/data/dataset/processor.py | 4 +++- 7 files changed, 18 insertions(+), 17 deletions(-) diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py index 6ed68ff38..d843c6ac0 100644 --- a/examples/highfreq/highfreq_processor.py +++ b/examples/highfreq/highfreq_processor.py @@ -70,7 +70,7 @@ class HighFreqNorm(Processor): columns=["FEATURE_%d" % i for i in range(12 * 240)], ).sort_index() return df_new_features - + def config(self, fit_start_time=None, fit_end_time=None, **kwargs): if fit_start_time: self.fit_start_time = fit_start_time diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 97762f182..94c9b689f 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -177,8 +177,8 @@ class HighfreqWorkflow(object): dataset_backtest.setup_data(handler_kwargs={}) ##=============get data============= - xtest, = dataset.prepare(["test"]) - backtest_test, = dataset_backtest.prepare(["test"]) + (xtest,) = dataset.prepare(["test"]) + (backtest_test,) = dataset_backtest.prepare(["test"]) print(xtest, backtest_test) return diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index ffdd8329a..02f43889d 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -105,7 +105,7 @@ class RollingDataWorkflow(object): handler_kwargs={ "start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), "end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), - "processor_kwargs":{ + "processor_kwargs": { "fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), "fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), }, @@ -126,7 +126,9 @@ class RollingDataWorkflow(object): }, ) dataset.setup_data( - handler_kwargs={"init_type": DataHandlerLP.IT_FIT_SEQ,} + handler_kwargs={ + "init_type": DataHandlerLP.IT_FIT_SEQ, + } ) dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"]) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index aa90cee2f..d8a9e0209 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -35,7 +35,7 @@ class Dataset(Serializable): def config(self, *arg, **kwargs): """ - config is designed to configure and parameters that cannot be learned from the data + config is designed to configure and parameters that cannot be learned from the data """ super().config(*arg, **kwargs) @@ -117,7 +117,7 @@ class DatasetH(Dataset): self.segments = segments.copy() super().__init__(**kwargs) - def config(self, handler_kwargs:dict = None, segments:dict = None, **kwargs): + def config(self, handler_kwargs: dict = None, segments: dict = None, **kwargs): """ Initialize the DatasetH @@ -130,7 +130,7 @@ class DatasetH(Dataset): kwargs : dict Config of DatasetH, such as - + - segments : dict Config of segments which is same as 'segments' in self.__init__ @@ -141,8 +141,6 @@ class DatasetH(Dataset): if segments is not None: self.segments = segments.copy() - - def setup_data(self, handler_kwargs: dict = None, **kwargs): """ Setup the Data @@ -151,16 +149,15 @@ class DatasetH(Dataset): ---------- handler_kwargs : dict init arguments of DataHanlder, which could include the following arguments: - + - init_type : Init Type of Handler - + - enable_cache : wheter to enable cache """ super().setup_data(**kwargs) if handler_kwargs is not None: self.handler.setup_data(**handler_kwargs) - def __repr__(self): return "{name}(handler={handler}, segments={segments})".format( @@ -464,7 +461,6 @@ class TSDatasetH(DatasetH): cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() cal = sorted(cal) self.cal = cal - def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: # Dataset decide how to slice data(Get more data for timeseries). diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 4adef23a0..2190deeb1 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -119,7 +119,7 @@ class DataHandler(Serializable): self.start_time = start_time if end_time: self.end_time = end_time - + def setup_data(self, enable_cache: bool = False): """ Set Up the data. @@ -407,7 +407,7 @@ class DataHandlerLP(DataHandler): if self.drop_raw: del self._data - def config(self, processor_kwargs:dict = None, **kwargs): + def config(self, processor_kwargs: dict = None, **kwargs): """ configuration of data. # what data to be loaded from data source diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index a58bca5e8..1cda5c025 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -53,6 +53,7 @@ class DataLoader(abc.ABC): """ pass + class DLWParser(DataLoader): """ (D)ata(L)oader (W)ith (P)arser for features and names diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 5be178c5c..d25d36c88 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -202,6 +202,7 @@ class MinMaxNorm(Processor): self.fit_end_time = fit_end_time super().config(**kwargs) + class ZScoreNorm(Processor): """ZScore Normalization""" @@ -229,7 +230,7 @@ class ZScoreNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df - + def config(self, fit_start_time=None, fit_end_time=None, **kwargs): if fit_start_time: self.fit_start_time = fit_start_time @@ -280,6 +281,7 @@ class RobustZScoreNorm(Processor): self.fit_end_time = fit_end_time super().config(**kwargs) + class CSZScoreNorm(Processor): """Cross Sectional ZScore Normalization""" From d18c3674974dfa3593424418e53d167247dadf74 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 29 Mar 2021 20:34:36 +0800 Subject: [PATCH 49/77] update README --- examples/rolling_process_data/README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/examples/rolling_process_data/README.md b/examples/rolling_process_data/README.md index 6a6af0d3d..b04f5ed7f 100644 --- a/examples/rolling_process_data/README.md +++ b/examples/rolling_process_data/README.md @@ -1 +1,17 @@ # Rolling Process Data + +This workflow is an example for `Rolling Process Data`. + +## Background + +When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will also change, and the processor's learnable state (such as standard deviation, mean, etc.) will also be changed. + +In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the sliding window. + + +### Run the Code + +Run the example by running the following command: +```bash + python workflow.py rolling_process +``` \ No newline at end of file From 1074284666113389cbcb6c5707f59e5c69f07f99 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 29 Mar 2021 20:38:09 +0800 Subject: [PATCH 50/77] fix docstring --- qlib/data/dataset/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index d8a9e0209..668ea833b 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -20,9 +20,7 @@ class Dataset(Serializable): """ init is designed to finish following steps: - - init instance - - - config the state of the dataset(info to prepare the data) + - init the sub instance and the state of the dataset(info to prepare the data) - The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing. - setup data From 136830bc2bf8281838d96c22fb0cdd45e93ae16b Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 30 Mar 2021 00:38:15 +0800 Subject: [PATCH 51/77] update comments --- examples/highfreq/highfreq_processor.py | 7 ----- examples/highfreq/workflow.py | 6 ++--- examples/rolling_process_data/workflow.py | 2 +- qlib/data/dataset/__init__.py | 27 ++++++++++---------- qlib/data/dataset/handler.py | 17 ++++++++----- qlib/data/dataset/loader.py | 2 +- qlib/data/dataset/processor.py | 31 +++++++---------------- 7 files changed, 38 insertions(+), 54 deletions(-) diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py index d843c6ac0..f0ab0dec2 100644 --- a/examples/highfreq/highfreq_processor.py +++ b/examples/highfreq/highfreq_processor.py @@ -70,10 +70,3 @@ class HighFreqNorm(Processor): columns=["FEATURE_%d" % i for i in range(12 * 240)], ).sort_index() return df_new_features - - def config(self, fit_start_time=None, fit_end_time=None, **kwargs): - if fit_start_time: - self.fit_start_time = fit_start_time - if fit_end_time: - self.fit_end_time = fit_end_time - super().config(**kwargs) diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 94c9b689f..5660ab2e9 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -27,7 +27,7 @@ from qlib.tests.data import GetData from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut -class HighfreqWorkflow(object): +class HighfreqWorkflow: SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None} @@ -177,8 +177,8 @@ class HighfreqWorkflow(object): dataset_backtest.setup_data(handler_kwargs={}) ##=============get data============= - (xtest,) = dataset.prepare(["test"]) - (backtest_test,) = dataset_backtest.prepare(["test"]) + xtest = dataset.prepare("test") + backtest_test = dataset_backtest.prepare("test") print(xtest, backtest_test) return diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 02f43889d..5757aaa87 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -14,7 +14,7 @@ from qlib.utils import exists_qlib_data, init_instance_by_config from qlib.tests.data import GetData -class RollingDataWorkflow(object): +class RollingDataWorkflow: MARKET = "csi300" start_time = "2010-01-01" diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 668ea833b..b3eaac7a3 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -3,6 +3,7 @@ from typing import Union, List, Tuple, Dict, Text, Optional from ...utils import init_instance_by_config, np_ffill from ...log import get_module_logger from .handler import DataHandler, DataHandlerLP +from copy import deepcopy from inspect import getfullargspec import pandas as pd import numpy as np @@ -16,7 +17,7 @@ class Dataset(Serializable): Preparing data for model training and inferencing. """ - def __init__(self, *args, **kwargs): + def __init__(self, **kwargs): """ init is designed to finish following steps: @@ -28,16 +29,16 @@ class Dataset(Serializable): The data could specify the info to caculate the essential data for preparation """ - self.setup_data(*args, **kwargs) + self.setup_data(**kwargs) super().__init__() - def config(self, *arg, **kwargs): + def config(self, **kwargs): """ config is designed to configure and parameters that cannot be learned from the data """ - super().config(*arg, **kwargs) + super().config(**kwargs) - def setup_data(self, *args, **kwargs): + def setup_data(self, **kwargs): """ Setup the data. @@ -53,7 +54,7 @@ class Dataset(Serializable): """ pass - def prepare(self, *args, **kwargs) -> object: + def prepare(self, **kwargs) -> object: """ The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.) The parameters should specify the scope for the prepared data @@ -115,7 +116,7 @@ class DatasetH(Dataset): self.segments = segments.copy() super().__init__(**kwargs) - def config(self, handler_kwargs: dict = None, segments: dict = None, **kwargs): + def config(self, handler_kwargs: dict = None, **kwargs): """ Initialize the DatasetH @@ -133,11 +134,11 @@ class DatasetH(Dataset): Config of segments which is same as 'segments' in self.__init__ """ - super().config(**kwargs) if handler_kwargs is not None: self.handler.config(**handler_kwargs) - if segments is not None: - self.segments = segments.copy() + if "segments" in kwargs: + self.segments = deepcopy(kwargs.pop("segments")) + super().config(**kwargs) def setup_data(self, handler_kwargs: dict = None, **kwargs): """ @@ -449,10 +450,10 @@ class TSDatasetH(DatasetH): self.step_len = step_len super().__init__(**kwargs) - def config(self, step_len=None, **kwargs): + def config(self, **kwargs): + if "step_len" in kwargs: + self.step_len = kwargs.pop("step_len") super().config(**kwargs) - if step_len: - self.step_len = step_len def setup_data(self, **kwargs): super().setup_data(**kwargs) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 2190deeb1..7fb7090d2 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -103,7 +103,7 @@ class DataHandler(Serializable): self.setup_data() super().__init__() - def config(self, instruments=None, start_time=None, end_time=None, **kwargs): + def config(self, **kwargs): """ configuration of data. # what data to be loaded from data source @@ -112,13 +112,16 @@ class DataHandler(Serializable): The data will be initialized with different time range. """ + attr_list = {"instruments", "start_time", "end_time"} + for k, v in kwargs.items(): + if k in attr_list: + setattr(self, k, v) + + for attr in attr_list: + if attr in kwargs: + kwargs.pop(attr) + super().config(**kwargs) - if instruments: - self.instruments = instruments - if start_time: - self.start_time = start_time - if end_time: - self.end_time = end_time def setup_data(self, enable_cache: bool = False): """ diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 1cda5c025..58aca1d4f 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -261,7 +261,7 @@ class DataLoaderDH(DataLoader): self.is_group = is_group self.fetch_kwargs = {"col_set": DataHandler.CS_RAW} - self.fetch_kwargs = {**self.fetch_kwargs, **fetch_kwargs} + self.fetch_kwargs.update(fetch_kwargs) def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: if instruments is not None: diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index d25d36c88..8f69a5dff 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -73,7 +73,15 @@ class Processor(Serializable): return True def config(self, **kwargs): - super().config(kwargs.get("dump_all", None), kwargs.get("exclude", None)) + attr_list = {"fit_start_time", "fit_end_time"} + for k, v in kwargs.items(): + if k in attr_list and getattr(self, k, None) is not None: + setattr(self, k, v) + + for attr in attr_list: + if attr in kwargs: + kwargs.pop(attr) + super().config(**kwargs) class DropnaProcessor(Processor): @@ -195,13 +203,6 @@ class MinMaxNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df - def config(self, fit_start_time=None, fit_end_time=None, **kwargs): - if fit_start_time: - self.fit_start_time = fit_start_time - if fit_end_time: - self.fit_end_time = fit_end_time - super().config(**kwargs) - class ZScoreNorm(Processor): """ZScore Normalization""" @@ -231,13 +232,6 @@ class ZScoreNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df - def config(self, fit_start_time=None, fit_end_time=None, **kwargs): - if fit_start_time: - self.fit_start_time = fit_start_time - if fit_end_time: - self.fit_end_time = fit_end_time - super().config(**kwargs) - class RobustZScoreNorm(Processor): """Robust ZScore Normalization @@ -274,13 +268,6 @@ class RobustZScoreNorm(Processor): df.clip(-3, 3, inplace=True) return df - def config(self, fit_start_time=None, fit_end_time=None, **kwargs): - if fit_start_time: - self.fit_start_time = fit_start_time - if fit_end_time: - self.fit_end_time = fit_end_time - super().config(**kwargs) - class CSZScoreNorm(Processor): """Cross Sectional ZScore Normalization""" From f8da79b802d617234f6ae20bea2ae2bc771c39a9 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 30 Mar 2021 00:54:00 +0800 Subject: [PATCH 52/77] fix readme --- examples/rolling_process_data/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rolling_process_data/README.md b/examples/rolling_process_data/README.md index b04f5ed7f..c84eaac20 100644 --- a/examples/rolling_process_data/README.md +++ b/examples/rolling_process_data/README.md @@ -9,7 +9,7 @@ When rolling train the models, data also needs to be generated in the different In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the sliding window. -### Run the Code +## Run the Code Run the example by running the following command: ```bash From 023603479c5e451671d2c68fcec65574ec847fe7 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 30 Mar 2021 01:00:12 +0800 Subject: [PATCH 53/77] fix readme --- examples/rolling_process_data/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/rolling_process_data/README.md b/examples/rolling_process_data/README.md index c84eaac20..315fe2eed 100644 --- a/examples/rolling_process_data/README.md +++ b/examples/rolling_process_data/README.md @@ -4,9 +4,9 @@ This workflow is an example for `Rolling Process Data`. ## Background -When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will also change, and the processor's learnable state (such as standard deviation, mean, etc.) will also be changed. +When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change. -In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the sliding window. +In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window. ## Run the Code From 7a2203f116bd79338481ffe439ad389b247c0e03 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 30 Mar 2021 11:03:07 +0800 Subject: [PATCH 54/77] update comments --- qlib/data/dataset/handler.py | 5 ++--- qlib/data/dataset/processor.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 7fb7090d2..201d2459d 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -125,8 +125,7 @@ class DataHandler(Serializable): def setup_data(self, enable_cache: bool = False): """ - Set Up the data. - In case of running intialization for multiple time, it will do nothing for the second time. + Set Up the data in case of running intialization for multiple time It is responsible for maintaining following variable 1) self._data @@ -431,7 +430,7 @@ class DataHandlerLP(DataHandler): def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs): """ - Set up the data of Qlib + Set up the data in case of running intialization for multiple time Parameters ---------- diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 8f69a5dff..e035f5624 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -75,7 +75,7 @@ class Processor(Serializable): def config(self, **kwargs): attr_list = {"fit_start_time", "fit_end_time"} for k, v in kwargs.items(): - if k in attr_list and getattr(self, k, None) is not None: + if k in attr_list and hasattr(self, k): setattr(self, k, v) for attr in attr_list: From b6df11b6b45763fcfdae1197f2cf8e56d64804a0 Mon Sep 17 00:00:00 2001 From: Jactus Date: Tue, 30 Mar 2021 14:41:56 +0800 Subject: [PATCH 55/77] Modify get_exp & get_recorder api --- qlib/workflow/__init__.py | 8 ++++---- qlib/workflow/exp.py | 12 +++++++----- qlib/workflow/expm.py | 11 ++++++----- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 678ae99a8..a03665626 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -202,13 +202,13 @@ class QlibRecorder: - no id or name specified, return the active experiment. - - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active. + - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. - If `active experiment` not exists: - no id or name specified, create a default experiment, and the experiment is set to be active. - - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment, and the experiment is set to be active. + - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment. - Else If '`create`' is False: @@ -260,7 +260,7 @@ class QlibRecorder: ------- An experiment instance with given id or name. """ - return self.exp_manager.get_exp(experiment_id, experiment_name, create) + return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False) def delete_exp(self, experiment_id=None, experiment_name=None): """ @@ -358,7 +358,7 @@ class QlibRecorder: A recorder instance. """ return self.get_exp(experiment_name=experiment_name, create=False).get_recorder( - recorder_id, recorder_name, create=False + recorder_id, recorder_name, create=False, start=False ) def delete_recorder(self, recorder_id=None, recorder_name=None): diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 0f420cec4..dd73f7f52 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -107,24 +107,24 @@ class Experiment: """ raise NotImplementedError(f"Please implement the `delete_recorder` method.") - def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True): + def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True, start: bool = False): """ Retrieve a Recorder for user. When user specify recorder id and name, the method will try to return the specific recorder. When user does not provide recorder id or name, the method will try to return the current active recorder. The `create` argument determines whether the method will automatically create a new recorder - according to user's specification if the recorder hasn't been created before + according to user's specification if the recorder hasn't been created before. * If `create` is True: * If `active recorder` exists: * no id or name specified, return the active recorder. - * if id or name is specified, return the specified recorder. If no such exp found, create a new recorder with given id or name, and the recorder shoud be active. + * if id or name is specified, return the specified recorder. If no such exp found, create a new recorder with given id or name. If `start` is set to be True, the recorder is set to be active. * If `active recorder` not exists: * no id or name specified, create a new recorder. - * if id or name is specified, return the specified experiment. If no such exp found, create a new recorder with given id or name, and the recorder shoud be active. + * if id or name is specified, return the specified experiment. If no such exp found, create a new recorder with given id or name. If `start` is set to be True, the recorder is set to be active. * Else If `create` is False: @@ -146,6 +146,8 @@ class Experiment: the name of the recorder to be deleted. create : boolean create the recorder if it hasn't been created before. + start : boolean + start the new recorder if one is created. Returns ------- @@ -163,7 +165,7 @@ class Experiment: self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False, ) - if is_new: + if is_new and start: self.active_recorder = recorder # start the recorder self.active_recorder.start_run() diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 28d6d92c7..5275e57d7 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -102,10 +102,9 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `search_records` method.") - def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True): + def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False): """ Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment. - The returned experiment will be active. When user specify experiment id and name, the method will try to return the specific experiment. When user does not provide recorder id or name, the method will try to return the current active experiment. @@ -117,12 +116,12 @@ class ExpManager: * If `active experiment` exists: * no id or name specified, return the active experiment. - * if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active. + * if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. If `start` is set to be True, the experiment is set to be active. * If `active experiment` not exists: * no id or name specified, create a default experiment. - * if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active. + * if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. If `start` is set to be True, the experiment is set to be active. * Else If `create` is False: @@ -144,6 +143,8 @@ class ExpManager: name of the experiment to return. create : boolean create the experiment it if hasn't been created before. + start : boolean + start the new experiment if one is created. Returns ------- @@ -163,7 +164,7 @@ class ExpManager: self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False, ) - if is_new: + if is_new and start: self.active_experiment = exp # start the recorder self.active_experiment.start() From bed1175e2404ccd4d711bb71aff9577c8449c6a9 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Tue, 30 Mar 2021 19:29:17 +0800 Subject: [PATCH 56/77] update dataset --- .../highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml index c21ef1da3..45c59c670 100644 --- a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml +++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml @@ -1,7 +1,7 @@ qlib_init: - provider_uri: "~/.qlib/qlib_data/yahoo_cn_1min" + provider_uri: "~/.qlib/qlib_data/cn_data_1min" region: cn -market: &market ['SH605222', 'SZ002796', 'SZ002246', 'SZ000713', 'SZ002820', 'SH601328', 'SZ000668', 'SH603359', 'SZ002144', 'SH600195', 'SH603685', 'SH603386', 'SZ002586', 'SZ000573', 'SZ000605', 'SZ002842', 'SH600068', 'SZ300547', 'SZ000926', 'SZ002036', 'SZ002161', 'SH600715', 'SZ300427', 'SZ002573', 'SZ300142', 'SH605116', 'SZ002951', 'SH600276', 'SZ002437', 'SH603355', 'SZ002893', 'SH600584'] +market: &market 'csi300' start_time: &start_time "2020-09-15 00:00:00" end_time: &end_time "2021-01-18 16:00:00" train_end_time: &train_end_time "2020-11-15 16:00:00" From fe190dec4b6670a8e0d5410545d7bb8a13304157 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 14 Apr 2021 14:40:28 +0800 Subject: [PATCH 57/77] update readme --- examples/highfreq/README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/highfreq/README.md b/examples/highfreq/README.md index 30c2e19db..c07d8a2a0 100644 --- a/examples/highfreq/README.md +++ b/examples/highfreq/README.md @@ -25,4 +25,11 @@ The example is given in `workflow.py`, users can run the code as follows. Run the example by running the following command: ```bash python workflow.py dump_and_load_dataset -``` \ No newline at end of file +``` + +## Benchmarks Performance +### Signal Test +Here are the results of signal test for benchmark models. We will keep updating benchmark models in future. +| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe | +|---|---|---|---|---|---|---|---|---|---| +| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 | From 941c980d06371b83cf54eef8e84b0614104eb5d4 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 14 Apr 2021 17:35:19 +0800 Subject: [PATCH 58/77] update tabnet --- examples/benchmarks/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/benchmarks/README.md b/examples/benchmarks/README.md index f1e7437fa..c3d965d85 100644 --- a/examples/benchmarks/README.md +++ b/examples/benchmarks/README.md @@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 | | GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 | | DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 | +| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 | ## Alpha158 dataset | Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown | @@ -32,6 +33,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 | | GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 | | DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 | +| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 | - The selected 20 features are based on the feature importance of a lightgbm-based model. - The base model of DoubleEnsemble is LGBM. From 848d953226cedc782f8949838698801458b1a829 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 09:58:55 +0800 Subject: [PATCH 59/77] Update qlib logger --- qlib/log.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/qlib/log.py b/qlib/log.py index 126acb9d2..ed050f6c9 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -11,6 +11,26 @@ from contextlib import contextmanager from .config import C +class QlibLogger(Logger,meta=): + ''' + Customized logger for Qlib. + ''' + def __init__(self, module_name): + self.module_name = module_name + self.level = 0 + + @property + def logger(self): + logger = logging.getLogger(self.module_name) + logger.setLevel(self.level) + return logger + + def setLevel(self, level): + self.level = level + + def __getattr__(self, name): + return self.logger.__getattribute__(name) + def get_module_logger(module_name, level: Optional[int] = None): """ @@ -27,7 +47,7 @@ def get_module_logger(module_name, level: Optional[int] = None): module_name = "qlib.{}".format(module_name) # Get logger. - module_logger = logging.getLogger(module_name) + module_logger = QlibLogger(module_name) module_logger.setLevel(level) return module_logger From 78bb8882cd4f23e20a14d69682b54cdd24a3e200 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 12:00:18 +0800 Subject: [PATCH 60/77] Format --- qlib/log.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index ed050f6c9..017e8e339 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -11,10 +11,12 @@ from contextlib import contextmanager from .config import C -class QlibLogger(Logger,meta=): - ''' + +class QlibLogger: + """ Customized logger for Qlib. - ''' + """ + def __init__(self, module_name): self.module_name = module_name self.level = 0 @@ -27,10 +29,10 @@ class QlibLogger(Logger,meta=): def setLevel(self, level): self.level = level - + def __getattr__(self, name): return self.logger.__getattribute__(name) - + def get_module_logger(module_name, level: Optional[int] = None): """ From f4bfe8e6197aa52bfb759c3981346953f8306f41 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 14:35:05 +0800 Subject: [PATCH 61/77] First trial of adding docstring --- qlib/log.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/qlib/log.py b/qlib/log.py index 017e8e339..c7d269f4d 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -12,7 +12,23 @@ from contextlib import contextmanager from .config import C -class QlibLogger: +class MetaLogger(type): + def __init__(self, name, bases, dic): + super().__init__(name, bases, dic) + + def __new__(cls, name, bases, dict): + wrapper_dict = type(logging.getLogger("module_name")).__dict__.copy() + wrapper_dict.update(dict) + wrapper_dict["__doc__"] = logging.getLogger("module_name").__doc__ + return type.__new__(cls, name, bases, wrapper_dict) + + def __call__(cls, *args, **kwargs): + obj = cls.__new__(cls) + cls.__init__(cls, *args, **kwargs) + return obj + + +class QlibLogger(metaclass=MetaLogger): """ Customized logger for Qlib. """ From 4ebf68479416932d8e28fdd4af289655e54a254f Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 15:35:11 +0800 Subject: [PATCH 62/77] Update workflow logging --- qlib/log.py | 4 ++-- qlib/workflow/exp.py | 4 ++-- qlib/workflow/expm.py | 4 ++-- qlib/workflow/record_temp.py | 4 ++-- qlib/workflow/recorder.py | 4 ++-- qlib/workflow/utils.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index c7d269f4d..4ecdceef2 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -17,9 +17,9 @@ class MetaLogger(type): super().__init__(name, bases, dic) def __new__(cls, name, bases, dict): - wrapper_dict = type(logging.getLogger("module_name")).__dict__.copy() + wrapper_dict = type(logging.getLogger("MetaLogger")).__dict__.copy() wrapper_dict.update(dict) - wrapper_dict["__doc__"] = logging.getLogger("module_name").__doc__ + wrapper_dict["__doc__"] = logging.getLogger("MetaLogger").__doc__ return type.__new__(cls, name, bases, wrapper_dict) def __call__(cls, *args, **kwargs): diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index dd73f7f52..7b3d1f507 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import mlflow +import mlflow, logging from mlflow.entities import ViewType from mlflow.exceptions import MlflowException from pathlib import Path from .recorder import Recorder, MLflowRecorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class Experiment: diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 5275e57d7..590790c9e 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -4,7 +4,7 @@ import mlflow from mlflow.exceptions import MlflowException from mlflow.entities import ViewType -import os +import os, logging from pathlib import Path from contextlib import contextmanager from typing import Optional, Text @@ -14,7 +14,7 @@ from ..config import C from .recorder import Recorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class ExpManager: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index dee327f64..5732c95a9 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import re +import re, logging import pandas as pd from pathlib import Path from pprint import pprint @@ -16,7 +16,7 @@ from ..utils import flatten_dict from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec from ..contrib.strategy.strategy import BaseStrategy -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class RecordTemp: diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 5915e58da..b9b2fd1b3 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import mlflow +import mlflow, logging import shutil, os, pickle, tempfile, codecs, pickle from pathlib import Path from datetime import datetime from ..utils.objm import FileManager from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class Recorder: diff --git a/qlib/workflow/utils.py b/qlib/workflow/utils.py index 33d251dd8..596ff0927 100644 --- a/qlib/workflow/utils.py +++ b/qlib/workflow/utils.py @@ -1,12 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys, traceback, signal, atexit +import sys, traceback, signal, atexit, logging from . import R from .recorder import Recorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) # function to handle the experiment when unusual program ending occurs From cbf1fa721ed85f0d2e89ff19f9ec0e08af2339c2 Mon Sep 17 00:00:00 2001 From: Jactus Date: Sat, 17 Apr 2021 15:47:49 +0800 Subject: [PATCH 63/77] Update --- qlib/contrib/workflow/record_temp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 12792fbcb..bedf89105 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import logging import pandas as pd +import numpy as np from sklearn.metrics import mean_squared_error from typing import Dict, Text, Any -import numpy as np from ...contrib.eva.alpha import calc_ic from ...workflow.record_temp import RecordTemp @@ -12,7 +13,7 @@ from ...workflow.record_temp import SignalRecord from ...data import dataset as qlib_dataset from ...log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class MultiSegRecord(RecordTemp): From 6a05d4e2559f1917e6411478426cab6c4f6eaa78 Mon Sep 17 00:00:00 2001 From: Jactus Date: Mon, 19 Apr 2021 11:36:00 +0800 Subject: [PATCH 64/77] Enable IDEs docstrings --- qlib/log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/log.py b/qlib/log.py index 4ecdceef2..8b123d05d 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -50,7 +50,7 @@ class QlibLogger(metaclass=MetaLogger): return self.logger.__getattribute__(name) -def get_module_logger(module_name, level: Optional[int] = None): +def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger: """ Get a logger for a specific module. From aafaff45d2b0d2740d83b2651a8887f51011037b Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 22 Apr 2021 14:13:36 +0800 Subject: [PATCH 65/77] Update doc --- qlib/contrib/backtest/backtest.py | 7 +++++-- qlib/contrib/report/analysis_position/cumulative_return.py | 2 +- qlib/contrib/report/analysis_position/rank_label.py | 2 +- qlib/contrib/report/analysis_position/report.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index b87d6afe3..909948c25 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -15,7 +15,8 @@ LOG = get_module_logger("backtest") def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order): - """Parameters + """ + Parameters ---------- pred : pandas.DataFrame predict should has index and one `score` column @@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, def update_account(trade_account, trade_info, trade_exchange, trade_date): - """Update the account and strategy + """ + Update the account and strategy + Parameters ---------- trade_account : Account() diff --git a/qlib/contrib/report/analysis_position/cumulative_return.py b/qlib/contrib/report/analysis_position/cumulative_return.py index abb68ea60..00985a17c 100644 --- a/qlib/contrib/report/analysis_position/cumulative_return.py +++ b/qlib/contrib/report/analysis_position/cumulative_return.py @@ -214,7 +214,7 @@ def cumulative_return_graph( features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] - qcr.cumulative_return_graph(positions, report_normal_df, features_df) + qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df) Graph desc: diff --git a/qlib/contrib/report/analysis_position/rank_label.py b/qlib/contrib/report/analysis_position/rank_label.py index 72a358adc..77743b10c 100644 --- a/qlib/contrib/report/analysis_position/rank_label.py +++ b/qlib/contrib/report/analysis_position/rank_label.py @@ -94,7 +94,7 @@ def rank_label_graph( features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] - qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) + qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) :param position: position data; **qlib.contrib.backtest.backtest.backtest** result. diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py index f82e654c4..6b83f0734 100644 --- a/qlib/contrib/report/analysis_position/report.py +++ b/qlib/contrib/report/analysis_position/report.py @@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list, report_normal_df, _ = backtest(pred_df, strategy, **bparas) - qcr.report_graph(report_normal_df) + qcr.analysis_position.report_graph(report_normal_df) :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**. From 8adfafa6aa6f76591ae2af537f9d8ad91ccb6c43 Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 22 Apr 2021 14:17:25 +0800 Subject: [PATCH 66/77] Black format --- qlib/contrib/backtest/backtest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index 909948c25..fc30065fd 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -127,7 +127,7 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, def update_account(trade_account, trade_info, trade_exchange, trade_date): """ Update the account and strategy - + Parameters ---------- trade_account : Account() From fbff4c271a7e74f2f0b4770912abf2fb01a9354b Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 23 Apr 2021 00:38:45 +0800 Subject: [PATCH 67/77] Remove redundant methods in meta --- qlib/log.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 8b123d05d..3b3362d5b 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -13,8 +13,6 @@ from .config import C class MetaLogger(type): - def __init__(self, name, bases, dic): - super().__init__(name, bases, dic) def __new__(cls, name, bases, dict): wrapper_dict = type(logging.getLogger("MetaLogger")).__dict__.copy() @@ -22,11 +20,6 @@ class MetaLogger(type): wrapper_dict["__doc__"] = logging.getLogger("MetaLogger").__doc__ return type.__new__(cls, name, bases, wrapper_dict) - def __call__(cls, *args, **kwargs): - obj = cls.__new__(cls) - cls.__init__(cls, *args, **kwargs) - return obj - class QlibLogger(metaclass=MetaLogger): """ From e410caaa8fb315de7898035986ec7cca58384bf0 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 23 Apr 2021 10:08:12 +0800 Subject: [PATCH 68/77] Simplify meta class --- qlib/log.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 3b3362d5b..5888b3841 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -13,11 +13,10 @@ from .config import C class MetaLogger(type): - def __new__(cls, name, bases, dict): - wrapper_dict = type(logging.getLogger("MetaLogger")).__dict__.copy() + wrapper_dict = logging.Logger.__dict__.copy() wrapper_dict.update(dict) - wrapper_dict["__doc__"] = logging.getLogger("MetaLogger").__doc__ + wrapper_dict["__doc__"] = logging.Logger.__doc__ return type.__new__(cls, name, bases, wrapper_dict) From e15ea06122bd570706ac8b6d3ab6b96b5ee64edb Mon Sep 17 00:00:00 2001 From: zhupr Date: Sun, 25 Apr 2021 23:50:29 +0800 Subject: [PATCH 69/77] Fix ClientProvider not supporting LocalInstrumentProvider && online using the latest python-socketio --- qlib/data/data.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index 000bd1196..cea2f42eb 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -1016,7 +1016,8 @@ class ClientProvider(BaseProvider): self.logger = get_module_logger(self.__class__.__name__) if isinstance(Cal, ClientCalendarProvider): Cal.set_conn(self.client) - Inst.set_conn(self.client) + if isinstance(Inst, ClientInstrumentProvider): + Inst.set_conn(self.client) if hasattr(DatasetD, "provider"): DatasetD.provider.set_conn(self.client) else: diff --git a/setup.py b/setup.py index 83cf6e1b6..747d885f4 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ REQUIRED = [ "scipy>=1.0.0", "requests>=2.18.0", "sacred>=0.7.4", - "python-socketio==3.1.2", + "python-socketio", "redis>=3.0.1", "python-redis-lock>=3.3.1", "schedule>=0.6.0", From 5a7eecabeefdf5218a4a4ea1db5ed94343df6c42 Mon Sep 17 00:00:00 2001 From: Young Date: Tue, 27 Apr 2021 04:04:43 +0000 Subject: [PATCH 70/77] black formating (black is upgraded in github) --- examples/benchmarks/TFT/data_formatters/base.py | 2 +- qlib/contrib/backtest/position.py | 2 +- qlib/contrib/report/graph.py | 2 +- qlib/data/dataset/processor.py | 4 ++-- qlib/model/base.py | 4 ++-- qlib/portfolio/optimizer/base.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/benchmarks/TFT/data_formatters/base.py b/examples/benchmarks/TFT/data_formatters/base.py index c68a192ba..aa1c0dc82 100644 --- a/examples/benchmarks/TFT/data_formatters/base.py +++ b/examples/benchmarks/TFT/data_formatters/base.py @@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC): return -1, -1 def get_column_definition(self): - """"Returns formatted column definition in order expected by the TFT.""" + """Returns formatted column definition in order expected by the TFT.""" column_definition = self._column_definition diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py index 6c269d505..97abc2a56 100644 --- a/qlib/contrib/backtest/position.py +++ b/qlib/contrib/backtest/position.py @@ -128,7 +128,7 @@ class Position: return self.position["cash"] def get_stock_amount_dict(self): - """generate stock amount dict {stock_id : amount of stock} """ + """generate stock amount dict {stock_id : amount of stock}""" d = {} stock_list = self.get_stock_list() for stock_code in stock_list: diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py index 677e767ee..2d4f546e8 100644 --- a/qlib/contrib/report/graph.py +++ b/qlib/contrib/report/graph.py @@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path class BaseGraph: - """""" + """ """ _name = None diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index e035f5624..7635a4127 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -129,7 +129,7 @@ class FilterCol(Processor): class TanhProcess(Processor): - """ Use tanh to process noise data""" + """Use tanh to process noise data""" def __call__(self, df): def tanh_denoise(data): @@ -144,7 +144,7 @@ class TanhProcess(Processor): class ProcessInf(Processor): - """Process infinity """ + """Process infinity""" def __call__(self, df): def replace_inf(data): diff --git a/qlib/model/base.py b/qlib/model/base.py index 1ac8f2fc9..12caf5f73 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -11,11 +11,11 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta): @abc.abstractmethod def predict(self, *args, **kwargs) -> object: - """ Make predictions after modeling things """ + """Make predictions after modeling things""" pass def __call__(self, *args, **kwargs) -> object: - """ leverage Python syntactic sugar to make the models' behaviors like functions """ + """leverage Python syntactic sugar to make the models' behaviors like functions""" return self.predict(*args, **kwargs) diff --git a/qlib/portfolio/optimizer/base.py b/qlib/portfolio/optimizer/base.py index 502443869..e3f692014 100644 --- a/qlib/portfolio/optimizer/base.py +++ b/qlib/portfolio/optimizer/base.py @@ -5,9 +5,9 @@ import abc class BaseOptimizer(abc.ABC): - """ Construct portfolio with a optimization related method """ + """Construct portfolio with a optimization related method""" @abc.abstractmethod def __call__(self, *args, **kwargs) -> object: - """ Generate a optimized portfolio allocation """ + """Generate a optimized portfolio allocation""" pass From eab19de080e2b2b1de93cdce7704c6535f2b2ced Mon Sep 17 00:00:00 2001 From: Jactus Date: Tue, 27 Apr 2021 16:56:07 +0800 Subject: [PATCH 71/77] Support start exp with given exp & recorder id --- qlib/workflow/__init__.py | 18 +++++++++++++++--- qlib/workflow/exp.py | 8 +++++--- qlib/workflow/expm.py | 12 ++++++++++-- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index a03665626..7cb1cf5cb 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -23,7 +23,9 @@ class QlibRecorder: @contextmanager def start( self, + experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, + recorder_id: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None, resume: bool = False, @@ -45,8 +47,12 @@ class QlibRecorder: Parameters ---------- + experiment_id : str + id of the experiment one wants to start. experiment_name : str name of the experiment one wants to start. + recorder_id : str + id of the recorder under the experiment one wants to start. recorder_name : str name of the recorder under the experiment one wants to start. uri : str @@ -57,7 +63,7 @@ class QlibRecorder: resume : bool whether to resume the specific recorder with given name under the given experiment. """ - run = self.start_exp(experiment_name, recorder_name, uri, resume) + run = self.start_exp(experiment_id, experiment_name, recorder_id, recorder_name, uri, resume) try: yield run except Exception as e: @@ -65,7 +71,9 @@ class QlibRecorder: raise e self.end_exp(Recorder.STATUS_FI) - def start_exp(self, experiment_name=None, recorder_name=None, uri=None, resume=False): + def start_exp( + self, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False + ): """ Lower level method for starting an experiment. When use this method, one should end the experiment manually and the status of the recorder may not be handled properly. Here is the example code: @@ -79,8 +87,12 @@ class QlibRecorder: Parameters ---------- + experiment_id : str + id of the experiment one wants to start. experiment_name : str the name of the experiment to be started + recorder_id : str + id of the recorder under the experiment one wants to start. recorder_name : str name of the recorder under the experiment one wants to start. uri : str @@ -93,7 +105,7 @@ class QlibRecorder: ------- An experiment instance being started. """ - return self.exp_manager.start_exp(experiment_name, recorder_name, uri, resume) + return self.exp_manager.start_exp(experiment_id, experiment_name, recorder_id, recorder_name, uri, resume) def end_exp(self, recorder_status=Recorder.STATUS_FI): """ diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 7b3d1f507..0a7e0a5a9 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -39,12 +39,14 @@ class Experiment: output["recorders"] = list(recorders.keys()) return output - def start(self, recorder_name=None, resume=False): + def start(self, recorder_id=None, recorder_name=None, resume=False): """ Start the experiment and set it to be active. This method will also start a new recorder. Parameters ---------- + recorder_id : str + the id of the recorder to be created. recorder_name : str the name of the recorder to be created. resume : bool @@ -238,14 +240,14 @@ class MLflowExperiment(Experiment): def __repr__(self): return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info) - def start(self, recorder_name=None, resume=False): + def start(self, recorder_id=None, recorder_name=None, resume=False): logger.info(f"Experiment {self.id} starts running ...") # Get or create recorder if recorder_name is None: recorder_name = self._default_rec_name # resume the recorder if resume: - recorder, _ = self._get_or_create_rec(recorder_name=recorder_name) + recorder, _ = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name) # create a new recorder else: recorder = self.create_recorder(recorder_name) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 590790c9e..5549bb9bf 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -33,7 +33,9 @@ class ExpManager: def start_exp( self, + experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, + recorder_id: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None, resume: bool = False, @@ -45,8 +47,12 @@ class ExpManager: Parameters ---------- + experiment_id : str + id of the active experiment. experiment_name : str name of the active experiment. + recorder_id : str + id of the recorder to be started. recorder_name : str name of the recorder to be started. uri : str @@ -298,7 +304,9 @@ class MLflowExpManager(ExpManager): def start_exp( self, + experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, + recorder_id: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None, resume: bool = False, @@ -308,11 +316,11 @@ class MLflowExpManager(ExpManager): # Create experiment if experiment_name is None: experiment_name = self._default_exp_name - experiment, _ = self._get_or_create_exp(experiment_name=experiment_name) + experiment, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name) # Set up active experiment self.active_experiment = experiment # Start the experiment - self.active_experiment.start(recorder_name, resume) + self.active_experiment.start(recorder_id, recorder_name, resume) return self.active_experiment From 8b8d21107c7f6dd6f6e6db371f4591179a4ad616 Mon Sep 17 00:00:00 2001 From: zhupr Date: Tue, 27 Apr 2021 21:20:47 +0800 Subject: [PATCH 72/77] Add future trading date collector --- qlib/data/data.py | 3 + scripts/data_collector/contrib/README.md | 24 +++++ .../contrib/future_trading_date_collector.py | 87 +++++++++++++++++++ .../data_collector/contrib/requirements.txt | 5 ++ scripts/data_collector/utils.py | 37 ++++++++ scripts/data_collector/yahoo/collector.py | 25 ++---- 6 files changed, 165 insertions(+), 16 deletions(-) create mode 100644 scripts/data_collector/contrib/README.md create mode 100644 scripts/data_collector/contrib/future_trading_date_collector.py create mode 100644 scripts/data_collector/contrib/requirements.txt diff --git a/qlib/data/data.py b/qlib/data/data.py index cea2f42eb..c2638e234 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -522,6 +522,9 @@ class LocalCalendarProvider(CalendarProvider): # if future calendar not exists, return current calendar if not os.path.exists(fname): get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!") + get_module_logger("data").warning( + "You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md" + ) fname = self._uri_cal.format(freq) else: fname = self._uri_cal.format(freq) diff --git a/scripts/data_collector/contrib/README.md b/scripts/data_collector/contrib/README.md new file mode 100644 index 000000000..011ff56e6 --- /dev/null +++ b/scripts/data_collector/contrib/README.md @@ -0,0 +1,24 @@ +# Get future trading days + +> `D.calendar(future=True)` will be used + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + +```bash +# parse instruments, using in qlib/instruments. +python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day +``` + +## Parameters + +- qlib_dir: qlib data directory +- freq: value from [`day`, `1min`], default `day` + + + diff --git a/scripts/data_collector/contrib/future_trading_date_collector.py b/scripts/data_collector/contrib/future_trading_date_collector.py new file mode 100644 index 000000000..4da62d465 --- /dev/null +++ b/scripts/data_collector/contrib/future_trading_date_collector.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from typing import List +from pathlib import Path + +import fire +import numpy as np +import pandas as pd +from loguru import logger + +# get data from baostock +import baostock as bs + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) + + +from data_collector.utils import generate_minutes_calendar_from_daily + + +def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame: + calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt") + if not calendar_path.exists(): + return pd.DataFrame() + return pd.read_csv(calendar_path, header=None) + + +def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"): + calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt")) + + np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8") + logger.info(f"write future calendars success: {calendar_path}") + + +def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]: + print(freq) + if freq == "day": + return date_list + elif freq == "1min": + date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist() + return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list)) + else: + raise ValueError(f"Unsupported freq: {freq}") + + +def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"): + """get future calendar + + Parameters + ---------- + qlib_dir: str or Path + qlib data directory + freq: str + value from ["day", "1min"], by default day + """ + qlib_dir = Path(qlib_dir).expanduser().resolve() + if not qlib_dir.exists(): + raise FileNotFoundError(str(qlib_dir)) + + lg = bs.login() + if lg.error_code != "0": + logger.error(f"login error: {lg.error_msg}") + return + # read daily calendar + daily_calendar = read_calendar_from_qlib(qlib_dir) + end_year = pd.Timestamp.now().year + if daily_calendar.empty: + start_year = pd.Timestamp.now().year + else: + start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year + rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31") + data_list = [] + while (rs.error_code == "0") & rs.next(): + _row_data = rs.get_row_data() + if int(_row_data[1]) == 1: + data_list.append(_row_data[0]) + data_list = sorted(data_list) + date_list = generate_qlib_calendar(data_list, freq=freq) + write_calendar_to_qlib(qlib_dir, date_list, freq=freq) + bs.logout() + logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31") + + +if __name__ == "__main__": + fire.Fire(future_calendar_collector) diff --git a/scripts/data_collector/contrib/requirements.txt b/scripts/data_collector/contrib/requirements.txt new file mode 100644 index 000000000..92dcb2374 --- /dev/null +++ b/scripts/data_collector/contrib/requirements.txt @@ -0,0 +1,5 @@ +baostock +fire +numpy +pandas +loguru diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index e8c9b9dc4..3f4539612 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -10,7 +10,9 @@ import random import requests import functools from pathlib import Path +from typing import Iterable, Tuple +import numpy as np import pandas as pd from lxml import etree from loguru import logger @@ -418,5 +420,40 @@ def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, sh return res +def generate_minutes_calendar_from_daily( + calendars: Iterable, + freq: str = "1min", + am_range: Tuple[str, str] = ("09:30:00", "11:29:00"), + pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"), +) -> pd.Index: + """generate minutes calendar + + Parameters + ---------- + calendars: Iterable + daily calendar + freq: str + by default 1min + am_range: Tuple[str, str] + AM Time Range, by default China-Stock: ("09:30:00", "11:29:00") + pm_range: Tuple[str, str] + PM Time Range, by default China-Stock: ("13:00:00", "14:59:00") + + """ + daily_format: str = "%Y-%m-%d" + res = [] + for _day in calendars: + for _range in [am_range, pm_range]: + res.append( + pd.date_range( + f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}", + f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}", + freq=freq, + ) + ) + + return pd.Index(sorted(set(np.hstack(res)))) + + if __name__ == "__main__": assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index f0e110694..a6e06613e 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -24,7 +24,12 @@ from qlib.config import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.base import BaseCollector, BaseNormalize, BaseRun -from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols +from data_collector.utils import ( + get_calendar_list, + get_hs_stock_symbols, + get_us_stock_symbols, + generate_minutes_calendar_from_daily, +) INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}" @@ -418,21 +423,9 @@ class YahooNormalize1min(YahooNormalize, ABC): return calendar_list_1d def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index: - res = [] - daily_format = self.DAILY_FORMAT - am_range = self.AM_RANGE - pm_range = self.PM_RANGE - for _day in calendars: - for _range in [am_range, pm_range]: - res.append( - pd.date_range( - f"{_day.strftime(daily_format)} {_range[0]}", - f"{_day.strftime(daily_format)} {_range[1]}", - freq="1min", - ) - ) - - return pd.Index(sorted(set(np.hstack(res)))) + return generate_minutes_calendar_from_daily( + calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE + ) def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: # TODO: using daily data factor From f58c61a2e0c313074729da6715d30d58e1503e69 Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 29 Apr 2021 16:54:51 +0800 Subject: [PATCH 73/77] Fix logger pickling error --- qlib/log.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/qlib/log.py b/qlib/log.py index 5888b3841..1d604e0c0 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -17,6 +17,7 @@ class MetaLogger(type): wrapper_dict = logging.Logger.__dict__.copy() wrapper_dict.update(dict) wrapper_dict["__doc__"] = logging.Logger.__doc__ + del wrapper_dict["__reduce__"] # make Logger object can be pickled return type.__new__(cls, name, bases, wrapper_dict) @@ -29,6 +30,15 @@ class QlibLogger(metaclass=MetaLogger): self.module_name = module_name self.level = 0 + def __getstate__(self): + return vars(self) + + def __setstate__(self, state): + vars(self).update(state) + + def __reduce__(self): + return (QlibLogger, (self.module_name,)) + @property def logger(self): logger = logging.getLogger(self.module_name) From ca92cb980ca9a49d9c41f98e5f2c2c6941a8a1ae Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 29 Apr 2021 22:40:52 +0800 Subject: [PATCH 74/77] Update meta logger --- qlib/log.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 1d604e0c0..19331f5d5 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -15,10 +15,11 @@ from .config import C class MetaLogger(type): def __new__(cls, name, bases, dict): wrapper_dict = logging.Logger.__dict__.copy() - wrapper_dict.update(dict) - wrapper_dict["__doc__"] = logging.Logger.__doc__ - del wrapper_dict["__reduce__"] # make Logger object can be pickled - return type.__new__(cls, name, bases, wrapper_dict) + for key in wrapper_dict: + if key not in dict and key != "__reduce__": + dict[key] = wrapper_dict[key] + dict["__doc__"] = logging.Logger.__doc__ + return type.__new__(cls, name, bases, dict) class QlibLogger(metaclass=MetaLogger): From 51b649ec395f4a80e96dd88b51ebdd8d2a192db2 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 30 Apr 2021 13:13:05 +0800 Subject: [PATCH 75/77] Update QlibLogger --- qlib/log.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 19331f5d5..d095d571a 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -31,12 +31,6 @@ class QlibLogger(metaclass=MetaLogger): self.module_name = module_name self.level = 0 - def __getstate__(self): - return vars(self) - - def __setstate__(self, state): - vars(self).update(state) - def __reduce__(self): return (QlibLogger, (self.module_name,)) @@ -50,6 +44,9 @@ class QlibLogger(metaclass=MetaLogger): self.level = level def __getattr__(self, name): + # During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error. + if name in {"__setstate__"}: + raise AttributeError return self.logger.__getattribute__(name) From 694ae3402766e582a6c067de807a997f1a9719c4 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 30 Apr 2021 13:27:19 +0800 Subject: [PATCH 76/77] Update api --- qlib/workflow/__init__.py | 21 ++++++++++++++++++--- qlib/workflow/exp.py | 4 ++-- qlib/workflow/expm.py | 4 +++- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 7cb1cf5cb..8135bab60 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -23,6 +23,7 @@ class QlibRecorder: @contextmanager def start( self, + *, experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, recorder_id: Optional[Text] = None, @@ -63,7 +64,14 @@ class QlibRecorder: resume : bool whether to resume the specific recorder with given name under the given experiment. """ - run = self.start_exp(experiment_id, experiment_name, recorder_id, recorder_name, uri, resume) + run = self.start_exp( + experiment_id=experiment_id, + experiment_name=experiment_name, + recorder_id=recorder_id, + recorder_name=recorder_name, + uri=uri, + resume=resume, + ) try: yield run except Exception as e: @@ -72,7 +80,7 @@ class QlibRecorder: self.end_exp(Recorder.STATUS_FI) def start_exp( - self, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False + self, *, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False ): """ Lower level method for starting an experiment. When use this method, one should end the experiment manually @@ -105,7 +113,14 @@ class QlibRecorder: ------- An experiment instance being started. """ - return self.exp_manager.start_exp(experiment_id, experiment_name, recorder_id, recorder_name, uri, resume) + return self.exp_manager.start_exp( + experiment_id=experiment_id, + experiment_name=experiment_name, + recorder_id=recorder_id, + recorder_name=recorder_name, + uri=uri, + resume=resume, + ) def end_exp(self, recorder_status=Recorder.STATUS_FI): """ diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 0a7e0a5a9..467c7c3f4 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -39,7 +39,7 @@ class Experiment: output["recorders"] = list(recorders.keys()) return output - def start(self, recorder_id=None, recorder_name=None, resume=False): + def start(self, *, recorder_id=None, recorder_name=None, resume=False): """ Start the experiment and set it to be active. This method will also start a new recorder. @@ -240,7 +240,7 @@ class MLflowExperiment(Experiment): def __repr__(self): return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info) - def start(self, recorder_id=None, recorder_name=None, resume=False): + def start(self, *, recorder_id=None, recorder_name=None, resume=False): logger.info(f"Experiment {self.id} starts running ...") # Get or create recorder if recorder_name is None: diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 5549bb9bf..04cc3bcb7 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -33,6 +33,7 @@ class ExpManager: def start_exp( self, + *, experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, recorder_id: Optional[Text] = None, @@ -304,6 +305,7 @@ class MLflowExpManager(ExpManager): def start_exp( self, + *, experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, recorder_id: Optional[Text] = None, @@ -320,7 +322,7 @@ class MLflowExpManager(ExpManager): # Set up active experiment self.active_experiment = experiment # Start the experiment - self.active_experiment.start(recorder_id, recorder_name, resume) + self.active_experiment.start(recorder_id=recorder_id, recorder_name=recorder_name, resume=resume) return self.active_experiment From 5eb9dfff166b79cdd2e00bc0ff7430f266db46b0 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 30 Apr 2021 15:28:37 +0800 Subject: [PATCH 77/77] Remove redundant --- qlib/log.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index d095d571a..e714bc15a 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -18,7 +18,6 @@ class MetaLogger(type): for key in wrapper_dict: if key not in dict and key != "__reduce__": dict[key] = wrapper_dict[key] - dict["__doc__"] = logging.Logger.__doc__ return type.__new__(cls, name, bases, dict) @@ -31,9 +30,6 @@ class QlibLogger(metaclass=MetaLogger): self.module_name = module_name self.level = 0 - def __reduce__(self): - return (QlibLogger, (self.module_name,)) - @property def logger(self): logger = logging.getLogger(self.module_name)