mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Fix
This commit is contained in:
@@ -28,7 +28,8 @@ from qlib.utils import exists_qlib_data
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
exp_path = str(Path(os.getcwd()).resolve() / "run_all_model_records")
|
||||
exp_folder_name = "run_all_model_records"
|
||||
exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
|
||||
exp_manager = {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
@@ -253,7 +254,7 @@ def run(times=1, models=None, exclude=False):
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} '{exp_manager}'"
|
||||
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import sys, os
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pandas as pd
|
||||
import ruamel.yaml as yaml
|
||||
from qlib.config import C
|
||||
from qlib.model.trainer import task_train
|
||||
|
||||
|
||||
@@ -41,7 +42,7 @@ def sys_config(config, config_path):
|
||||
|
||||
|
||||
# worflow handler function
|
||||
def workflow(config_path, experiment_name="workflow", exp_manager=None):
|
||||
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
with open(config_path) as fp:
|
||||
config = yaml.load(fp, Loader=yaml.Loader)
|
||||
|
||||
@@ -50,10 +51,9 @@ def workflow(config_path, experiment_name="workflow", exp_manager=None):
|
||||
|
||||
provider_uri = config.get("provider_uri")
|
||||
region = config.get("region")
|
||||
if exp_manager:
|
||||
qlib.init(provider_uri=provider_uri, region=region, exp_manager=exp_manager)
|
||||
else:
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
exp_manager = C["exp_manager"]
|
||||
exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder
|
||||
qlib.init(provider_uri=provider_uri, region=region, exp_manager=exp_manager)
|
||||
|
||||
task_train(config, experiment_name=experiment_name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user