mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix_issue_715 (#1070)
* fix_issue_715 * fix_issue_1065 Co-authored-by: Linlang Lv (iSoftStone) <v-linlanglv@microsoft.com>
This commit is contained in:
@@ -144,7 +144,7 @@ class ADARNN(Model):
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.cuda()
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
@@ -153,7 +153,7 @@ class ADARNN(Model):
|
||||
def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):
|
||||
self.model.train()
|
||||
criterion = nn.MSELoss()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
len_loader = np.inf
|
||||
for loader in train_loader_list:
|
||||
if len(loader) < len_loader:
|
||||
@@ -165,7 +165,7 @@ class ADARNN(Model):
|
||||
list_label = []
|
||||
for data in data_all:
|
||||
# feature :[36, 24, 6]
|
||||
feature, label_reg = data[0].cuda().float(), data[1].cuda().float()
|
||||
feature, label_reg = data[0].to(self.device).float(), data[1].to(self.device).float()
|
||||
list_feat.append(feature)
|
||||
list_label.append(label_reg)
|
||||
flag = False
|
||||
@@ -179,7 +179,7 @@ class ADARNN(Model):
|
||||
if flag:
|
||||
continue
|
||||
|
||||
total_loss = torch.zeros(1).cuda()
|
||||
total_loss = torch.zeros(1).to(self.device)
|
||||
for i, n in enumerate(index):
|
||||
feature_s = list_feat[n[0]]
|
||||
feature_t = list_feat[n[1]]
|
||||
@@ -325,7 +325,7 @@ class ADARNN(Model):
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().cuda()
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model.predict(x_batch).detach().cpu().numpy()
|
||||
@@ -335,7 +335,7 @@ class ADARNN(Model):
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
def transform_type(self, init_weight):
|
||||
weight = torch.ones(self.num_layers, self.len_seq).cuda()
|
||||
weight = torch.ones(self.num_layers, self.len_seq).to(self.device)
|
||||
for i in range(self.num_layers):
|
||||
for j in range(self.len_seq):
|
||||
weight[i, j] = init_weight[i][j].item()
|
||||
@@ -389,6 +389,7 @@ class AdaRNN(nn.Module):
|
||||
len_seq=9,
|
||||
model_type="AdaRNN",
|
||||
trans_loss="mmd",
|
||||
GPU=0,
|
||||
):
|
||||
super(AdaRNN, self).__init__()
|
||||
self.use_bottleneck = use_bottleneck
|
||||
@@ -399,6 +400,7 @@ class AdaRNN(nn.Module):
|
||||
self.model_type = model_type
|
||||
self.trans_loss = trans_loss
|
||||
self.len_seq = len_seq
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
in_size = self.n_input
|
||||
|
||||
features = nn.ModuleList()
|
||||
@@ -455,7 +457,7 @@ class AdaRNN(nn.Module):
|
||||
|
||||
out_list_all, out_weight_list = out[1], out[2]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
loss_transfer = torch.zeros((1,)).to(self.device)
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
h_start = 0
|
||||
@@ -516,12 +518,12 @@ class AdaRNN(nn.Module):
|
||||
|
||||
out_list_all = out[1]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
loss_transfer = torch.zeros((1,)).to(self.device)
|
||||
if weight_mat is None:
|
||||
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).cuda()
|
||||
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).to(self.device)
|
||||
else:
|
||||
weight = weight_mat
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
for j in range(self.len_seq):
|
||||
@@ -553,12 +555,13 @@ class AdaRNN(nn.Module):
|
||||
|
||||
|
||||
class TransferLoss:
|
||||
def __init__(self, loss_type="cosine", input_dim=512):
|
||||
def __init__(self, loss_type="cosine", input_dim=512, GPU=0):
|
||||
"""
|
||||
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
|
||||
"""
|
||||
self.loss_type = loss_type
|
||||
self.input_dim = input_dim
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
|
||||
def compute(self, X, Y):
|
||||
"""Compute adaptation loss
|
||||
@@ -574,7 +577,7 @@ class TransferLoss:
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
elif self.loss_type == "coral":
|
||||
loss = CORAL(X, Y)
|
||||
loss = CORAL(X, Y, self.device)
|
||||
elif self.loss_type in ("cosine", "cos"):
|
||||
loss = 1 - cosine(X, Y)
|
||||
elif self.loss_type == "kl":
|
||||
@@ -582,10 +585,10 @@ class TransferLoss:
|
||||
elif self.loss_type == "js":
|
||||
loss = js(X, Y)
|
||||
elif self.loss_type == "mine":
|
||||
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).cuda()
|
||||
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).to(self.device)
|
||||
loss = mine_model(X, Y)
|
||||
elif self.loss_type == "adv":
|
||||
loss = adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
|
||||
loss = adv(X, Y, self.device, input_dim=self.input_dim, hidden_dim=32)
|
||||
elif self.loss_type == "mmd_rbf":
|
||||
mmdloss = MMD_loss(kernel_type="rbf")
|
||||
loss = mmdloss(X, Y)
|
||||
@@ -630,12 +633,12 @@ class Discriminator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def adv(source, target, input_dim=256, hidden_dim=512):
|
||||
def adv(source, target, device, input_dim=256, hidden_dim=512):
|
||||
domain_loss = nn.BCELoss()
|
||||
# !!! Pay attention to .cuda !!!
|
||||
adv_net = Discriminator(input_dim, hidden_dim).cuda()
|
||||
domain_src = torch.ones(len(source)).cuda()
|
||||
domain_tar = torch.zeros(len(target)).cuda()
|
||||
adv_net = Discriminator(input_dim, hidden_dim).to(device)
|
||||
domain_src = torch.ones(len(source)).to(device)
|
||||
domain_tar = torch.zeros(len(target)).to(device)
|
||||
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
|
||||
reverse_src = ReverseLayerF.apply(source, 1)
|
||||
reverse_tar = ReverseLayerF.apply(target, 1)
|
||||
@@ -646,16 +649,16 @@ def adv(source, target, input_dim=256, hidden_dim=512):
|
||||
return loss
|
||||
|
||||
|
||||
def CORAL(source, target):
|
||||
def CORAL(source, target, device):
|
||||
d = source.size(1)
|
||||
ns, nt = source.size(0), target.size(0)
|
||||
|
||||
# source covariance
|
||||
tmp_s = torch.ones((1, ns)).cuda() @ source
|
||||
tmp_s = torch.ones((1, ns)).to(device) @ source
|
||||
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
|
||||
|
||||
# target covariance
|
||||
tmp_t = torch.ones((1, nt)).cuda() @ target
|
||||
tmp_t = torch.ones((1, nt)).to(device) @ target
|
||||
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
|
||||
|
||||
# frobenius norm
|
||||
|
||||
@@ -90,7 +90,6 @@ class CSIIndex(IndexBase):
|
||||
raise NotImplementedError("rewrite index_code")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def html_table_index(self) -> int:
|
||||
"""Which table of changes in html
|
||||
|
||||
@@ -98,7 +97,7 @@ class CSIIndex(IndexBase):
|
||||
CSI100: 1
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError("rewrite html_table_index")
|
||||
|
||||
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""formatting the datetime in an instrument
|
||||
@@ -184,12 +183,7 @@ class CSIIndex(IndexBase):
|
||||
df = pd.DataFrame()
|
||||
_tmp_count = 0
|
||||
for _df in pd.read_html(content):
|
||||
if (
|
||||
_df.shape[-1] != 4
|
||||
or _df.iloc[2:,][0].str.contains(
|
||||
"."
|
||||
)[2]
|
||||
):
|
||||
if _df.shape[-1] != 4 or _df.isnull().loc(0)[0][0]:
|
||||
continue
|
||||
_tmp_count += 1
|
||||
if self.html_table_index + 1 > _tmp_count:
|
||||
@@ -341,8 +335,8 @@ class CSI300Index(CSIIndex):
|
||||
return pd.Timestamp("2005-01-01")
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 1
|
||||
def html_table_index(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
class CSI100Index(CSIIndex):
|
||||
@@ -355,8 +349,8 @@ class CSI100Index(CSIIndex):
|
||||
return pd.Timestamp("2006-05-29")
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 2
|
||||
def html_table_index(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
class CSI500Index(CSIIndex):
|
||||
@@ -368,10 +362,6 @@ class CSI500Index(CSIIndex):
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2007-01-15")
|
||||
|
||||
@property
|
||||
def html_table_index(self) -> int:
|
||||
return 0
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
@@ -475,5 +465,4 @@ class CSI500Index(CSIIndex):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
get_instruments(index_name="CSI300", qlib_dir="~/.qlib/qlib_data/cn_data", method="parse_instruments")
|
||||
# fire.Fire(get_instruments)
|
||||
fire.Fire(get_instruments)
|
||||
|
||||
Reference in New Issue
Block a user