mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Refactor to Python3 style
This commit is contained in:
@@ -25,7 +25,7 @@ import os
|
||||
import data_formatters.qlib_Alpha158
|
||||
|
||||
|
||||
class ExperimentConfig(object):
|
||||
class ExperimentConfig:
|
||||
"""Defines experiment configs and paths to outputs.
|
||||
|
||||
Attributes:
|
||||
|
||||
@@ -320,7 +320,7 @@ class InterpretableMultiHeadAttention:
|
||||
return outputs, attn
|
||||
|
||||
|
||||
class TFTDataCache(object):
|
||||
class TFTDataCache:
|
||||
"""Caches data for the TFT."""
|
||||
|
||||
_data_cache = {}
|
||||
@@ -348,7 +348,7 @@ class TFTDataCache(object):
|
||||
|
||||
|
||||
# TFT model definitions.
|
||||
class TemporalFusionTransformer(object):
|
||||
class TemporalFusionTransformer:
|
||||
"""Defines Temporal Fusion Transformer.
|
||||
|
||||
Attributes:
|
||||
@@ -972,7 +972,7 @@ class TemporalFusionTransformer(object):
|
||||
valid_quantiles = self.quantiles
|
||||
output_size = self.output_size
|
||||
|
||||
class QuantileLossCalculator(object):
|
||||
class QuantileLossCalculator:
|
||||
"""Computes the combined quantile loss for prespecified quantiles.
|
||||
|
||||
Attributes:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
from .order import Order
|
||||
from .account import Account
|
||||
from .position import Position
|
||||
|
||||
@@ -296,7 +296,7 @@ class DNNModelPytorch(Model):
|
||||
self._fitted = True
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -464,7 +464,7 @@ class SFM(Model):
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -21,7 +21,7 @@ from .executor import SimulatorExecutor
|
||||
from .executor import save_score_series, load_score_series
|
||||
|
||||
|
||||
class Operator(object):
|
||||
class Operator:
|
||||
def __init__(self, client: str):
|
||||
"""
|
||||
Parameters
|
||||
|
||||
@@ -17,7 +17,7 @@ from plotly.figure_factory import create_distplot
|
||||
from ...utils import get_module_by_module_path
|
||||
|
||||
|
||||
class BaseGraph(object):
|
||||
class BaseGraph:
|
||||
""""""
|
||||
|
||||
_name = None
|
||||
@@ -204,7 +204,7 @@ class HistogramGraph(BaseGraph):
|
||||
return _data
|
||||
|
||||
|
||||
class SubplotsGraph(object):
|
||||
class SubplotsGraph:
|
||||
"""Create subplots same as df.plot(subplots=True)
|
||||
|
||||
Simple package for `plotly.tools.subplots`
|
||||
|
||||
@@ -6,7 +6,7 @@ import copy
|
||||
import os
|
||||
|
||||
|
||||
class TunerConfigManager(object):
|
||||
class TunerConfigManager:
|
||||
def __init__(self, config_path):
|
||||
|
||||
if not config_path:
|
||||
@@ -27,7 +27,7 @@ class TunerConfigManager(object):
|
||||
self.qlib_client_config = config.get("qlib_client", dict())
|
||||
|
||||
|
||||
class PipelineExperimentConfig(object):
|
||||
class PipelineExperimentConfig:
|
||||
def __init__(self, config, TUNER_CONFIG_MANAGER):
|
||||
"""
|
||||
:param config: The config dict for tuner experiment
|
||||
@@ -53,7 +53,7 @@ class PipelineExperimentConfig(object):
|
||||
yaml.dump(TUNER_CONFIG_MANAGER.config, fp)
|
||||
|
||||
|
||||
class OptimizationConfig(object):
|
||||
class OptimizationConfig:
|
||||
def __init__(self, config, TUNER_CONFIG_MANAGER):
|
||||
|
||||
self.report_type = config.get("report_type", "pred_long")
|
||||
|
||||
@@ -11,7 +11,7 @@ from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_module_by_module_path
|
||||
|
||||
|
||||
class Pipeline(object):
|
||||
class Pipeline:
|
||||
|
||||
GLOBAL_BEST_PARAMS_NAME = "global_best_params.json"
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from hyperopt import fmin, tpe
|
||||
from hyperopt import STATUS_OK, STATUS_FAIL
|
||||
|
||||
|
||||
class Tuner(object):
|
||||
class Tuner:
|
||||
def __init__(self, tuner_config, optim_config):
|
||||
|
||||
self.logger = get_module_logger("Tuner", sh_level=logging.INFO)
|
||||
|
||||
@@ -8,7 +8,7 @@ from libc.math cimport sqrt, isnan, NAN
|
||||
from libcpp.vector cimport vector
|
||||
|
||||
|
||||
cdef class Expanding(object):
|
||||
cdef class Expanding:
|
||||
"""1-D array expanding"""
|
||||
cdef vector[double] barv
|
||||
cdef int na_count
|
||||
|
||||
@@ -8,7 +8,7 @@ from libc.math cimport sqrt, isnan, NAN
|
||||
from libcpp.deque cimport deque
|
||||
|
||||
|
||||
cdef class Rolling(object):
|
||||
cdef class Rolling:
|
||||
"""1-D array rolling"""
|
||||
cdef int window
|
||||
cdef deque[double] barv
|
||||
|
||||
@@ -68,7 +68,7 @@ class MemCacheUnit(OrderedDict):
|
||||
self.popitem(last=False)
|
||||
|
||||
|
||||
class MemCache(object):
|
||||
class MemCache:
|
||||
"""Memory cache."""
|
||||
|
||||
def __init__(self, mem_cache_size_limit=None, limit_type="length"):
|
||||
@@ -140,7 +140,7 @@ class MemCacheExpire:
|
||||
return value, expire
|
||||
|
||||
|
||||
class CacheUtils(object):
|
||||
class CacheUtils:
|
||||
LOCK_ID = "QLIB"
|
||||
|
||||
@staticmethod
|
||||
@@ -224,7 +224,7 @@ class CacheUtils(object):
|
||||
current_cache_wlock.release()
|
||||
|
||||
|
||||
class BaseProviderCache(object):
|
||||
class BaseProviderCache:
|
||||
"""Provider cache base class"""
|
||||
|
||||
def __init__(self, provider):
|
||||
|
||||
@@ -12,7 +12,7 @@ from ..log import get_module_logger
|
||||
import pickle
|
||||
|
||||
|
||||
class Client(object):
|
||||
class Client:
|
||||
"""A client class
|
||||
|
||||
Provide the connection tool functions for ClientProvider.
|
||||
|
||||
@@ -36,7 +36,7 @@ def get_module_logger(module_name, level=None):
|
||||
return module_logger
|
||||
|
||||
|
||||
class TimeInspector(object):
|
||||
class TimeInspector:
|
||||
|
||||
timer_logger = get_module_logger("timer", level=logging.WARNING)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import scipy.optimize as so
|
||||
from typing import Optional, Union, Callable, List
|
||||
|
||||
|
||||
class PortfolioOptimizer(object):
|
||||
class PortfolioOptimizer:
|
||||
"""Portfolio Optimizer
|
||||
|
||||
The following optimization algorithms are supported:
|
||||
|
||||
@@ -686,7 +686,7 @@ def flatten_dict(d, parent_key="", sep="."):
|
||||
|
||||
|
||||
#################### Wrapper #####################
|
||||
class Wrapper(object):
|
||||
class Wrapper:
|
||||
"""Wrapper class for anything that needs to set up during qlib.init"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user