mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
Merge branch 'nested_decision_exe' of https://github.com/microsoft/qlib into rl-dummy
This commit is contained in:
@@ -263,11 +263,11 @@ class Account:
|
||||
elif atomic is False and inner_order_indicators is None:
|
||||
raise ValueError("inner_order_indicators is necessary in unatomic executor")
|
||||
|
||||
# TODO: `update_bar_count` and `update_current` should placed in Position and be merged.
|
||||
self.update_bar_count()
|
||||
self.update_current(trade_start_time, trade_end_time, trade_exchange)
|
||||
if generate_report:
|
||||
# report is portfolio related analysis
|
||||
# TODO: `update_bar_count` and `update_current` should placed in Position and be merged.
|
||||
self.update_bar_count()
|
||||
self.update_current(trade_start_time, trade_end_time, trade_exchange)
|
||||
self.update_report(trade_start_time, trade_end_time)
|
||||
|
||||
# indicator is trading (e.g. high-frequency order execution) related analysis
|
||||
|
||||
@@ -14,7 +14,7 @@ from ..data.dataset.utils import get_level_index
|
||||
from ..config import C, REG_CN
|
||||
from ..utils.resam import resam_ts_data
|
||||
from ..log import get_module_logger
|
||||
from .order import Order
|
||||
from .order import Order, OrderDir, OrderHelper
|
||||
|
||||
|
||||
class Exchange:
|
||||
@@ -164,10 +164,6 @@ class Exchange:
|
||||
assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
|
||||
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
|
||||
|
||||
# update quote: pd.DataFrame to dict, for search use
|
||||
if get_level_index(quote_df, level="datetime") == 1:
|
||||
quote_df = quote_df.swaplevel().sort_index()
|
||||
|
||||
quote_dict = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = stock_val
|
||||
@@ -530,3 +526,9 @@ class Exchange:
|
||||
raise NotImplementedError("order type {} error".format(order.type))
|
||||
|
||||
return trade_val, trade_cost
|
||||
|
||||
def get_order_helper(self) -> OrderHelper:
|
||||
if not hasattr(self, "_order_helper"):
|
||||
# cache to avoid recreate the same instance
|
||||
self._order_helper = OrderHelper(self)
|
||||
return self._order_helper
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
# Licensed under the MIT License.
|
||||
# TODO: rename it with decision.py
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
import warnings
|
||||
import pandas as pd
|
||||
@@ -15,6 +17,12 @@ from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Optional, Union, List, Set, Tuple
|
||||
|
||||
|
||||
class OrderDir(IntEnum):
|
||||
# Order direction
|
||||
SELL = 0
|
||||
BUY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Order:
|
||||
"""
|
||||
@@ -32,19 +40,98 @@ class Order:
|
||||
|
||||
stock_id: str
|
||||
amount: float
|
||||
|
||||
# The interval of the order which belongs to (NOTE: this is not the expected order dealing range time)
|
||||
start_time: pd.Timestamp
|
||||
end_time: pd.Timestamp
|
||||
|
||||
direction: int
|
||||
factor: float
|
||||
deal_amount: Optional[float] = None
|
||||
SELL: ClassVar[int] = 0
|
||||
BUY: ClassVar[int] = 1
|
||||
deal_amount: float = field(init=False)
|
||||
|
||||
# FIXME:
|
||||
# for compatible now.
|
||||
# Plese remove them in the future
|
||||
SELL: ClassVar[OrderDir] = OrderDir.SELL
|
||||
BUY: ClassVar[OrderDir] = OrderDir.BUY
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.direction not in {Order.SELL, Order.BUY}:
|
||||
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
|
||||
self.deal_amount = 0
|
||||
|
||||
@staticmethod
|
||||
def parse_dir(direction: Union[str, int, OrderDir]) -> OrderDir:
|
||||
if isinstance(direction, OrderDir):
|
||||
return direction
|
||||
elif isinstance(direction, int):
|
||||
return OrderDir(direction)
|
||||
elif isinstance(direction, str):
|
||||
dl = direction.lower()
|
||||
if dl.strip() == "sell":
|
||||
return OrderDir.SELL
|
||||
elif dl.strip() == "buy":
|
||||
return OrderDir.BUY
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
|
||||
class OrderHelper:
|
||||
"""
|
||||
Motivation
|
||||
- Make generating order easier
|
||||
- User may have no knowledge about the adjust-factor information about the system.
|
||||
- It involves to much interaction with the exchange when generating orders.
|
||||
"""
|
||||
|
||||
def __init__(self, exchange: Exchange):
|
||||
self.exchange = exchange
|
||||
|
||||
def create(
|
||||
self,
|
||||
code: str,
|
||||
amount: float,
|
||||
direction: OrderDir,
|
||||
start_time: Union[str, pd.Timestamp],
|
||||
end_time: Union[str, pd.Timestamp],
|
||||
) -> Order:
|
||||
"""
|
||||
help to create a order
|
||||
|
||||
# TODO: create order for unadjusted amount order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code : str
|
||||
the id of the instrument
|
||||
amount : float
|
||||
**adjusted trading amount**
|
||||
direction : OrderDir
|
||||
trading direction
|
||||
start_time : Union[str, pd.Timestamp]
|
||||
The interval of the order which belongs to
|
||||
end_time : Union[str, pd.Timestamp]
|
||||
The interval of the order which belongs to
|
||||
|
||||
Returns
|
||||
-------
|
||||
Order:
|
||||
The created order
|
||||
"""
|
||||
start_time = pd.Timestamp(start_time)
|
||||
end_time = pd.Timestamp(end_time)
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
factor=self.exchange.get_factor(code, start_time, end_time),
|
||||
)
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
"""
|
||||
|
||||
@@ -408,7 +408,7 @@ class InfPosition(BasePosition):
|
||||
"""
|
||||
|
||||
def skip_update(self) -> bool:
|
||||
""" Updating state is meaningless for InfPosition """
|
||||
"""Updating state is meaningless for InfPosition"""
|
||||
return True
|
||||
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Union
|
||||
from typing import IO, List, Tuple, Union
|
||||
from qlib.data.dataset.utils import convert_index_format
|
||||
|
||||
from qlib.utils import lazy_sort_index
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import convert_index_format
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO
|
||||
from ...backtest.exchange import Exchange
|
||||
from ...backtest.exchange import Exchange, OrderHelper
|
||||
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure
|
||||
from qlib.utils.file import get_io_object
|
||||
|
||||
|
||||
def get_start_end_idx(strategy: BaseStrategy, outer_trade_decision: BaseTradeDecision) -> Union[int, int]:
|
||||
@@ -423,7 +427,6 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
|
||||
)
|
||||
signal_df = convert_index_format(signal_df)
|
||||
signal_df.columns = ["signal"]
|
||||
self.signal = {}
|
||||
|
||||
@@ -515,7 +518,6 @@ class ACStrategy(BaseStrategy):
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
|
||||
)
|
||||
signal_df = convert_index_format(signal_df)
|
||||
signal_df.columns = ["volatility"]
|
||||
self.signal = {}
|
||||
|
||||
@@ -656,6 +658,9 @@ class RandomOrderStrategy(BaseStrategy):
|
||||
index_range : Tuple
|
||||
the intra day time index range of the orders
|
||||
the left and right is closed.
|
||||
|
||||
If you want to get the index_range in intra-day
|
||||
- `qlib/utils/time.py:def get_day_min_idx_range` can help you create the index range easier
|
||||
# TODO: this is a index_range level limitation. We'll implement a more detailed limitation later.
|
||||
sample_ratio : float
|
||||
the ratio of all orders are sampled
|
||||
@@ -687,7 +692,9 @@ class RandomOrderStrategy(BaseStrategy):
|
||||
if step_time_start in self.volume_df:
|
||||
for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items():
|
||||
order_list.append(
|
||||
self.common_infra.get("trade_exchange").create_order(
|
||||
self.common_infra.get("trade_exchange")
|
||||
.get_order_helper()
|
||||
.create(
|
||||
code=stock_id,
|
||||
amount=volume * self.volume_ratio,
|
||||
start_time=step_time_start,
|
||||
@@ -696,3 +703,53 @@ class RandomOrderStrategy(BaseStrategy):
|
||||
)
|
||||
)
|
||||
return TradeDecisionWO(order_list, self, self.index_range)
|
||||
|
||||
|
||||
class FileOrderStrategy(BaseStrategy):
|
||||
"""
|
||||
Motivtaion:
|
||||
- This class provides an interface for user to read orders from csv files.
|
||||
- It is supposed to be used in
|
||||
"""
|
||||
|
||||
def __init__(self, file: Union[IO, str, Path], index_range: Tuple[int, int] = None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
with get_io_object(file) as f:
|
||||
self.order_df = pd.read_csv(f, dtype={"datetime": np.str})
|
||||
|
||||
self.order_df["datetime"] = self.order_df["datetime"].apply(pd.Timestamp)
|
||||
self.order_df = self.order_df.set_index(["datetime", "instrument"])
|
||||
|
||||
# make sure the datetime is the first level for fast indexing
|
||||
self.order_df = lazy_sort_index(convert_index_format(self.order_df, level="datetime"))
|
||||
self.index_range = index_range
|
||||
|
||||
def generate_trade_decision(self, execute_result=None) -> TradeDecisionWO:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
execute_result :
|
||||
execute_result will be ignored in FileOrderStrategy
|
||||
"""
|
||||
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
|
||||
tc = self.trade_calendar
|
||||
step = tc.get_trade_step()
|
||||
start, end = tc.get_step_time(step)
|
||||
# CONVERSION: the bar is indexed by the time
|
||||
try:
|
||||
df = self.order_df.loc(axis=0)[start]
|
||||
except KeyError:
|
||||
return TradeDecisionWO([], self)
|
||||
else:
|
||||
order_list = []
|
||||
for idx, row in df.iterrows():
|
||||
order_list.append(
|
||||
oh.create(
|
||||
code=idx,
|
||||
amount=row["amount"],
|
||||
direction=Order.parse_dir(row["direction"]),
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
)
|
||||
)
|
||||
return TradeDecisionWO(order_list, self, self.index_range)
|
||||
|
||||
@@ -17,7 +17,7 @@ from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import fetch_df_by_index
|
||||
from .utils import fetch_df_by_index, fetch_df_by_col
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
@@ -152,14 +152,6 @@ class DataHandler(Serializable):
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
CS_RAW = "__raw" # return raw data with multi-level index column
|
||||
|
||||
def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
|
||||
if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW:
|
||||
return df
|
||||
elif col_set == self.CS_ALL:
|
||||
return df.droplevel(axis=1, level=0)
|
||||
else:
|
||||
return df.loc(axis=1)[col_set]
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
@@ -183,7 +175,7 @@ class DataHandler(Serializable):
|
||||
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
|
||||
if cal_set == CS_RAW:
|
||||
if col_set == CS_RAW:
|
||||
the raw dataset will be returned.
|
||||
|
||||
- if isinstance(col_set, List[str]):
|
||||
@@ -205,23 +197,34 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
if proc_func is None:
|
||||
df = self._data
|
||||
else:
|
||||
# FIXME: fetching by time first will be more friendly to `proc_func`
|
||||
# Copy in case of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
from .storage import HasingStockStorage
|
||||
|
||||
data_storage = self._data
|
||||
if isinstance(data_storage, pd.DataFrame):
|
||||
data_df = data_storage
|
||||
if proc_func is not None:
|
||||
# FIXME: fetching by time first will be more friendly to `proc_func`
|
||||
# Copy in case of `proc_func` changing the data inplace....
|
||||
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
else:
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
|
||||
elif isinstance(data_storage, HasingStockStorage):
|
||||
if proc_func is not None:
|
||||
raise ValueError("proc_func is not supported by the HasingStockStorage")
|
||||
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
|
||||
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
df = df.squeeze()
|
||||
data_df = data_df.squeeze()
|
||||
# squeeze index
|
||||
if isinstance(selector, (str, pd.Timestamp)):
|
||||
df = df.reset_index(level=level, drop=True)
|
||||
return df
|
||||
data_df = data_df.reset_index(level=level, drop=True)
|
||||
return data_df
|
||||
|
||||
def get_cols(self, col_set=CS_ALL) -> list:
|
||||
"""
|
||||
@@ -238,7 +241,7 @@ class DataHandler(Serializable):
|
||||
list of column names
|
||||
"""
|
||||
df = self._data.head()
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_col(df, col_set)
|
||||
return df.columns.to_list()
|
||||
|
||||
def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice:
|
||||
@@ -519,14 +522,29 @@ class DataHandlerLP(DataHandler):
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = self._get_df_by_key(data_key)
|
||||
if proc_func is not None:
|
||||
# FIXME: fetch by time first will be more friendly to proc_func
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
from .storage import HasingStockStorage
|
||||
|
||||
data_storage = self._get_df_by_key(data_key)
|
||||
if isinstance(data_storage, pd.DataFrame):
|
||||
data_df = data_storage
|
||||
if proc_func is not None:
|
||||
# FIXME: fetch by time first will be more friendly to proc_func
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
else:
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
elif isinstance(data_storage, HasingStockStorage):
|
||||
if proc_func is not None:
|
||||
raise ValueError("proc_func is not supported by the HasingStockStorage")
|
||||
data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
|
||||
|
||||
return data_df
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
|
||||
"""
|
||||
@@ -545,5 +563,5 @@ class DataHandlerLP(DataHandler):
|
||||
list of column names
|
||||
"""
|
||||
df = self._get_df_by_key(data_key).head()
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_col(df, col_set)
|
||||
return df.columns.to_list()
|
||||
|
||||
@@ -310,3 +310,12 @@ class CSZFillna(Processor):
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean()))
|
||||
return df
|
||||
|
||||
|
||||
class HashStockFormat(Processor):
|
||||
"""Process the storage of from df into hasing stock format"""
|
||||
|
||||
def __call__(self, df: pd.DataFrame):
|
||||
from .storage import HasingStockStorage
|
||||
|
||||
return HasingStockStorage.from_df(df)
|
||||
|
||||
107
qlib/data/dataset/storage.py
Normal file
107
qlib/data/dataset/storage.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from .handler import DataHandler
|
||||
from typing import Tuple, Union, List, Callable
|
||||
|
||||
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
|
||||
|
||||
|
||||
class BaseHandlerStorage:
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""fetch data from the data storage
|
||||
|
||||
Parameters
|
||||
----------
|
||||
selector : Union[pd.Timestamp, slice, str]
|
||||
describe how to select data by index
|
||||
level : Union[str, int]
|
||||
which index level to select the data
|
||||
col_set : Union[str, List[str]]
|
||||
- if isinstance(col_set, str):
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
if col_set == DataHandler.CS_RAW:
|
||||
the raw dataset will be returned.
|
||||
- if isinstance(col_set, List[str]):
|
||||
select several sets of meaningful columns, the returned data has multiple level
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError("fetch is method not implemented!")
|
||||
|
||||
@staticmethod
|
||||
def from_df(df: pd.DataFrame):
|
||||
raise NotImplementedError("from_df method is not implemented!")
|
||||
|
||||
|
||||
class HasingStockStorage(BaseHandlerStorage):
|
||||
def __init__(self, df):
|
||||
self.hash_df = dict()
|
||||
self.stock_level = get_level_index(df, "instrument")
|
||||
for k, v in df.groupby(level="instrument"):
|
||||
self.hash_df[k] = v
|
||||
self.columns = df.columns
|
||||
|
||||
@staticmethod
|
||||
def from_df(df):
|
||||
return HasingStockStorage(df)
|
||||
|
||||
def _fetch_hash_df_by_stock(self, selector, level):
|
||||
stock_selector = slice(None)
|
||||
|
||||
if level is None:
|
||||
if isinstance(selector, tuple) and self.stock_level < len(selector):
|
||||
stock_selector = selector[self.stock_level]
|
||||
elif isinstance(selector, (list, str)) and self.stock_level == 0:
|
||||
stock_selector = selector
|
||||
elif level == "instrument" or level == self.stock_level:
|
||||
if isinstance(selector, tuple):
|
||||
stock_selector = selector[0]
|
||||
elif isinstance(selector, (list, str)):
|
||||
stock_selector = selector
|
||||
|
||||
if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None):
|
||||
raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
|
||||
|
||||
if stock_selector == slice(None):
|
||||
return self.hash_df
|
||||
|
||||
if isinstance(stock_selector, str):
|
||||
stock_selector = [stock_selector]
|
||||
|
||||
select_dict = dict()
|
||||
for each_stock in sorted(stock_selector):
|
||||
if each_stock in self.hash_df:
|
||||
select_dict[each_stock] = self.hash_df[each_stock]
|
||||
return select_dict
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = DataHandler.CS_ALL,
|
||||
fetch_orig: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
|
||||
for _index, stock_df in enumerate(fetch_stock_df_list):
|
||||
fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
|
||||
fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig)
|
||||
fetch_stock_df_list[_index] = fetch_index_df
|
||||
if len(fetch_stock_df_list) == 0:
|
||||
index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
|
||||
return pd.DataFrame(
|
||||
index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32
|
||||
)
|
||||
elif len(fetch_stock_df_list) == 1:
|
||||
return fetch_stock_df_list[0]
|
||||
else:
|
||||
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import Union
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
@@ -72,12 +75,26 @@ def fetch_df_by_index(
|
||||
]
|
||||
|
||||
|
||||
def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:
|
||||
from .handler import DataHandler
|
||||
|
||||
if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW:
|
||||
return df
|
||||
elif col_set == DataHandler.CS_ALL:
|
||||
return df.droplevel(axis=1, level=0)
|
||||
else:
|
||||
return df.loc(axis=1)[col_set]
|
||||
|
||||
|
||||
def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]:
|
||||
"""
|
||||
Convert the format of df.MultiIndex according to the following rules:
|
||||
- If `level` is the first level of df.MultiIndex, do nothing
|
||||
- If `level` is the second level of df.MultiIndex, swap the level of index.
|
||||
|
||||
NOTE:
|
||||
the number of levels of df.MultiIndex should be 2
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : Union[pd.DataFrame, pd.Series]
|
||||
|
||||
37
qlib/utils/file.py
Normal file
37
qlib/utils/file.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# TODO: move file related utils into this module
|
||||
import contextlib
|
||||
from typing import IO, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_io_object(file: Union[IO, str, Path], *args, **kwargs) -> IO:
|
||||
"""
|
||||
providing a easy interface to get an IO object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file : Union[IO, str, Path]
|
||||
a object representing the file
|
||||
|
||||
Returns
|
||||
-------
|
||||
IO:
|
||||
a IO-like object
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
"""
|
||||
if isinstance(file, IO):
|
||||
yield file
|
||||
else:
|
||||
if isinstance(file, str):
|
||||
file = Path(file)
|
||||
if not isinstance(file, Path):
|
||||
raise NotImplementedError(f"This type[{type(file)}] of input is not supported")
|
||||
with file.open(*args, **kwargs) as f:
|
||||
yield f
|
||||
86
tests/backtest/test_file_strategy.py
Normal file
86
tests/backtest/test_file_strategy.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
from qlib.backtest import backtest, order
|
||||
from qlib.tests import TestAutoData
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
class FileStrTest(TestAutoData):
|
||||
|
||||
TEST_INST = "SH600519"
|
||||
|
||||
EXAMPLE_FILE = DIRNAME / "order_example.csv"
|
||||
|
||||
def _gen_orders(self) -> pd.DataFrame:
|
||||
headers = [
|
||||
"datetime",
|
||||
"instrument",
|
||||
"amount",
|
||||
"direction",
|
||||
]
|
||||
orders = [
|
||||
["20200102", self.TEST_INST, "1000", "sell"],
|
||||
["20200103", self.TEST_INST, "1000", "buy"],
|
||||
["20200106", self.TEST_INST, "1000", "sell"],
|
||||
]
|
||||
return pd.DataFrame(orders, columns=headers).set_index(["datetime", "instrument"])
|
||||
|
||||
def test_file_str(self):
|
||||
|
||||
orders = self._gen_orders()
|
||||
print(orders)
|
||||
orders.to_csv(self.EXAMPLE_FILE)
|
||||
|
||||
orders = pd.read_csv(self.EXAMPLE_FILE, index_col=["datetime", "instrument"])
|
||||
|
||||
strategy_config = {
|
||||
"class": "FileOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {"file": self.EXAMPLE_FILE},
|
||||
}
|
||||
|
||||
freq = "day"
|
||||
start_time = "2020-01-01"
|
||||
end_time = "2020-01-16"
|
||||
codes = [self.TEST_INST]
|
||||
|
||||
backtest_config = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"account": 100000000,
|
||||
"benchmark": None, # benchmark is not required here for trading
|
||||
"exchange_kwargs": {
|
||||
"freq": freq,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"codes": codes,
|
||||
},
|
||||
# "pos_type": "InfPosition" # Position with infinitive position
|
||||
}
|
||||
executor_config = {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": freq,
|
||||
"generate_report": False,
|
||||
"verbose": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
|
||||
|
||||
self.EXAMPLE_FILE.unlink()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
114
tests/test_handler_storage.py
Normal file
114
tests/test_handler_storage.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import unittest
|
||||
import time
|
||||
import numpy as np
|
||||
from qlib.data import D
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
from qlib.log import TimeInspector
|
||||
|
||||
|
||||
class TestHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"freq": "day",
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
},
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
drop_raw=drop_raw,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = ["Ref($open, 1)", "Ref($close, 1)", "Ref($volume, 1)", "$open", "$close", "$volume"]
|
||||
names = ["open_0", "close_0", "volume_0", "open_1", "close_1", "volume_1"]
|
||||
return fields, names
|
||||
|
||||
|
||||
class TestHandlerStorage(TestAutoData):
|
||||
|
||||
market = "all"
|
||||
|
||||
start_time = "2010-01-01"
|
||||
end_time = "2020-12-31"
|
||||
train_end_time = "2015-12-31"
|
||||
test_start_time = "2016-01-01"
|
||||
|
||||
data_handler_kwargs = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"fit_start_time": start_time,
|
||||
"fit_end_time": train_end_time,
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
def test_handler_storage(self):
|
||||
# init data handler
|
||||
data_handler = TestHandler(**self.data_handler_kwargs)
|
||||
|
||||
# init data handler with hasing storage
|
||||
data_handler_hs = TestHandler(**self.data_handler_kwargs, infer_processors=["HashStockFormat"])
|
||||
|
||||
fetch_start_time = "2019-01-01"
|
||||
fetch_end_time = "2019-12-31"
|
||||
instruments = D.instruments(market=self.market)
|
||||
instruments = D.list_instruments(
|
||||
instruments=instruments, start_time=fetch_start_time, end_time=fetch_end_time, as_list=True
|
||||
)
|
||||
|
||||
with TimeInspector.logt("random fetch with DataFrame Storage"):
|
||||
|
||||
# single stock
|
||||
for i in range(100):
|
||||
random_index = np.random.randint(len(instruments), size=1)[0]
|
||||
fetch_stock = instruments[random_index]
|
||||
data_handler.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None)
|
||||
|
||||
# multi stocks
|
||||
for i in range(100):
|
||||
random_indexs = np.random.randint(len(instruments), size=5)
|
||||
fetch_stocks = [instruments[_index] for _index in random_indexs]
|
||||
data_handler.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None)
|
||||
|
||||
with TimeInspector.logt("random fetch with HasingStock Storage"):
|
||||
|
||||
# single stock
|
||||
for i in range(100):
|
||||
random_index = np.random.randint(len(instruments), size=1)[0]
|
||||
fetch_stock = instruments[random_index]
|
||||
data_handler_hs.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None)
|
||||
|
||||
# multi stocks
|
||||
for i in range(100):
|
||||
random_indexs = np.random.randint(len(instruments), size=5)
|
||||
fetch_stocks = [instruments[_index] for _index in random_indexs]
|
||||
data_handler_hs.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user