diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py index 0d637d9db..88926b553 100644 --- a/qlib/backtest/order.py +++ b/qlib/backtest/order.py @@ -1,30 +1,35 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import pandas as pd +from dataclasses import dataclass, field +from typing import ClassVar - +@dataclass class Order: - - SELL = 0 - BUY = 1 - - def __init__(self, stock_id, amount, start_time, end_time, direction, factor): - """Parameter - direction : Order.SELL for sell; Order.BUY for buy - stock_id : str - amount : float - trade_date : pd.Timestamp - factor : float + """ + stock_id : str + amount : float + start_time : pd.Timestamp + closed start time for order generation + end_time : pd.Timestamp + closed end time for order generation + direction : Order.SELL for sell; Order.BUY for buy + factor : float presents the weight factor assigned in Exchange() - """ - # check direction - if direction not in {Order.SELL, Order.BUY}: + """ + stock_id : str + amount : float + start_time : pd.Timestamp + end_time : pd.Timestamp + direction : int + factor : float + deal_amount : float = field(init=False) + SELL : ClassVar[int] = 0 + BUY : ClassVar[int] = 1 + + + def __post_init__(self): + if self.direction not in {Order.SELL, Order.BUY}: raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy") - self.stock_id = stock_id - # amount of generated orders - self.amount = amount - # amount of successfully completed orders self.deal_amount = 0 - self.start_time = start_time - self.end_time = end_time - self.direction = direction - self.factor = factor + diff --git a/setup.py b/setup.py index 92c9ccc0c..0205ab087 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ REQUIRED = [ "pymongo==3.7.2", # For task management "scikit-learn>=0.22", "dill", + "dataclasses;python_version<'3.7'", ] # Numpy include