1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

fix duplicate log (#1661)

* fix duplicate log

* fix unit test

* fix log

* fix_duplicate_log

* fix_duplicate_log

* add comments

---------

Co-authored-by: Linlang <Lv.Linlang@hotmail.com>
This commit is contained in:
YQ Tsui
2024-12-09 15:45:31 +08:00
committed by GitHub
parent b604fe56b3
commit 431f574967
2 changed files with 22 additions and 13 deletions

View File

@@ -1,8 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from random import randint, choice
from pathlib import Path
import logging
import re
from typing import Any, Tuple
@@ -69,6 +69,10 @@ class AnyPolicy(BasePolicy):
def test_simple_env_logger(caplog):
set_log_with_config(C.logging_config)
# In order for caplog to capture log messages, we configure it here:
# allow logs from the qlib logger to be passed to the parent logger.
C.logging_config["loggers"]["qlib"]["propagate"] = True
logging.config.dictConfig(C.logging_config)
for venv_cls_name in ["dummy", "shmem", "subproc"]:
writer = ConsoleWriter()
csv_writer = CsvWriter(Path(__file__).parent / ".output")
@@ -80,13 +84,12 @@ def test_simple_env_logger(caplog):
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert output_file.columns.tolist() == ["reward", "a", "c"]
assert len(output_file) >= 30
line_counter = 0
for line in caplog.text.splitlines():
line = line.strip()
if line:
line_counter += 1
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert re.match(r".*reward .* {2}a .* \(([456])\.\d+\) {2}c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3
@@ -137,15 +140,17 @@ class RandomFivePolicy(BasePolicy):
def test_logger_with_env_wrapper():
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
env_wrapper_factory = lambda: EnvWrapper(
SimpleSimulator,
DummyStateInterpreter(),
DummyActionInterpreter(),
data_iterator,
logger=LogCollector(LogLevel.DEBUG),
)
# loglevel can be debug here because metrics can all dump into csv
def env_wrapper_factory():
return EnvWrapper(
SimpleSimulator,
DummyStateInterpreter(),
DummyActionInterpreter(),
data_iterator,
logger=LogCollector(LogLevel.DEBUG),
)
# loglevel can be debugged here because metrics can all dump into csv
# otherwise, csv writer might crash
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
@@ -155,7 +160,7 @@ def test_logger_with_env_wrapper():
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(output_df) == 20
# obs has a increasing trend
# obs has an increasing trend
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
assert (output_df["test_a"] == 233).all()
assert (output_df["test_b"] == 200).all()