mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
Qlib RL framework (stage 1) - single-asset order execution (#1076)
* rl init * aux info * Reward config * update * simple * update saoe init * update simulator and seed * minor * minor * update sim * checkpoint * obs * Update interpreter * init qlib simulator * checkpoint * Refine codebase * checkpoint * checkpoint * Add one test * More tests * Simulator checkpoint * checkpoint * First-step tested * Checkpoint * Update data_queue API * Checkpoint * Update test * Move files * Checkpoint * Single-quote -> double-quote * Fix finite env tests * Tested with mypy * pep-574 * No call for env done * Update finite env docs * Fix csv writer * Refine tester * Update logger * Add another logger test * Checkpoint * Add network sanity test * steps per episode is not correct * Cleanup code, ready for PR * Reformat with black * Fix pylint for py37 * Fix lint * Fix lint * Fix flake * update mypy command * mypy * Update exclude pattern * Use pyproject.toml * test * . * . * Refactor pipeline * . * defaults run bash * . * Revert and skip follow_imports * Fix toml issue * fix mypy * . * . * . * Fix install * Minor fix * Fix test * Fix test * Remove requirements * Revert * fix tests * Fix lint * . * . * . * . * . * update install from source command * . * Fix data download * . * . * . * . * . * . * Fix py37 * Ignore tests on non-linux * resolve comments * fix tests * resolve comments * some typo * style updates * More comments * fix dummy * add warning * Align precision in some system * Added some impl notes Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
88
tests/rl/test_data_queue.py
Normal file
88
tests/rl/test_data_queue.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from qlib.rl.utils.data_queue import DataQueue
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert 0 <= index < self.length
|
||||
return pd.DataFrame(np.random.randint(0, 100, size=(index + 1, 4)), columns=list("ABCD"))
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def _worker(dataloader, collector):
|
||||
# for i in range(3):
|
||||
for i, data in enumerate(dataloader):
|
||||
collector.put(len(data))
|
||||
|
||||
|
||||
def _queue_to_list(queue):
|
||||
result = []
|
||||
while not queue.empty():
|
||||
result.append(queue.get())
|
||||
return result
|
||||
|
||||
|
||||
def test_pytorch_dataloader():
|
||||
dataset = DummyDataset(100)
|
||||
dataloader = DataLoader(dataset, batch_size=None, num_workers=1)
|
||||
queue = multiprocessing.Queue()
|
||||
_worker(dataloader, queue)
|
||||
assert len(set(_queue_to_list(queue))) == 100
|
||||
|
||||
|
||||
def test_multiprocess_shared_dataloader():
|
||||
dataset = DummyDataset(100)
|
||||
with DataQueue(dataset, producer_num_workers=1) as data_queue:
|
||||
queue = multiprocessing.Queue()
|
||||
processes = []
|
||||
for _ in range(3):
|
||||
processes.append(multiprocessing.Process(target=_worker, args=(data_queue, queue)))
|
||||
processes[-1].start()
|
||||
for p in processes:
|
||||
p.join()
|
||||
assert len(set(_queue_to_list(queue))) == 100
|
||||
|
||||
|
||||
def test_exit_on_crash_finite():
|
||||
def _exit_finite():
|
||||
dataset = DummyDataset(100)
|
||||
|
||||
with DataQueue(dataset, producer_num_workers=4) as data_queue:
|
||||
time.sleep(3)
|
||||
raise ValueError
|
||||
|
||||
# https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess
|
||||
|
||||
process = multiprocessing.Process(target=_exit_finite)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
|
||||
def test_exit_on_crash_infinite():
|
||||
def _exit_infinite():
|
||||
dataset = DummyDataset(100)
|
||||
with DataQueue(dataset, repeat=-1, queue_maxsize=100) as data_queue:
|
||||
time.sleep(3)
|
||||
raise ValueError
|
||||
|
||||
process = multiprocessing.Process(target=_exit_infinite)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_multiprocess_shared_dataloader()
|
||||
Reference in New Issue
Block a user