mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Modify cli
This commit is contained in:
@@ -14,7 +14,7 @@ from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
|
||||
# worflow handler function
|
||||
def workflow(config_path):
|
||||
def workflow(config_path, experiment_name="workflow"):
|
||||
with open(config_path) as fp:
|
||||
config = yaml.load(fp, Loader=yaml.Loader)
|
||||
|
||||
@@ -26,12 +26,13 @@ def workflow(config_path):
|
||||
dataset = init_instance_by_config(config.get("task")["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_paramters(**flatten_dict(task))
|
||||
with R.start(experiment_name=experiment_name):
|
||||
# train model
|
||||
R.log_params(**flatten_dict(config.get("task")))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
|
||||
# generate records
|
||||
# generate records: prediction, backtest, and analysis
|
||||
for record in config.get("task")["record"]:
|
||||
if record["class"] == SignalRecord.__name__:
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
|
||||
@@ -108,9 +108,7 @@ class CheckBin:
|
||||
return self.COMPARE_ERROR
|
||||
|
||||
def check(self):
|
||||
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data
|
||||
|
||||
"""
|
||||
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data"""
|
||||
logger.info("start check......")
|
||||
|
||||
error_list = []
|
||||
|
||||
@@ -55,7 +55,9 @@ class GetData:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"):
|
||||
def qlib_data(
|
||||
self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
|
||||
):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
Parameters
|
||||
|
||||
1
setup.py
1
setup.py
@@ -57,6 +57,7 @@ REQUIRED = [
|
||||
"tornado",
|
||||
"joblib>=0.17.0",
|
||||
"fire>=0.3.1",
|
||||
"ruamel.yaml>=0.16.12",
|
||||
]
|
||||
|
||||
# Numpy include
|
||||
|
||||
@@ -149,7 +149,9 @@ class TestAllFlow(unittest.TestCase):
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri)
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
|
||||
)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def test_0_train(self):
|
||||
|
||||
@@ -75,7 +75,9 @@ class TestDumpData(unittest.TestCase):
|
||||
|
||||
def test_4_dump_features_simple(self):
|
||||
stock = self.STOCK_NAMES[0]
|
||||
dump_data = DumpDataFix(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS)
|
||||
dump_data = DumpDataFix(
|
||||
csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS
|
||||
)
|
||||
dump_data.dump()
|
||||
|
||||
df = D.features([stock], self.QLIB_FIELDS)
|
||||
|
||||
Reference in New Issue
Block a user