1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

Merge branch 'online_srv_wd' into online_srv

This commit is contained in:
Young
2021-03-19 07:49:16 +00:00
5 changed files with 16 additions and 7 deletions

View File

@@ -154,7 +154,7 @@ def init_from_yaml_conf(conf_path, **kwargs):
init(default_conf, **config)
def get_project_path(config_name="config.yaml") -> Path:
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
"""
If users are building a project follow the following pattern.
- Qlib is a sub folder in project path
@@ -181,7 +181,8 @@ def get_project_path(config_name="config.yaml") -> Path:
FileNotFoundError:
If project path is not found
"""
cur_path = Path(__file__).absolute().resolve()
if cur_path is None:
cur_path = Path(__file__).absolute().resolve()
while True:
if (cur_path / config_name).exists():
return cur_path
@@ -199,7 +200,7 @@ def auto_init(**kwargs):
"""
try:
pp = get_project_path()
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:

View File

@@ -50,6 +50,9 @@ class DataHandler(Serializable):
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
Tips for improving the performance of datahandler
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
"""
def __init__(
@@ -257,6 +260,10 @@ class DataHandler(Serializable):
class DataHandlerLP(DataHandler):
"""
DataHandler with **(L)earnable (P)rocessor**
Tips to improving the performance of data handler
- To reduce the memory cost
- `drop_raw=True`: this will modify the data inplace on raw data;
"""
# data key

View File

@@ -36,7 +36,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
model.fit(dataset)
recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.save_objects(task=task_config) # keep the original format and datatype
artifact_uri = recorder.get_artifact_uri()[7:] # delete "file://"
dataset.to_pickle(artifact_uri + "/dataset", exclude=["handler"])

View File

@@ -1,5 +1,6 @@
from qlib.workflow import R
import pandas as pd
from tqdm.auto import tqdm
from typing import Union
from typing import Callable
@@ -113,4 +114,4 @@ class RollingCollector(TaskCollector):
pred = pred.sort_index()
reduce_group[k] = pred
return reduce_group
return reduce_group

View File

@@ -96,8 +96,8 @@ class ModelUpdater:
.. code-block:: python
def rec_filter_func(recorder):
task_config = recorder.load_object("task")
def record_filter(record):
task_config = record.load_object("task")
if task_config["model"]["class"]=="LGBModel":
return True
return False