1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

Add ipynb format check (#1439)

* Update test_qlib_from_source.yml

* add ipynb format check to workflow

* test ipynb CI

* modify nbqa check path

* add pylint flake8 mypy check to ipynb

* check ipynb with black and pylint

* reformat .ipynb files

* format line length

nbqa black . -l 120

* update nbqa .ipynb format CI

* format old ipynb files

* add nbconvert check to CI

* adjust CI order to avoid repeating download data
This commit is contained in:
Cadenza-Li
2023-02-21 09:23:22 +08:00
committed by GitHub
parent 5eb5ac1f1f
commit 76f2fb1a1a
6 changed files with 275 additions and 173 deletions

View File

@@ -41,6 +41,7 @@
"\n",
"State = namedtuple(\"State\", [\"value\", \"last_action\"])\n",
"\n",
"\n",
"class SimpleSimulator(Simulator[float, State, float]):\n",
" def __init__(self, initial: float, nsteps: int, **kwargs: Any) -> None:\n",
" super().__init__(initial)\n",
@@ -92,6 +93,7 @@
"from gym import spaces\n",
"from qlib.rl.interpreter import StateInterpreter\n",
"\n",
"\n",
"class SimpleStateInterpreter(StateInterpreter[Tuple[float, float], np.ndarray]):\n",
" def interpret(self, state: State) -> np.ndarray:\n",
" # Convert state.value to a 1D Numpy array\n",
@@ -101,7 +103,8 @@
" @property\n",
" def observation_space(self) -> spaces.Box:\n",
" return spaces.Box(0, np.inf, shape=(1,), dtype=np.float32)\n",
" \n",
"\n",
"\n",
"state_interpreter = SimpleStateInterpreter()"
]
},
@@ -120,6 +123,7 @@
"source": [
"from qlib.rl.interpreter import ActionInterpreter\n",
"\n",
"\n",
"class SimpleActionInterpreter(ActionInterpreter[State, int, float]):\n",
" def __init__(self, n_value: int) -> None:\n",
" self.n_value = n_value\n",
@@ -132,7 +136,8 @@
" assert 0 <= action <= self.n_value\n",
" # simulator_state.value is used as the denominator\n",
" return simulator_state.value * (action / self.n_value)\n",
" \n",
"\n",
"\n",
"action_interpreter = SimpleActionInterpreter(n_value=10)"
]
},
@@ -151,12 +156,14 @@
"source": [
"from qlib.rl.reward import Reward\n",
"\n",
"\n",
"class SimpleReward(Reward[State]):\n",
" def reward(self, simulator_state: State) -> float:\n",
" # Use last_action to calculate reward. This is why it should be in the state.\n",
" rew = simulator_state.last_action / simulator_state.value\n",
" return rew\n",
" \n",
"\n",
"\n",
"reward = SimpleReward()"
]
},
@@ -180,6 +187,7 @@
"from torch import nn\n",
"from qlib.rl.order_execution import PPO\n",
"\n",
"\n",
"class SimpleFullyConnect(nn.Module):\n",
" def __init__(self, dims: List[int]) -> None:\n",
" super().__init__()\n",
@@ -195,7 +203,8 @@
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" return self.fc(x)\n",
" \n",
"\n",
"\n",
"policy = PPO(\n",
" network=SimpleFullyConnect(dims=[16, 8]),\n",
" obs_space=state_interpreter.observation_space,\n",
@@ -221,6 +230,7 @@
"source": [
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class SimpleDataset(Dataset):\n",
" def __init__(self, positions: List[float]) -> None:\n",
" self.positions = positions\n",
@@ -230,7 +240,8 @@
"\n",
" def __getitem__(self, index: int) -> float:\n",
" return self.positions[index]\n",
" \n",
"\n",
"\n",
"dataset = SimpleDataset(positions=[10.0, 50.0, 100.0])"
]
},
@@ -265,11 +276,13 @@
"trainer_kwargs = {\n",
" \"max_iters\": 10,\n",
" \"finite_env_type\": \"dummy\",\n",
" \"callbacks\": [Checkpoint(\n",
" dirpath=Path(\"./checkpoints\"),\n",
" every_n_iters=1,\n",
" save_latest=\"copy\",\n",
" )],\n",
" \"callbacks\": [\n",
" Checkpoint(\n",
" dirpath=Path(\"./checkpoints\"),\n",
" every_n_iters=1,\n",
" save_latest=\"copy\",\n",
" )\n",
" ],\n",
"}\n",
"vessel_kwargs = {\n",
" \"update_kwargs\": {\"batch_size\": 16, \"repeat\": 5},\n",