mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 17:41:18 +08:00
136 lines
5.3 KiB
Python
136 lines
5.3 KiB
Python
import re
|
|
import os
|
|
import argparse
|
|
import yaml
|
|
from executor import Executor
|
|
import warnings
|
|
import redis
|
|
import subprocess
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
from util import merge_dicts
|
|
|
|
loader = yaml.FullLoader
|
|
loader.add_implicit_resolver(
|
|
"tag:yaml.org,2002:float",
|
|
re.compile(
|
|
"""^(?:
|
|
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
|
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
|
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
|
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
|
|[-+]?\\.(?:inf|Inf|INF)
|
|
|\\.(?:nan|NaN|NAN))$""",
|
|
re.X,
|
|
),
|
|
list("-+0123456789."),
|
|
)
|
|
|
|
|
|
def get_full_config(config, dir_name):
|
|
while "base" in config:
|
|
base_config = os.path.normpath(os.path.join(dir_name, config.pop("base")))
|
|
dir_name = os.path.dirname(base_config)
|
|
with open(base_config, "r") as f:
|
|
base_config = yaml.load(base_config, Loader=yaml.FullLoader)
|
|
config = merge_dicts(base_config, config)
|
|
return config
|
|
|
|
|
|
def run(config):
|
|
log_dir = config["log_dir"]
|
|
if not os.path.exists(log_dir):
|
|
os.makedirs(log_dir)
|
|
with open(log_dir + "/config.yml", "w") as f:
|
|
yaml.dump(config, f)
|
|
executor = Executor(**config)
|
|
if config["task"] == "train":
|
|
return executor.train(**config["optim"])
|
|
elif config["task"] == "eval":
|
|
return executor.eval(config["test_paths"]["order_dir"], save_res=True, logdir=config["log_dir"] + "/test/",)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-c", "--config", type=str)
|
|
parser.add_argument("-n", "--index", type=int, default=None)
|
|
args = parser.parse_args()
|
|
|
|
print(os.cpu_count())
|
|
|
|
EXP_PATH = os.environ["EXP_PATH"]
|
|
config_path = os.path.normpath(os.path.join(EXP_PATH, args.config))
|
|
EXP_NAME = os.path.relpath(config_path, EXP_PATH)
|
|
if os.path.isdir(config_path):
|
|
if not args.index is None:
|
|
with open(config_path + "/configs.yml") as f:
|
|
config_list = list(yaml.load_all(f, Loader=loader))
|
|
config = config_list[args.index]
|
|
if "PT_OUTPUT_DIR" in os.environ:
|
|
config["log_dir"] = os.environ["PT_OUTPUT_DIR"]
|
|
else:
|
|
log_prefix = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
|
config["log_dir"] = os.path.join(log_prefix, config["log_dir"])
|
|
config = get_full_config(config, config_path)
|
|
run(config)
|
|
else:
|
|
redis_server = redis.Redis(
|
|
host=os.environ["REDIS_SERVER"],
|
|
port=os.environ["REDIS_PORT"],
|
|
db=0,
|
|
charset="utf-8",
|
|
decode_responses=True,
|
|
)
|
|
with open(config_path + "/configs.yml") as f:
|
|
config_list = list(yaml.load_all(f, Loader=loader))
|
|
config_num = len(config_list)
|
|
if not redis_server.exists(EXP_NAME):
|
|
for i in range(config_num):
|
|
redis_server.rpush(EXP_NAME, i)
|
|
redis_server.set(f"{EXP_NAME}_{i}", "Pending")
|
|
else:
|
|
if redis_server.llen(EXP_NAME) == 0:
|
|
for i in range(config_num):
|
|
if (
|
|
not redis_server.exists(f"{EXP_NAME}_{i}")
|
|
or redis_server.get(f"{EXP_NAME}_{i}") == "Failed"
|
|
):
|
|
redis_server.rpush(EXP_NAME, i)
|
|
redis_server.set(f"{EXP_NAME}_{i}", "Pending")
|
|
print(f"Starting..., {redis_server.llen(EXP_NAME)} trails to run")
|
|
while True:
|
|
index = redis_server.lpop(EXP_NAME)
|
|
if index is None:
|
|
print("All done")
|
|
break
|
|
index = int(index)
|
|
redis_server.set(f"{EXP_NAME}_{index}", "Running")
|
|
print(f"Trail_{index} is running")
|
|
try:
|
|
res = subprocess.run(["python", "main.py", "--config", args.config, "--index", str(index),],)
|
|
except KeyboardInterrupt:
|
|
redis_server.set(f"{EXP_NAME}_{index}", "Failed")
|
|
print(f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run")
|
|
break
|
|
if res.returncode == 0:
|
|
redis_server.set(f"{EXP_NAME}_{index}", "Finished")
|
|
print(f"Finish running one trail, {redis_server.llen(EXP_NAME)} trails to run")
|
|
else:
|
|
redis_server.set(f"{EXP_NAME}_{index}", "Failed")
|
|
print(f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run")
|
|
|
|
elif os.path.isfile(config_path):
|
|
assert config_path.endswith(".yml"), "Config file should be an yaml file"
|
|
EXP_NAME = EXP_NAME[:-4]
|
|
with open(config_path, "r") as f:
|
|
config = yaml.load(f, Loader=loader)
|
|
config = get_full_config(config, os.path.dirname(config_path))
|
|
log_prefix = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
|
config["log_dir"] = os.path.join(log_prefix, config["log_dir"])
|
|
run(config)
|
|
else:
|
|
print("The config path should be a relative path from EXP_PATH")
|