1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Add some misc features. (#1816)

* Normal mod

* Black linting

* Linting
This commit is contained in:
you-n-g
2024-06-26 18:34:00 +08:00
committed by GitHub
parent cde80206e4
commit 5190332c7e
15 changed files with 290 additions and 76 deletions

View File

@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Union
@@ -35,6 +36,10 @@ class DDGDABench(DDGDA):
if __name__ == "__main__":
GetData().qlib_data(exists_skip=True)
auto_init()
kwargs = {}
if os.environ.get("PROVIDER_URI", "") == "":
GetData().qlib_data(exists_skip=True)
else:
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
auto_init(**kwargs)
fire.Fire(DDGDABench)

View File

@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Union
@@ -31,6 +32,10 @@ class RollingBenchmark(Rolling):
if __name__ == "__main__":
GetData().qlib_data(exists_skip=True)
auto_init()
kwargs = {}
if os.environ.get("PROVIDER_URI", "") == "":
GetData().qlib_data(exists_skip=True)
else:
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
auto_init(**kwargs)
fire.Fire(RollingBenchmark)

View File

@@ -243,7 +243,7 @@ class MetaDatasetDS(MetaTaskDataset):
trunc_days: int = None,
rolling_ext_days: int = 0,
exp_name: Union[str, InternalData],
segments: Union[Dict[Text, Tuple], float],
segments: Union[Dict[Text, Tuple], float, str],
hist_step_n: int = 10,
task_mode: str = MetaTask.PROC_MODE_FULL,
fill_method: str = "max",
@@ -271,12 +271,16 @@ class MetaDatasetDS(MetaTaskDataset):
- str: the name of the experiment to store the performance of data
- InternalData: a prepared internal data
segments: Union[Dict[Text, Tuple], float]
the segments to divide data
both left and right
if the segment is a Dict
the segments to divide data
both left and right are included
if segments is a float:
the float represents the percentage of data for training
if segments is a string:
it will try its best to put its data in training and ensure that the date `segments` is in the test set
hist_step_n: int
length of historical steps for the meta infomation
Number of steps of the data similarity information
task_mode : str
Please refer to the docs of MetaTask
"""
@@ -383,10 +387,30 @@ class MetaDatasetDS(MetaTaskDataset):
if isinstance(self.segments, float):
train_task_n = int(len(self.meta_task_l) * self.segments)
if segment == "train":
return self.meta_task_l[:train_task_n]
train_tasks = self.meta_task_l[:train_task_n]
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
return train_tasks
elif segment == "test":
return self.meta_task_l[train_task_n:]
test_tasks = self.meta_task_l[train_task_n:]
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
return test_tasks
else:
raise NotImplementedError(f"This type of input is not supported")
elif isinstance(self.segments, str):
train_tasks = []
test_tasks = []
for t in self.meta_task_l:
test_end = t.task["dataset"]["kwargs"]["segments"]["test"][1]
if test_end is None or pd.Timestamp(test_end) < pd.Timestamp(self.segments):
train_tasks.append(t)
else:
test_tasks.append(t)
get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}")
get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}")
if segment == "train":
return train_tasks
elif segment == "test":
return test_tasks
raise NotImplementedError(f"This type of input is not supported")
else:
raise NotImplementedError(f"This type of input is not supported")

View File

@@ -53,7 +53,12 @@ class MetaModelDS(MetaTaskModel):
max_epoch=100,
seed=43,
alpha=0.0,
loss_skip_thresh=50,
):
"""
loss_skip_size: int
The number of threshold to skip the loss calculation for each day.
"""
self.step = step
self.hist_step_n = hist_step_n
self.clip_method = clip_method
@@ -63,6 +68,7 @@ class MetaModelDS(MetaTaskModel):
self.max_epoch = max_epoch
self.fitted = False
self.alpha = alpha
self.loss_skip_thresh = loss_skip_thresh
torch.manual_seed(seed)
def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
@@ -88,12 +94,14 @@ class MetaModelDS(MetaTaskModel):
criterion = nn.MSELoss()
loss = criterion(pred, meta_input["y_test"])
elif self.criterion == "ic_loss":
criterion = ICLoss()
criterion = ICLoss(self.loss_skip_thresh)
try:
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"], skip_size=50)
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"])
except ValueError as e:
get_module_logger("MetaModelDS").warning(f"Exception `{e}` when calculating IC loss")
continue
else:
raise ValueError(f"Unknown criterion: {self.criterion}")
assert not np.isnan(loss.detach().item()), "NaN loss!"

View File

@@ -10,7 +10,11 @@ from qlib.log import get_module_logger
class ICLoss(nn.Module):
def forward(self, pred, y, idx, skip_size=50):
def __init__(self, skip_size=50):
super().__init__()
self.skip_size = skip_size
def forward(self, pred, y, idx):
"""forward.
FIXME:
- Some times it will be a slightly different from the result from `pandas.corr()`
@@ -33,7 +37,7 @@ class ICLoss(nn.Module):
skip_n = 0
for start_i, end_i in zip(diff_point, diff_point[1:]):
pred_focus = pred[start_i:end_i] # TODO: just for fake
if pred_focus.shape[0] < skip_size:
if pred_focus.shape[0] < self.skip_size:
# skip some days which have very small amount of stock.
skip_n += 1
continue
@@ -50,6 +54,7 @@ class ICLoss(nn.Module):
)
ic_all += ic_day
if len(diff_point) - 1 - skip_n <= 0:
__import__("ipdb").set_trace()
raise ValueError("No enough data for calculating IC")
if skip_n > 0:
get_module_logger("ICLoss").info(

View File

@@ -63,6 +63,7 @@ class LinearModel(Model):
df_train = pd.concat([df_train, df_valid])
except KeyError:
get_module_logger("LinearModel").info("include_valid=True, but valid does not exist")
df_train = df_train.dropna()
if df_train.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
if reweighter is not None:

View File

@@ -1,25 +1,25 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import copy
from typing import Text, Union
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from .pytorch_utils import count_parameters
from ...model.base import Model
from qlib.workflow import R
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...log import get_module_logger
from ...model.base import Model
from ...utils import get_or_create_path
from .pytorch_utils import count_parameters
class GRU(Model):
@@ -212,16 +212,31 @@ class GRU(Model):
evals_result=dict(),
save_path=None,
):
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
# prepare training and validation data
dfs = {
k: dataset.prepare(
k,
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
for k in ["train", "valid"]
if k in dataset.segments
}
df_train, df_valid = dfs.get("train", pd.DataFrame()), dfs.get("valid", pd.DataFrame())
# check if training data is empty
if df_train.empty:
raise ValueError("Empty training data from dataset, please check your dataset config.")
df_train = df_train.dropna()
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
# check if validation data is provided
if not df_valid.empty:
df_valid = df_valid.dropna()
x_valid, y_valid = df_valid["feature"], df_valid["label"]
else:
x_valid, y_valid = None, None
save_path = get_or_create_path(save_path)
stop_steps = 0
@@ -235,32 +250,42 @@ class GRU(Model):
self.logger.info("training...")
self.fitted = True
best_param = copy.deepcopy(self.gru_model.state_dict())
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(x_train, y_train)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(x_train, y_train)
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.gru_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
# evaluate on validation data if provided
if x_valid is not None and y_valid is not None:
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.gru_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.gru_model.load_state_dict(best_param)
torch.save(best_param, save_path)
# Logging
rec = R.get_recorder()
for k, v_l in evals_result.items():
for i, v in enumerate(v_l):
rec.log_metrics(step=i, **{k: v})
if self.use_gpu:
torch.cuda.empty_cache()
@@ -292,6 +317,7 @@ class GRU(Model):
class GRUModel(nn.Module):
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
super().__init__()

View File

@@ -1,5 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Here we have a comprehensive set of analysis classes.
Here is an example.
.. code-block:: python
from qlib.contrib.report.data.ana import FeaMeanStd
fa = FeaMeanStd(ret_df)
fa.plot_all(wspace=0.3, sub_figsize=(12, 3), col_n=5)
"""
import pandas as pd
import numpy as np
from qlib.contrib.report.data.base import FeaAnalyser
@@ -152,6 +164,7 @@ class FeaSkewTurt(NumFeaAnalyser):
self._kurt[col].plot(ax=right_ax, label="kurt", color="green")
right_ax.set_xlabel("")
right_ax.set_ylabel("kurt")
right_ax.grid(None) # set the grid to None to avoid two layer of grid
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = right_ax.get_legend_handles_labels()
@@ -171,12 +184,15 @@ class FeaMeanStd(NumFeaAnalyser):
ax.set_xlabel("")
ax.set_ylabel("mean")
ax.legend()
ax.tick_params(axis="x", rotation=90)
right_ax = ax.twinx()
self._std[col].plot(ax=right_ax, label="std", color="green")
right_ax.set_xlabel("")
right_ax.set_ylabel("std")
right_ax.tick_params(axis="x", rotation=90)
right_ax.grid(None) # set the grid to None to avoid two layer of grid
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = right_ax.get_legend_handles_labels()

View File

@@ -14,6 +14,24 @@ from qlib.contrib.report.utils import sub_fig_generator
class FeaAnalyser:
def __init__(self, dataset: pd.DataFrame):
"""
Parameters
----------
dataset : pd.DataFrame
We often have multiple columns for dataset. Each column corresponds to one sub figure.
There will be a datatime column in the index levels.
Aggretation will be used for more summarized metrics overtime.
Here is an example of data:
.. code-block::
return
datetime instrument
2007-02-06 equity_tpx 0.010087
equity_spx 0.000786
"""
self._dataset = dataset
with TimeInspector.logt("calc_stat_values"):
self.calc_stat_values()

View File

@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
import pandas as pd
def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
def sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
"""sub_fig_generator.
it will return a generator, each row contains <col_n> sub graph
@@ -13,7 +13,7 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
Parameters
----------
sub_fs :
sub_figsize :
the figure size of each subgraph in <col_n> * <row_n> subgraphs
col_n :
the number of subgraph in each row; It will generating a new graph after generating <col_n> of subgraphs.
@@ -33,7 +33,7 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
while True:
fig, axes = plt.subplots(
row_n, col_n, figsize=(sub_fs[0] * col_n, sub_fs[1] * row_n), sharex=sharex, sharey=sharey
row_n, col_n, figsize=(sub_figsize[0] * col_n, sub_figsize[1] * row_n), sharex=sharex, sharey=sharey
)
plt.subplots_adjust(wspace=wspace, hspace=hspace)
axes = axes.reshape(row_n, col_n)

View File

@@ -73,8 +73,8 @@ class Rolling:
The horizon of the prediction target.
This is used to override the prediction horizon of the file.
h_path : Optional[str]
the dumped data handler;
It may come from other data source. It will override the data handler in the config.
It is other data source that is dumped as a handler. It will override the data handler section in the config.
If it is not given, it will create a customized cache for the handler when `enable_handler_cache=True`
test_end : Optional[str]
the test end for the data. It is typically used together with the handler
You can do the same thing with task_ext_conf in a more complicated way
@@ -119,7 +119,7 @@ class Rolling:
with self.conf_path.open("r") as f:
return yaml.safe_load(f)
def _replace_hanler_with_cache(self, task: dict):
def _replace_handler_with_cache(self, task: dict):
"""
Due to the data processing part in original rolling is slow. So we have to
This class tries to add more feature
@@ -159,13 +159,20 @@ class Rolling:
# - get horizon automatically from the expression!!!!
raise NotImplementedError(f"This type of input is not supported")
else:
self.logger.info("The prediction horizon is overrided")
task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [
"Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1)
]
if enable_handler_cache and self.h_path is not None:
self.logger.info("Fail to override the horizon due to data handler cache")
else:
self.logger.info("The prediction horizon is overrided")
if isinstance(task["dataset"]["kwargs"]["handler"], dict):
task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [
"Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1)
]
else:
self.logger.warning("Try to automatically configure the lablel but failed.")
if enable_handler_cache:
task = self._replace_hanler_with_cache(task)
if self.h_path is not None or enable_handler_cache:
# if we already have provided data source or we want to create one
task = self._replace_handler_with_cache(task)
task = self._update_start_end_time(task)
if self.task_ext_conf is not None:
@@ -173,6 +180,16 @@ class Rolling:
self.logger.info(task)
return task
def run_basic_task(self):
"""
Run the basic task without rolling.
This is for fast testing for model tunning.
"""
task = self.basic_task()
print(task)
trainer = TrainerR(experiment_name=self.exp_name)
trainer([task])
def get_task_list(self) -> List[dict]:
"""return a batch of tasks for rolling."""
task = self.basic_task()

View File

@@ -80,6 +80,11 @@ class DDGDA(Rolling):
sim_task_model: UTIL_MODEL_TYPE = "gbdt",
meta_1st_train_end: Optional[str] = None,
alpha: float = 0.01,
loss_skip_thresh: int = 50,
fea_imp_n: Optional[int] = 30,
meta_data_proc: Optional[str] = "V01",
segments: Union[float, str] = 0.62,
hist_step_n: int = 30,
working_dir: Optional[Union[str, Path]] = None,
**kwargs,
):
@@ -94,6 +99,15 @@ class DDGDA(Rolling):
alpha: float
Setting the L2 regularization for ridge
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
loss_skip_thresh: int
The thresh to skip the loss calculation for each day. If the number of item is less than it, it will skip the loss on that day.
meta_data_proc : Optional[str]
How we process the meta dataset for learning meta model.
segments : Union[float, str]
if segments is a float:
The ratio of training data in the meta task dataset
if segments is a string:
it will try its best to put its data in training and ensure that the date `segments` is in the test set
"""
# NOTE:
# the horizon must match the meaning in the base task template
@@ -104,14 +118,22 @@ class DDGDA(Rolling):
super().__init__(**kwargs)
self.working_dir = self.conf_path.parent if working_dir is None else Path(working_dir)
self.proxy_hd = self.working_dir / "handler_proxy.pkl"
self.fea_imp_n = fea_imp_n
self.meta_data_proc = meta_data_proc
self.loss_skip_thresh = loss_skip_thresh
self.segments = segments
self.hist_step_n = hist_step_n
def _adjust_task(self, task: dict, astype: UTIL_MODEL_TYPE):
"""
some task are use for special purpose.
Base on the original task, we need to do some extra things.
For example:
- GBDT for calculating feature importance
- Linear or GBDT for calculating similarity
- Datset (well processed) that aligned to Linear that for meta learning
So we may need to change the dataset and model for the special purpose and other settings remains the same.
"""
# NOTE: here is just for aligning with previous implementation
# It is not necessary for the current implementation
@@ -119,12 +141,16 @@ class DDGDA(Rolling):
if astype == "gbdt":
task["model"] = LGBM_MODEL
if isinstance(handler, dict):
# We don't need preprocessing when using GBDT model
for k in ["infer_processors", "learn_processors"]:
if k in handler.setdefault("kwargs", {}):
handler["kwargs"].pop(k)
elif astype == "linear":
task["model"] = LINEAR_MODEL
handler["kwargs"].update(PROC_ARGS)
if isinstance(handler, dict):
handler["kwargs"].update(PROC_ARGS)
else:
self.logger.warning("The handler can't be adjusted.")
else:
raise ValueError(f"astype not supported: {astype}")
return task
@@ -155,12 +181,15 @@ class DDGDA(Rolling):
The meta model will be trained upon the proxy forecasting model.
This dataset is for the proxy forecasting model.
"""
topk = 30
fi = self._get_feature_importance()
col_selected = fi.nlargest(topk)
# NOTE: adjusting to `self.sim_task_model` just for aligning with previous implementation.
# In previous version. The data for proxy model is using sim_task_model's way for processing
task = self._adjust_task(self.basic_task(enable_handler_cache=False), self.sim_task_model)
task = replace_task_handler_with_cache(task, self.working_dir)
# if self.meta_data_proc is not None:
# else:
# # Otherwise, we don't need futher processing
# task = self.basic_task()
dataset = init_instance_by_config(task["dataset"])
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
@@ -168,12 +197,18 @@ class DDGDA(Rolling):
feature_df = prep_ds["feature"]
label_df = prep_ds["label"]
feature_selected = feature_df.loc[:, col_selected.index]
if self.fea_imp_n is not None:
fi = self._get_feature_importance()
col_selected = fi.nlargest(self.fea_imp_n)
feature_selected = feature_df.loc[:, col_selected.index]
else:
feature_selected = feature_df
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
lambda df: (df - df.mean()).div(df.std())
)
feature_selected = feature_selected.fillna(0.0)
if self.meta_data_proc == "V01":
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
lambda df: (df - df.mean()).div(df.std())
)
feature_selected = feature_selected.fillna(0.0)
df_all = {
"label": label_df.reindex(feature_selected.index),
@@ -223,7 +258,10 @@ class DDGDA(Rolling):
# 1) leverage the simplified proxy forecasting model to train meta model.
# - Only the dataset part is important, in current version of meta model will integrate the
# the train_start for training meta model does not necessarily align with final rolling
# NOTE:
# - The train_start for training meta model does not necessarily align with final rolling
# But please select a right time to make sure the finnal rolling tasks are not leaked in the training data.
# - The test_start is automatically aligned to the next day of test_end. Validation is ignored.
train_start = "2008-01-01" if self.train_start is None else self.train_start
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
@@ -249,9 +287,9 @@ class DDGDA(Rolling):
kwargs = dict(
task_tpl=proxy_forecast_model_task,
step=self.step,
segments=0.62, # keep test period consistent with the dataset yaml
segments=self.segments, # keep test period consistent with the dataset yaml
trunc_days=1 + self.horizon,
hist_step_n=30,
hist_step_n=self.hist_step_n,
fill_method=fill_method,
rolling_ext_days=0,
)
@@ -268,7 +306,13 @@ class DDGDA(Rolling):
with R.start(experiment_name=self.meta_exp_name):
R.log_params(**kwargs)
mm = MetaModelDS(
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=30, seed=43, alpha=self.alpha
step=self.step,
hist_step_n=kwargs["hist_step_n"],
lr=0.001,
max_epoch=30,
seed=43,
alpha=self.alpha,
loss_skip_thresh=self.loss_skip_thresh,
)
mm.fit(md)
R.save_objects(model=mm)

View File

@@ -51,3 +51,6 @@ class MetaTask:
Return the **processed** meta_info
"""
return self.meta_info
def __repr__(self):
return f"MetaTask(task={self.task}, meta_info={self.meta_info})"

View File

@@ -161,7 +161,13 @@ def init_instance_by_config(
# path like 'file:///<path to pickle file>/obj.pkl'
pr = urlparse(config)
if pr.scheme == "file":
pr_path = os.path.join(pr.netloc, pr.path) if bool(pr.path) else pr.netloc
# To enable relative path like file://data/a/b/c.pkl. pr.netloc will be data
path = pr.path
if pr.netloc != "":
path = path.lstrip("/")
pr_path = os.path.join(pr.netloc, path) if bool(pr.path) else pr.netloc
with open(os.path.normpath(pr_path), "rb") as f:
return pickle.load(f)
else:

View File

@@ -1,18 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import sys
import os
from pathlib import Path
import sys
import fire
from jinja2 import Template, meta
import ruamel.yaml as yaml
import qlib
import fire
import ruamel.yaml as yaml
from qlib.config import C
from qlib.model.trainer import task_train
from qlib.utils.data import update_config
from qlib.log import get_module_logger
from qlib.model.trainer import task_train
from qlib.utils import set_log_with_config
from qlib.utils.data import update_config
set_log_with_config(C.logging_config)
logger = get_module_logger("qrun", logging.INFO)
@@ -47,6 +49,39 @@ def sys_config(config, config_path):
sys.path.append(str(Path(config_path).parent.resolve().absolute() / p))
def render_template(config_path: str) -> str:
"""
render the template based on the environment
Parameters
----------
config_path : str
configuration path
Returns
-------
str
the rendered content
"""
with open(config_path, "r") as f:
config = f.read()
# Set up the Jinja2 environment
template = Template(config)
# Parse the template to find undeclared variables
env = template.environment
parsed_content = env.parse(config)
variables = meta.find_undeclared_variables(parsed_content)
# Get context from os.environ according to the variables
context = {var: os.getenv(var, "") for var in variables if var in os.environ}
logger.info(f"Render the template with the context: {context}")
# Render the template with the context
rendered_content = template.render(context)
return rendered_content
# workflow handler function
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
"""
@@ -67,8 +102,9 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
market: csi300
"""
with open(config_path) as fp:
config = yaml.safe_load(fp)
# Render the template
rendered_yaml = render_template(config_path)
config = yaml.safe_load(rendered_yaml)
base_config_path = config.get("BASE_CONFIG_PATH", None)
if base_config_path: