mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Add ipynb format check (#1439)
* Update test_qlib_from_source.yml * add ipynb format check to workflow * test ipynb CI * modify nbqa check path * add pylint flake8 mypy check to ipynb * check ipynb with black and pylint * reformat .ipynb files * format line length nbqa black . -l 120 * update nbqa .ipynb format CI * format old ipynb files * add nbconvert check to CI * adjust CI order to avoid repeating download data
This commit is contained in:
@@ -41,6 +41,7 @@
|
||||
"\n",
|
||||
"State = namedtuple(\"State\", [\"value\", \"last_action\"])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class SimpleSimulator(Simulator[float, State, float]):\n",
|
||||
" def __init__(self, initial: float, nsteps: int, **kwargs: Any) -> None:\n",
|
||||
" super().__init__(initial)\n",
|
||||
@@ -92,6 +93,7 @@
|
||||
"from gym import spaces\n",
|
||||
"from qlib.rl.interpreter import StateInterpreter\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class SimpleStateInterpreter(StateInterpreter[Tuple[float, float], np.ndarray]):\n",
|
||||
" def interpret(self, state: State) -> np.ndarray:\n",
|
||||
" # Convert state.value to a 1D Numpy array\n",
|
||||
@@ -101,7 +103,8 @@
|
||||
" @property\n",
|
||||
" def observation_space(self) -> spaces.Box:\n",
|
||||
" return spaces.Box(0, np.inf, shape=(1,), dtype=np.float32)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"state_interpreter = SimpleStateInterpreter()"
|
||||
]
|
||||
},
|
||||
@@ -120,6 +123,7 @@
|
||||
"source": [
|
||||
"from qlib.rl.interpreter import ActionInterpreter\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class SimpleActionInterpreter(ActionInterpreter[State, int, float]):\n",
|
||||
" def __init__(self, n_value: int) -> None:\n",
|
||||
" self.n_value = n_value\n",
|
||||
@@ -132,7 +136,8 @@
|
||||
" assert 0 <= action <= self.n_value\n",
|
||||
" # simulator_state.value is used as the denominator\n",
|
||||
" return simulator_state.value * (action / self.n_value)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"action_interpreter = SimpleActionInterpreter(n_value=10)"
|
||||
]
|
||||
},
|
||||
@@ -151,12 +156,14 @@
|
||||
"source": [
|
||||
"from qlib.rl.reward import Reward\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class SimpleReward(Reward[State]):\n",
|
||||
" def reward(self, simulator_state: State) -> float:\n",
|
||||
" # Use last_action to calculate reward. This is why it should be in the state.\n",
|
||||
" rew = simulator_state.last_action / simulator_state.value\n",
|
||||
" return rew\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reward = SimpleReward()"
|
||||
]
|
||||
},
|
||||
@@ -180,6 +187,7 @@
|
||||
"from torch import nn\n",
|
||||
"from qlib.rl.order_execution import PPO\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class SimpleFullyConnect(nn.Module):\n",
|
||||
" def __init__(self, dims: List[int]) -> None:\n",
|
||||
" super().__init__()\n",
|
||||
@@ -195,7 +203,8 @@
|
||||
"\n",
|
||||
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
|
||||
" return self.fc(x)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"policy = PPO(\n",
|
||||
" network=SimpleFullyConnect(dims=[16, 8]),\n",
|
||||
" obs_space=state_interpreter.observation_space,\n",
|
||||
@@ -221,6 +230,7 @@
|
||||
"source": [
|
||||
"from torch.utils.data import Dataset\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class SimpleDataset(Dataset):\n",
|
||||
" def __init__(self, positions: List[float]) -> None:\n",
|
||||
" self.positions = positions\n",
|
||||
@@ -230,7 +240,8 @@
|
||||
"\n",
|
||||
" def __getitem__(self, index: int) -> float:\n",
|
||||
" return self.positions[index]\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"dataset = SimpleDataset(positions=[10.0, 50.0, 100.0])"
|
||||
]
|
||||
},
|
||||
@@ -265,11 +276,13 @@
|
||||
"trainer_kwargs = {\n",
|
||||
" \"max_iters\": 10,\n",
|
||||
" \"finite_env_type\": \"dummy\",\n",
|
||||
" \"callbacks\": [Checkpoint(\n",
|
||||
" dirpath=Path(\"./checkpoints\"),\n",
|
||||
" every_n_iters=1,\n",
|
||||
" save_latest=\"copy\",\n",
|
||||
" )],\n",
|
||||
" \"callbacks\": [\n",
|
||||
" Checkpoint(\n",
|
||||
" dirpath=Path(\"./checkpoints\"),\n",
|
||||
" every_n_iters=1,\n",
|
||||
" save_latest=\"copy\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
"}\n",
|
||||
"vessel_kwargs = {\n",
|
||||
" \"update_kwargs\": {\"batch_size\": 16, \"repeat\": 5},\n",
|
||||
|
||||
Reference in New Issue
Block a user