mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
Merge branch 'online_srv_wd' into online_srv
This commit is contained in:
@@ -154,7 +154,7 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
init(default_conf, **config)
|
||||
|
||||
|
||||
def get_project_path(config_name="config.yaml") -> Path:
|
||||
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
|
||||
"""
|
||||
If users are building a project follow the following pattern.
|
||||
- Qlib is a sub folder in project path
|
||||
@@ -181,7 +181,8 @@ def get_project_path(config_name="config.yaml") -> Path:
|
||||
FileNotFoundError:
|
||||
If project path is not found
|
||||
"""
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
if cur_path is None:
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
while True:
|
||||
if (cur_path / config_name).exists():
|
||||
return cur_path
|
||||
@@ -199,7 +200,7 @@ def auto_init(**kwargs):
|
||||
"""
|
||||
|
||||
try:
|
||||
pp = get_project_path()
|
||||
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
@@ -50,6 +50,9 @@ class DataHandler(Serializable):
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
|
||||
Tips for improving the performance of datahandler
|
||||
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -257,6 +260,10 @@ class DataHandler(Serializable):
|
||||
class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
DataHandler with **(L)earnable (P)rocessor**
|
||||
|
||||
Tips to improving the performance of data handler
|
||||
- To reduce the memory cost
|
||||
- `drop_raw=True`: this will modify the data inplace on raw data;
|
||||
"""
|
||||
|
||||
# data key
|
||||
|
||||
@@ -36,7 +36,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
R.save_objects(task=task_config) # keep the original format and datatype
|
||||
|
||||
artifact_uri = recorder.get_artifact_uri()[7:] # delete "file://"
|
||||
dataset.to_pickle(artifact_uri + "/dataset", exclude=["handler"])
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from qlib.workflow import R
|
||||
import pandas as pd
|
||||
from tqdm.auto import tqdm
|
||||
from typing import Union
|
||||
from typing import Callable
|
||||
|
||||
@@ -113,4 +114,4 @@ class RollingCollector(TaskCollector):
|
||||
pred = pred.sort_index()
|
||||
reduce_group[k] = pred
|
||||
|
||||
return reduce_group
|
||||
return reduce_group
|
||||
|
||||
@@ -96,8 +96,8 @@ class ModelUpdater:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def rec_filter_func(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
def record_filter(record):
|
||||
task_config = record.load_object("task")
|
||||
if task_config["model"]["class"]=="LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user