mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
* rl init * aux info * Reward config * update * simple * update saoe init * update simulator and seed * minor * minor * update sim * checkpoint * obs * Update interpreter * init qlib simulator * checkpoint * Refine codebase * checkpoint * checkpoint * Add one test * More tests * Simulator checkpoint * checkpoint * First-step tested * Checkpoint * Update data_queue API * Checkpoint * Update test * Move files * Checkpoint * Single-quote -> double-quote * Fix finite env tests * Tested with mypy * pep-574 * No call for env done * Update finite env docs * Fix csv writer * Refine tester * Update logger * Add another logger test * Checkpoint * Add network sanity test * steps per episode is not correct * Cleanup code, ready for PR * Reformat with black * Fix pylint for py37 * Fix lint * Fix lint * Fix flake * update mypy command * mypy * Update exclude pattern * Use pyproject.toml * test * . * . * Refactor pipeline * . * defaults run bash * . * Revert and skip follow_imports * Fix toml issue * fix mypy * . * . * . * Fix install * Minor fix * Fix test * Fix test * Remove requirements * Revert * fix tests * Fix lint * . * . * . * . * . * update install from source command * . * Fix data download * . * . * . * . * . * . * Fix py37 * Ignore tests on non-linux * resolve comments * fix tests * resolve comments * some typo * style updates * More comments * fix dummy * add warning * Align precision in some system * Added some impl notes Co-authored-by: Young <afe.young@gmail.com>
89 lines
2.3 KiB
Python
89 lines
2.3 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import multiprocessing
|
|
import time
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from qlib.rl.utils.data_queue import DataQueue
|
|
|
|
|
|
class DummyDataset(Dataset):
|
|
def __init__(self, length):
|
|
self.length = length
|
|
|
|
def __getitem__(self, index):
|
|
assert 0 <= index < self.length
|
|
return pd.DataFrame(np.random.randint(0, 100, size=(index + 1, 4)), columns=list("ABCD"))
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
|
|
def _worker(dataloader, collector):
|
|
# for i in range(3):
|
|
for i, data in enumerate(dataloader):
|
|
collector.put(len(data))
|
|
|
|
|
|
def _queue_to_list(queue):
|
|
result = []
|
|
while not queue.empty():
|
|
result.append(queue.get())
|
|
return result
|
|
|
|
|
|
def test_pytorch_dataloader():
|
|
dataset = DummyDataset(100)
|
|
dataloader = DataLoader(dataset, batch_size=None, num_workers=1)
|
|
queue = multiprocessing.Queue()
|
|
_worker(dataloader, queue)
|
|
assert len(set(_queue_to_list(queue))) == 100
|
|
|
|
|
|
def test_multiprocess_shared_dataloader():
|
|
dataset = DummyDataset(100)
|
|
with DataQueue(dataset, producer_num_workers=1) as data_queue:
|
|
queue = multiprocessing.Queue()
|
|
processes = []
|
|
for _ in range(3):
|
|
processes.append(multiprocessing.Process(target=_worker, args=(data_queue, queue)))
|
|
processes[-1].start()
|
|
for p in processes:
|
|
p.join()
|
|
assert len(set(_queue_to_list(queue))) == 100
|
|
|
|
|
|
def test_exit_on_crash_finite():
|
|
def _exit_finite():
|
|
dataset = DummyDataset(100)
|
|
|
|
with DataQueue(dataset, producer_num_workers=4) as data_queue:
|
|
time.sleep(3)
|
|
raise ValueError
|
|
|
|
# https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess
|
|
|
|
process = multiprocessing.Process(target=_exit_finite)
|
|
process.start()
|
|
process.join()
|
|
|
|
|
|
def test_exit_on_crash_infinite():
|
|
def _exit_infinite():
|
|
dataset = DummyDataset(100)
|
|
with DataQueue(dataset, repeat=-1, queue_maxsize=100) as data_queue:
|
|
time.sleep(3)
|
|
raise ValueError
|
|
|
|
process = multiprocessing.Process(target=_exit_infinite)
|
|
process.start()
|
|
process.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_multiprocess_shared_dataloader()
|