mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
@@ -7,6 +7,7 @@ from __future__ import print_function
|
||||
from collections import defaultdict
|
||||
|
||||
import os
|
||||
import gc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Callable, Optional, Text, Union
|
||||
@@ -32,7 +33,6 @@ from ...log import get_module_logger
|
||||
from ...workflow import R
|
||||
from qlib.contrib.meta.data_selection.utils import ICLoss
|
||||
from torch.nn import DataParallel
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
|
||||
|
||||
class DNNModelPytorch(Model):
|
||||
@@ -201,7 +201,7 @@ class DNNModelPytorch(Model):
|
||||
seg, col_set=["feature", "label"], data_key=self.valid_key if seg == "valid" else DataHandlerLP.DK_L
|
||||
)
|
||||
all_df["x"][seg] = df["feature"]
|
||||
all_df["y"][seg] = df["label"]
|
||||
all_df["y"][seg] = df["label"].copy() # We have to use copy to remove the reference to release mem
|
||||
if reweighter is None:
|
||||
all_df["w"][seg] = pd.DataFrame(np.ones_like(all_df["y"][seg].values), index=df.index)
|
||||
elif isinstance(reweighter, Reweighter):
|
||||
@@ -216,6 +216,10 @@ class DNNModelPytorch(Model):
|
||||
all_t[v][seg] = all_t[v][seg].to(self.device) # This will consume a lot of memory !!!!
|
||||
|
||||
evals_result[seg] = []
|
||||
# free memory
|
||||
del df
|
||||
del all_df["x"]
|
||||
gc.collect()
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
@@ -266,7 +270,7 @@ class DNNModelPytorch(Model):
|
||||
loss_val = cur_loss_val.item()
|
||||
metric_val = (
|
||||
self.get_metric(
|
||||
preds.reshape(-1), all_t["y"]["valid"].reshape(-1), all_df["x"]["valid"].index
|
||||
preds.reshape(-1), all_t["y"]["valid"].reshape(-1), all_df["y"]["valid"].index
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
@@ -281,7 +285,7 @@ class DNNModelPytorch(Model):
|
||||
self.get_metric(
|
||||
self._nn_predict(all_t["x"]["train"], return_cpu=False),
|
||||
all_t["y"]["train"].reshape(-1),
|
||||
all_df["x"]["train"].index,
|
||||
all_df["y"]["train"].index,
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
@@ -351,31 +355,17 @@ class DNNModelPytorch(Model):
|
||||
1) test inference (data may come from CPU and expect the output data is on CPU)
|
||||
2) evaluation on training (data may come from GPU)
|
||||
"""
|
||||
if isinstance(data, torch.Tensor) and data.device.type != "cpu":
|
||||
# GPU data
|
||||
# CUDA data don't support pin_memory and multi-processing workers
|
||||
num_workers = 0
|
||||
pin_memory = False
|
||||
else:
|
||||
# CPU data
|
||||
if not isinstance(data, torch.Tensor):
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data = data.values
|
||||
# else: CPU Tensor
|
||||
num_workers = 8
|
||||
pin_memory = True
|
||||
data_loader = DataLoader(
|
||||
data,
|
||||
sampler=SequentialSampler(data),
|
||||
batch_size=self.batch_size,
|
||||
drop_last=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
if not isinstance(data, torch.Tensor):
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data = data.values
|
||||
data = torch.Tensor(data)
|
||||
data = data.to(self.device)
|
||||
preds = []
|
||||
self.dnn_model.eval()
|
||||
with torch.no_grad():
|
||||
for x in data_loader:
|
||||
batch_size = 8096
|
||||
for i in range(0, len(data), batch_size):
|
||||
x = data[i : i + batch_size]
|
||||
preds.append(self.dnn_model(x.to(self.device)).detach().reshape(-1))
|
||||
if return_cpu:
|
||||
preds = np.concatenate([pr.cpu().numpy() for pr in preds])
|
||||
|
||||
@@ -8,14 +8,14 @@ Assumptions
|
||||
|
||||
"""
|
||||
import pandas as pd
|
||||
from blocks.utils.log import logt
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.contrib.report.utils import sub_fig_generator
|
||||
|
||||
|
||||
class FeaAnalyser:
|
||||
def __init__(self, dataset: pd.DataFrame):
|
||||
self._dataset = dataset
|
||||
with logt("calc_stat_values"):
|
||||
with TimeInspector.logt("calc_stat_values"):
|
||||
self.calc_stat_values()
|
||||
|
||||
def calc_stat_values(self):
|
||||
|
||||
Reference in New Issue
Block a user