mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
update run_all_model and black format
This commit is contained in:
@@ -23,7 +23,6 @@ from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
exp_folder_name = "run_all_model_records"
|
||||
@@ -40,6 +39,7 @@ exp_manager = {
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
|
||||
|
||||
# decorator to check the arguments
|
||||
def only_allow_defined_args(function_to_decorate):
|
||||
@functools.wraps(function_to_decorate)
|
||||
@@ -92,7 +92,8 @@ def create_env():
|
||||
|
||||
|
||||
# function to execute the cmd
|
||||
def execute(cmd):
|
||||
def execute(cmd, wait_when_err=False):
|
||||
print("Running CMD:", cmd)
|
||||
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
|
||||
for line in p.stdout:
|
||||
sys.stdout.write(line.split("\b")[0])
|
||||
@@ -102,6 +103,8 @@ def execute(cmd):
|
||||
sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
|
||||
|
||||
if p.returncode != 0:
|
||||
if wait_when_err:
|
||||
input("Press Enter to Continue")
|
||||
return p.stderr
|
||||
else:
|
||||
return None
|
||||
@@ -184,7 +187,15 @@ def gen_and_save_md_table(metrics, dataset):
|
||||
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
def run(
|
||||
times=1,
|
||||
models=None,
|
||||
dataset="Alpha360",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
wait_before_rm_env: bool = False,
|
||||
wait_when_err: bool = False,
|
||||
):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
|
||||
@@ -200,6 +211,13 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path
|
||||
wait_before_rm_env : bool
|
||||
wait before remove environment.
|
||||
wait_when_err : bool
|
||||
wait when errors raised when executing commands
|
||||
|
||||
Usage:
|
||||
-------
|
||||
@@ -240,32 +258,36 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
sys.stderr.write("\n")
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
execute(f"{python_path} -m pip install -r {req_path}")
|
||||
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
|
||||
sys.stderr.write("\n")
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
execute(f"{python_path} -m pip install --upgrade pip") # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/microsoft/qlib#egg=pyqlib"
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
else:
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/microsoft/qlib#egg=pyqlib"
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
|
||||
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
@@ -274,6 +296,8 @@ def run(times=1, models=None, dataset="Alpha360", exclude=False):
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
if wait_before_rm_env:
|
||||
input("Press Enter to Continue")
|
||||
shutil.rmtree(env_path)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
|
||||
@@ -24,6 +24,7 @@ from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
# qrun examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml ”
|
||||
|
||||
|
||||
@@ -150,8 +151,8 @@ class LocalformerModel(Model):
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i: i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i: i + self.batch_size]]).float().to(self.device)
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature)
|
||||
|
||||
@@ -154,6 +154,7 @@ class LocalformerModel(Model):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
@@ -23,6 +23,7 @@ from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
# qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ”
|
||||
|
||||
|
||||
@@ -149,8 +150,8 @@ class TransformerModel(Model):
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i: i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i: i + self.batch_size]]).float().to(self.device)
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(feature)
|
||||
|
||||
Reference in New Issue
Block a user