1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00
Files
qlib/tests/rl/test_data_queue.py
Yuge Zhang 9a40fd3cdc Qlib RL framework (stage 1) - single-asset order execution (#1076)
* rl init

* aux info

* Reward config

* update

* simple

* update saoe init

* update simulator and seed

* minor

* minor

* update sim

* checkpoint

* obs

* Update interpreter

* init qlib simulator

* checkpoint

* Refine codebase

* checkpoint

* checkpoint

* Add one test

* More tests

* Simulator checkpoint

* checkpoint

* First-step tested

* Checkpoint

* Update data_queue API

* Checkpoint

* Update test

* Move files

* Checkpoint

* Single-quote -> double-quote

* Fix finite env tests

* Tested with mypy

* pep-574

* No call for env done

* Update finite env docs

* Fix csv writer

* Refine tester

* Update logger

* Add another logger test

* Checkpoint

* Add network sanity test

* steps per episode is not correct

* Cleanup code, ready for PR

* Reformat with black

* Fix pylint for py37

* Fix lint

* Fix lint

* Fix flake

* update mypy command

* mypy

* Update exclude pattern

* Use pyproject.toml

* test

* .

* .

* Refactor pipeline

* .

* defaults run bash

* .

* Revert and skip follow_imports

* Fix toml issue

* fix mypy

* .

* .

* .

* Fix install

* Minor fix

* Fix test

* Fix test

* Remove requirements

* Revert

* fix tests

* Fix lint

* .

* .

* .

* .

* .

* update install from source command

* .

* Fix data download

* .

* .

* .

* .

* .

* .

* Fix py37

* Ignore tests on non-linux

* resolve comments

* fix tests

* resolve comments

* some typo

* style updates

* More comments

* fix dummy

* add warning

* Align precision in some system

* Added some impl notes

Co-authored-by: Young <afe.young@gmail.com>
2022-05-21 18:19:24 +08:00

89 lines
2.3 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import multiprocessing
import time
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from qlib.rl.utils.data_queue import DataQueue
class DummyDataset(Dataset):
def __init__(self, length):
self.length = length
def __getitem__(self, index):
assert 0 <= index < self.length
return pd.DataFrame(np.random.randint(0, 100, size=(index + 1, 4)), columns=list("ABCD"))
def __len__(self):
return self.length
def _worker(dataloader, collector):
# for i in range(3):
for i, data in enumerate(dataloader):
collector.put(len(data))
def _queue_to_list(queue):
result = []
while not queue.empty():
result.append(queue.get())
return result
def test_pytorch_dataloader():
dataset = DummyDataset(100)
dataloader = DataLoader(dataset, batch_size=None, num_workers=1)
queue = multiprocessing.Queue()
_worker(dataloader, queue)
assert len(set(_queue_to_list(queue))) == 100
def test_multiprocess_shared_dataloader():
dataset = DummyDataset(100)
with DataQueue(dataset, producer_num_workers=1) as data_queue:
queue = multiprocessing.Queue()
processes = []
for _ in range(3):
processes.append(multiprocessing.Process(target=_worker, args=(data_queue, queue)))
processes[-1].start()
for p in processes:
p.join()
assert len(set(_queue_to_list(queue))) == 100
def test_exit_on_crash_finite():
def _exit_finite():
dataset = DummyDataset(100)
with DataQueue(dataset, producer_num_workers=4) as data_queue:
time.sleep(3)
raise ValueError
# https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess
process = multiprocessing.Process(target=_exit_finite)
process.start()
process.join()
def test_exit_on_crash_infinite():
def _exit_infinite():
dataset = DummyDataset(100)
with DataQueue(dataset, repeat=-1, queue_maxsize=100) as data_queue:
time.sleep(3)
raise ValueError
process = multiprocessing.Process(target=_exit_infinite)
process.start()
process.join()
if __name__ == "__main__":
test_multiprocess_shared_dataloader()