1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Modify cli

This commit is contained in:
Jactus
2020-11-20 10:59:37 +08:00
parent 0f433571f6
commit 547697ddc6
6 changed files with 16 additions and 10 deletions

View File

@@ -14,7 +14,7 @@ from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
# worflow handler function
def workflow(config_path):
def workflow(config_path, experiment_name="workflow"):
with open(config_path) as fp:
config = yaml.load(fp, Loader=yaml.Loader)
@@ -26,12 +26,13 @@ def workflow(config_path):
dataset = init_instance_by_config(config.get("task")["dataset"])
# start exp
with R.start(experiment_name="workflow"):
R.log_paramters(**flatten_dict(task))
with R.start(experiment_name=experiment_name):
# train model
R.log_params(**flatten_dict(config.get("task")))
model.fit(dataset)
recorder = R.get_recorder()
# generate records
# generate records: prediction, backtest, and analysis
for record in config.get("task")["record"]:
if record["class"] == SignalRecord.__name__:
srconf = {"model": model, "dataset": dataset, "recorder": recorder}

View File

@@ -108,9 +108,7 @@ class CheckBin:
return self.COMPARE_ERROR
def check(self):
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data
"""
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data"""
logger.info("start check......")
error_list = []

View File

@@ -55,7 +55,9 @@ class GetData:
for _file in tqdm(zp.namelist()):
zp.extract(_file, str(target_dir.resolve()))
def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"):
def qlib_data(
self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
):
"""download cn qlib data from remote
Parameters

View File

@@ -57,6 +57,7 @@ REQUIRED = [
"tornado",
"joblib>=0.17.0",
"fire>=0.3.1",
"ruamel.yaml>=0.16.12",
]
# Numpy include

View File

@@ -149,7 +149,9 @@ class TestAllFlow(unittest.TestCase):
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri)
GetData().qlib_data(
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
)
qlib.init(provider_uri=provider_uri, region=REG_CN)
def test_0_train(self):

View File

@@ -75,7 +75,9 @@ class TestDumpData(unittest.TestCase):
def test_4_dump_features_simple(self):
stock = self.STOCK_NAMES[0]
dump_data = DumpDataFix(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS)
dump_data = DumpDataFix(
csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS
)
dump_data.dump()
df = D.features([stock], self.QLIB_FIELDS)