土曜日, 6月 7, 2025
- Advertisment -
ホームニューステックニュースDQNでトレードの売買判断を行うAIの試作 〜環境・トレーニング編〜 #Python - Qiita

DQNでトレードの売買判断を行うAIの試作 〜環境・トレーニング編〜 #Python – Qiita



DQNでトレードの売買判断を行うAIの試作 〜環境・トレーニング編〜 #Python - Qiita

本記事で紹介している内容は、DQN(ディープQネットワーク)を用いた日経平均トレードの技術的な解説およびシミュレーション事例であり、特定の投資行動や金融商品の購入・売却を勧誘するものではありません。
また、記載された運用成績や利回りは過去のバックテストまたはシミュレーション結果に基づいており、将来の成果を保証するものではありません。
投資には元本割れや損失が発生するリスクがあり、最終的な投資判断はご自身の責任でお願いいたします。
本記事の内容を参考にして生じたいかなる損失についても、一切の責任を負いかねます。
投資にあたっては、必ずご自身で十分な調査・ご判断のうえ、必要に応じて専門家等にご相談ください。

関連記事一覧

1. DQNでトレードの売買判断を行うAIの試作 〜データ準備編〜
2. (本記事)DQNでトレードの売買判断を行うAIの試作 〜環境・トレーニング編〜

はじめに

前回の記事では、AIトレーダーに「何を見せるか」というデータ準備の部分を詳しく解説しました。株価データから各種テクニカル指標まで、人間のトレーダーが参考にする情報をAIにも提供できるようになりましたね。

今回は、いよいよAIが実際に「取引を学習する環境」と「学習プロセス」について解説します。強化学習では、AIが試行錯誤しながら最適な行動を学ぶための「環境」を作ることが重要です。この環境で、AIは何千回、何万回と取引を繰り返し、徐々に上手になっていきます。

ここでは強化学習や実際のコードの詳細について解説しながら、最後にはコード全文を載せています。

強化学習の基本:環境、行動、報酬

強化学習は「環境」「エージェント(AI)」「行動」「報酬」の4つの要素で構成されます。子供がゲームを覚えるのと似ていて:

  • 環境:ゲームの世界(今回は株式市場)
  • エージェント:プレイヤー(今回はAIトレーダー)
  • 行動:プレイヤーができること(今回は買い・売り・様子見)
  • 報酬:行動の結果得られるポイント(今回は利益・損失)

AIは「どの状況でどの行動を取れば、より多くの報酬を得られるか」を学習していきます。
実際のコードを見ていきましょう。

取引環境の設計:NikkeiEnvクラス

NikkeiEnvクラスは、株式市場をシミュレートする仮想的な取引環境です。この環境の中で、AIは実際のお金を使わずに何度でも取引の練習ができます。

基本設定:AIができること

class NikkeiEnv(gym.Env):
    def __init__(
        self,
        df,
        window_size=30,
        transaction_cost=0.001,
        risk_limit=0.5,
        trade_penalty=0.002,
    ):
        # 行動空間(AIが取れる行動)
        self.action_space = spaces.Discrete(3)
        # 0: ロング(買い)
        # 1: フラット(様子見)
        # 2: ショート(売り)

DQN(Deep Q-Network)は、強化学習アルゴリズムの一種で、特に離散的な行動空間(action space)を持つ問題に使用されます。
AIが取れる行動は3つだけです:

  1. ロング(買い):株価が上がると利益、下がると損失
  2. フラット(様子見):何もしない、利益も損失もなし
  3. ショート(売り):株価が下がると利益、上がると損失

シンプルですが、これだけでも十分に複雑な戦略を学習できます。

観測空間:AIが「見る」情報

self.observation_space = spaces.Box(
    low=-np.inf,
    high=np.inf,
    shape=(window_size, len(self.feature_cols)),
    dtype=np.float32,
)

AIは直近130日間(window_size=130)の29種類の指標(feature_cols)を同時に見ることができます。これは人間にはなかなかできない情報処理能力ですね。

前回の記事で準備した特徴量がここで活用されます:

self.feature_cols = [
    "Open", "SMA_5", "SMA_25", "SMA_75",
    "Upper_3σ", "Upper_2σ", "Upper_1σ",
    "Lower_3σ", "Lower_2σ", "Lower_1σ",
    "偏差値25", "Upper2_3σ", "Upper2_2σ", "Upper2_1σ",
    "Lower2_3σ", "Lower2_2σ", "Lower2_1σ",
    "偏差値75", "RSI_14", "RSI_22",
    "MACD", "MACD_signal", "Japan_10Y_Rate", "US_10Y_Rate",
    "ATR_5", "ATR_25", "RCI_9", "RCI_26", "VIX",
]

株価、移動平均線、ボリンジャーバンド、RSI、MACD、金利、VIXなど、人間のトレーダーが使う指標をすべてAIに提供しています。

データの正規化:公平な比較のために

def _get_observation(self):
    obs = []
    for col in self.feature_cols:
        # 現在のウィンドウの値を取得
        window = self.data[col][
            self.current_step - self.window_size : self.current_step
        ]
        # MinMax法で0〜1の範囲に正規化
        min_val = np.min(window)
        max_val = np.max(window)
        if max_val - min_val == 0:
            norm = np.zeros_like(window)
        else:
            norm = (window - min_val) / (max_val - min_val)
        obs.append(norm.reshape(-1, 1))
    
    observation = np.concatenate(obs, axis=1).astype(np.float32)
    return observation

この正規化処理が重要なポイントです。例えば、日経平均が30,000円でRSIが70の場合、どちらが「高い」のでしょうか?絶対値では日経平均の方が大きいですが、投資判断としてはRSIの70の方が重要かもしれません。

MinMax法を使って各指標を0〜1の範囲に正規化することで、異なるスケールの指標を公平に扱えるようになります。また、直近130日間での相対的な位置を示すため、「最近の相場環境における相対的な高さ・低さ」を捉えられます。

報酬設計:AIのモチベーション

強化学習で最も重要なのが「報酬設計」です。AIは報酬を最大化するように学習するため、報酬の設計が学習結果を大きく左右します。

基本的な報酬:対数リターン

def step(self, action):
    old_balance = float(self.balance)
    
    # 当日と翌日の株価を取得
    price_today = self.data["Open"][self.current_step]
    price_tomorrow = self.data["Open"][self.current_step + 1]
    ret = (price_tomorrow - price_today) / price_today
    
    # ポジションに応じて資産を更新
    if action == 0:  # ロング(買い)
        self.balance *= 1 + ret
    elif action == 2:  # ショート(売り)
        self.balance *= 1 - ret
    elif action == 1:  # フラット(様子見)
        pass
    
    # 1日分の報酬:対数リターン
    reward = np.log(self.balance / old_balance)

基本的な報酬は「対数リターン」を使用しています。これは金融工学でよく使われる指標で、以下の特徴があります。

  • 複利効果を考慮:連続的な利益の積み重ねを適切に評価
  • リスク調整:大きな利益も大きな損失も、対数変換により適度に抑制
  • 数学的性質:加法性があり、期間をまたいだ計算が容易

中期的視点の報酬:先を見据えた判断

# 中期的視点の報酬追加
discount_factor = 0.3
days = 3
if self.current_step + days  len(self.data["Open"]):
    price_future = self.data["Open"][self.current_step + days]
else:
    price_future = self.data["Open"][-1]

# 行動に応じた中期リターンの計算
if action == 0:  # ロング
    mid_return = np.log(price_future / price_today)
elif action == 2:  # ショート
    mid_return = -np.log(price_future / price_today)
else:  # フラット
    mid_return = 0.0

# 割引率を考慮して中期的な報酬を追加
reward += discount_factor * mid_return

この実装の面白いところは、「3日後の株価」も考慮している点です。短期的な利益だけでなく、中期的な視点も持たせることで、より安定した取引戦略を学習できるようになります。

割引率(discount_factor=0.3)により、中期的な報酬の影響を調整しています。これにより、「今日の利益」と「3日後の利益」のバランスを取っています。

現実的な制約:取引コストとリスク管理

# 前回のポジションと異なる場合は手数料を引く
if action != self.prev_action:
    cost = self.balance * self.transaction_cost
    self.balance -= cost
    self.trade_count += 1

# エピソード終了条件
if (
    (self.current_step >= len(self.data["Open"]) - 1)
    or (self.balance  self.initial_balance * self.risk_limit)
    or (self.balance  0)
):
    done = True

現実の取引では手数料がかかります。ポジションを変更するたびにtransaction_cost(例:0.1%)の手数料を差し引くことで、頻繁な取引を抑制し、より効率的な戦略を学習させます。

また、資産が初期資産の60%(risk_limit=0.6)を下回ると取引終了となります。これにより、破産リスクを避ける保守的な戦略を学習できます。

ResNetによる特徴抽出:時系列データの深い理解

今回のDQNの心臓部となるのが、ResNet(Residual Network)を使った特徴抽出器です。ResNetは画像認識で有名ですが、時系列データにも応用できます。

ResNetの基本アイデア:残差学習

class ResidualBlock(nn.Module):
    def __init__(self, channels, kernel_size=3):
        super(ResidualBlock, self).__init__()
        padding = kernel_size // 2
        
        self.conv1 = nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=kernel_size,
            padding=padding,
            bias=False,
        )
        self.bn1 = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=kernel_size,
            padding=padding,
            bias=False,
        )
        self.bn2 = nn.BatchNorm1d(channels)
    
    def forward(self, x):
        residual = x  # 入力を保存
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += residual  # 残差接続:入力を直接加算
        out = self.relu(out)
        
        return out

ResNetの核心は「残差接続」です。通常のニューラルネットワークは入力を変換して出力しますが、ResNetは「入力からの変化分(残差)」を学習します。

これにより

  • 勾配消失問題の解決:深いネットワークでも学習が安定
  • 恒等写像の学習:必要なければ入力をそのまま出力
  • 細かいパターンの検出:微細な変化も捉えやすい

株価データのような時系列では、「前日からの変化」「トレンドの変化」など、変化分の情報が重要です。ResNetはこれらを効果的に捉えられます。

時系列用ResNetの実装

class ResNetFeatures(BaseFeaturesExtractor):
    def __init__(
        self, observation_space: gym.spaces.Box, features_dim=128, num_blocks=3
    ):
        super(ResNetFeatures, self).__init__(
            observation_space, features_dim=features_dim
        )
        
        self.window_size = observation_space.shape[0]  # 130日
        self.input_dim = observation_space.shape[1]    # 29指標
        
        # 入力層:29指標 → 128次元に変換
        self.input_projection = nn.Sequential(
            nn.Linear(self.input_dim, features_dim), nn.ReLU()
        )
        
        # 3つの残差ブロック
        self.res_blocks = nn.ModuleList(
            [ResidualBlock(features_dim) for _ in range(num_blocks)]
        )
        
        # グローバル平均プーリング:時系列全体を要約
        self.pool = nn.AdaptiveAvgPool1d(1)

この実装では

  1. 入力射影:29種類の指標を128次元の特徴空間に変換
  2. 残差ブロック:3つのブロックで段階的に特徴を抽出
  3. グローバル平均プーリング:130日分の情報を1つのベクトルに要約

人間が「チャート全体の印象」を掴むように、AIも時系列全体から重要な特徴を抽出します。

なぜResNetを選んだのか?

株価データには以下の特徴があります

  • 長期依存性:数ヶ月前の出来事が今の株価に影響
  • 複雑なパターン:単純な線形関係では捉えられない
  • ノイズの多さ:重要な信号とノイズの区別が困難

ResNetは

  • 深い学習:複雑なパターンを段階的に学習
  • 安定した学習:残差接続により勾配が安定
  • 柔軟性:必要に応じて複雑さを調整可能

これらの特徴により、株価の複雑なパターンを効果的に学習できます。
色々試行錯誤した結果、Resnetが成績がよかったので今はこのアーキテクチャに落ち着きました。

DQNモデルの設定:学習パラメータの調整

model = DQN(
    "MlpPolicy",
    train_env,
    policy_kwargs=policy_kwargs,
    exploration_final_eps=0.03,      # 最終的な探索率
    exploration_fraction=0.3,        # 探索期間の割合
    learning_rate=1e-5,              # 学習率
    verbose=1,
    device=device,
)

探索と活用のバランス

  • exploration_final_eps=0.03:学習後期でも3%の確率でランダム行動
  • exploration_fraction=0.3:全学習期間の30%を探索期間とする

強化学習では「探索」と「活用」のバランスが重要です。探索が少ないと局所最適に陥り、多すぎると学習が進みません。

学習率の設定

  • learning_rate=1e-5:非常に小さな学習率

株価データは非常にノイズが多いため、小さな学習率でゆっくりと学習させます。急激な変化よりも、安定した学習を重視しています。

学習プロセス:35万ステップ

print("エージェントの学習開始...")
model.learn(total_timesteps=350000, callback=checkpoint_callback, progress_bar=True)
print("学習完了!")

35万ステップの学習は、約27年分のデータ(1997-2024年)を何度も繰り返し学習することを意味します。

  1. 初期段階:ランダムに行動し、基本的なパターンを学習
  2. 中期段階:徐々に有効な戦略を発見し、探索を減らす
  3. 後期段階:学習した戦略を洗練し、安定した性能を目指す

チェックポイント機能

checkpoint_callback = CheckpointCallback(
    save_freq=10000,
    save_path="./",
    name_prefix=f"nikkei_cp_{start}_{end}",
)

1万ステップごとにモデルを保存することで、異なるステップの比較ができます。また推論モードの時にこれらを読み込んで使います。

学習中の出力:AIの成長を見守る

print(
    f"アクション[0:買,1:待,2:売]:{action}, ステップ:{self.num_step}, "
    f"累積リワード:{self.sum_reward:.4f}, 資産:{int(self.balance)}, リターン:{ret}, "
    f"トレード回数:{self.trade_count}, 明日: {int(price_tomorrow)},株価:{int(price_today)}"
)

学習中は各エピソード終了時に詳細な情報が出力されます。

  • アクション:最後に取った行動(0:買い、1:様子見、2:売り)
  • 累積リワード:そのエピソードでの総報酬
  • 資産:最終的な資産額
  • トレード回数:取引回数(頻繁すぎないかチェック)

これらの情報から、AIの学習進捗を把握できます。

実際の学習設定

train_env = DummyVecEnv([
    lambda: NikkeiEnv(
        train_data,
        window_size=130,           # 130日間のデータを観測
        transaction_cost=0.001,   # 取引手数料0.1%
        risk_limit=0.6,            # 資産が60%を下回ると終了
        trade_penalty=0.000000,    # 取引ペナルティなし
    )
])

この設定では

  • 長期的視点:130日間という長期間のデータを参考にする
  • 現実的コスト:0.01%の取引手数料を考慮
  • リスク管理:40%の損失で取引停止
  • 取引頻度:ペナルティなしで自然な取引頻度を学習

取引ペナルティを設けた場合、学習があまりうまく進まなかったので、ここでは0にしています。

実行の様子

トレーニングを実行すると、5万ステップぐらいから資金が増えるようになり、その後シミュレーター上でだんだんと学習が進み100万円からスタートする資金が数百兆円になるのが確認できます。

...

アクション[0:買,1:待,2:売]:2, ステップ:6196, 累積リワード:27.4626, 資産:667814885183, リターン:0.0, 
トレード回数:2335, 明日: 33458,株価:33458

アクション[0:買,1:待,2:売]:2, ステップ:6196, 累積リワード:29.0031, 資産:3770660314288, リターン:0.0, 
トレード回数:2063, 明日: 33458,株価:33458

アクション[0:買,1:待,2:売]:2, ステップ:6196, 累積リワード:32.4817, 資産:17134537386443, リターン:0.0, 
トレード回数:1732, 明日: 33458,株価:33458

アクション[0:買,1:待,2:売]:0, ステップ:6196, 累積リワード:35.4314, 資産:115568295256266, リターン:0.0, 
トレード回数:1335, 明日: 33458,株価:33458

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.0755   |
| time/               |          |
|    episodes         | 32       |
|    fps              | 222      |
|    time_elapsed     | 450      |
|    total_timesteps  | 100071   |
| train/              |          |
|    learning_rate    | 1e-05    |
|    loss             | 0.000118 |
|    n_updates        | 24992    |
----------------------------------

アクション[0:買,1:待,2:売]:0, ステップ:6196, 累積リワード:38.4772, 資産:822027980355331, リターン:0.0, 
トレード回数:1053, 明日: 33458,株価:33458

アクション[0:買,1:待,2:売]:0, ステップ:6196, 累積リワード:41.7966, 資産:4038394362343677, リターン:0.0, 
トレード回数:959, 明日: 33458,株価:33458

...

まとめ:AIトレーダーの学習環境

今回は、DQNを用いた株価テクニカル分析AIの学習環境とトレーニングプロセスについて解説しました。重要なポイントをまとめると:

環境設計のポイント

  1. シンプルな行動空間:買い・様子見・売りの3択
  2. 豊富な観測情報:130日×29指標の多次元データ
  3. 現実的な制約:取引コスト、リスク制限の考慮
  4. 適切な報酬設計:対数リターン+中期的視点

技術的な工夫

  1. ResNet特徴抽出:時系列データの深い理解
  2. データ正規化:異なるスケールの指標を公平に扱う
  3. 残差学習:微細な変化パターンの検出
  4. 探索と活用:学習効率と性能のバランス

学習プロセス

  1. 大量のデータ:27年分の株価データで学習
  2. 長期間の学習:35万ステップの反復学習
  3. 段階的な改善:探索から活用への移行
  4. 継続的な監視:チェックポイントによる進捗管理

AIトレーダーは、人間には不可能な大量の情報処理と、何万回もの取引経験を通じて、独自の投資戦略を学習していきます。ただし、過去のデータで学習したAIが未来の相場でも通用するかは別問題です。

次回は、学習したモデルの評価方法と、実際の性能について詳しく解説する予定です。

コード全文

main.py

import numpy as np
import gym
from gym import spaces
import torch
import torch.nn as nn
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.callbacks import CheckpointCallback
from data import generate_env_data


# ──────────────────────────────
# 1. 改善版 環境 (NikkeiEnv)
class NikkeiEnv(gym.Env):
    """
    日経225の終値・出来高データを用いたシンプルなトレーディング環境
    ・観測:直近 window_size 日間の各種特徴(例:始値、出来高)を、それぞれウィンドウ初日を基準に正規化
    ・行動:2: ショート, 1: フラット, 0: ロング
    ・取引手数料:前回ポジションと異なる場合、現在の残高に対して transaction_cost % の費用がかかる
    ・報酬:1日分の相対的な対数リターン(手数料考慮済み)
    ・エピソード終了:データ終了、あるいは資産残高が初期資産の risk_limit 未満になった場合
    """

    metadata = {"render.modes": ["human"]}

    def __init__(
        self,
        df,
        window_size=30,
        transaction_cost=0.001,
        risk_limit=0.5,
        trade_penalty=0.002,
    ):
        super(NikkeiEnv, self).__init__()

        # 既存の初期化処理…
        df = df.dropna().reset_index(drop=True)
        self.df = df
        self.feature_cols = [
            "Open",
            "SMA_5",
            "SMA_25",
            "SMA_75",
            "Upper_3σ",
            "Upper_2σ",
            "Upper_1σ",
            "Lower_3σ",
            "Lower_2σ",
            "Lower_1σ",
            "偏差値25",
            "Upper2_3σ",
            "Upper2_2σ",
            "Upper2_1σ",
            "Lower2_3σ",
            "Lower2_2σ",
            "Lower2_1σ",
            "偏差値75",
            "RSI_14",
            "RSI_22",
            "MACD",
            "MACD_signal",
            "Japan_10Y_Rate",
            "US_10Y_Rate",
            "ATR_5",
            "ATR_25",
            "RCI_9",
            "RCI_26",
            "VIX",
        ]

        self.data = {col: self.df[col].values for col in self.feature_cols}
        self.window_size = window_size
        self.current_step = window_size  # 最初の window_size 日は観測用

        # 行動空間(例:0→Long, 1→Flat)
        self.action_space = spaces.Discrete(3)
        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(window_size, len(self.feature_cols)),
            dtype=np.float32,
        )

        # 資産関係の初期設定
        self.initial_balance = 1_000_000
        self.balance = self.initial_balance
        self.equity_curve = [self.balance]
        self.sum_reward = 0
        self.num_step = 0

        self.transaction_cost = transaction_cost  # 例:0.001 → 0.1%
        self.risk_limit = risk_limit  # 資金が初期の risk_limit 未満なら終了

        # 新たに取引ペナルティと取引数を管理する変数を設定
        self.trade_penalty = trade_penalty  # 1回の取引ごとに与える追加ペナルティ
        self.trade_count = 0  # 累積の取引回数

        # エピソード開始時のポジション
        self.prev_action = 1  # 例:フラット(何もしない)

    def reset(self):
        self.current_step = self.window_size
        self.balance = self.initial_balance
        self.equity_curve = [self.balance]
        self.trade_count = 0  # 取引数もリセット
        self.prev_action = 1
        self.sum_reward = 0
        self.num_step = 0
        return self._get_observation()

    def _get_observation(self):
        # MIXMAX法
        obs = []
        for col in self.feature_cols:
            # 現在のウィンドウの値を取得
            window = self.data[col][
                self.current_step - self.window_size : self.current_step
            ]
            # ウィンドウ内の最小値・最大値を計算: MinMax法
            min_val = np.min(window)
            max_val = np.max(window)
            # ゼロ除算を防ぐため、最大値と最小値が等しい場合の処理
            if max_val - min_val == 0:
                norm = np.zeros_like(window)
            else:
                norm = (window - min_val) / (max_val - min_val)
            # shape (window_size, 1) に整形してリストに追加
            obs.append(norm.reshape(-1, 1))
        # 各特徴量ごとの正規化済みデータを連結 → shape: (window_size, len(feature_cols))
        observation = np.concatenate(obs, axis=1).astype(np.float32)
        return observation

    def step(self, action):
        old_balance = float(self.balance)
        self.num_step += 1

        # 当日と翌日の株価(ここではOpen値)を取得
        price_today = self.data["Open"][self.current_step]
        if self.current_step + 1  len(self.data["Open"]):
            price_tomorrow = self.data["Open"][self.current_step + 1]
        else:
            price_tomorrow = price_today

        ret = (price_tomorrow - price_today) / price_today

        # 保有ポジションごとに資産を更新(0: ロング, 1: フラット, 2: ショート)
        if action == 0:  # ロングの場合
            self.balance *= 1 + ret
        elif action == 2:  # ショートの場合
            self.balance *= 1 - ret
        elif action == 1:  # フラットの場合(インフレペナルティなどがあれば適用)
            pass

        # 前回のポジションと異なる場合は手数料を引く
        if action != self.prev_action:
            cost = self.balance * self.transaction_cost
            self.balance -= cost
            self.trade_count += 1

        # 1日分のリワード:対数リターン
        reward = np.log(self.balance / old_balance)
        # ── ここから中期的視点の報酬追加 ──
        # 3日後の株価を使用(もし存在しない場合は末尾の値を使用)
        discount_factor = 0.3
        days = 3
        if self.current_step + days  len(self.data["Open"]):
            price_future = self.data["Open"][self.current_step + days]
        else:
            price_future = self.data["Open"][-1]

        # 行動に応じた中期リターンの計算
        if action == 0:  # ロングの場合
            mid_return = np.log(price_future / price_today)
        elif action == 2:  # ショートの場合
            mid_return = -np.log(price_future / price_today)
        else:  # フラットの場合:中期的なリターンは0とする
            mid_return = 0.0

        # 割引率を考慮して中期的な報酬を追加
        reward += discount_factor * mid_return
        # ── ここまで中期的視点の報酬追加 ──

        # エピソード終了条件判定
        done = False
        self.prev_action = action
        self.sum_reward += reward

        if (
            (self.current_step >= len(self.data["Open"]) - 1)
            or (self.balance  self.initial_balance * self.risk_limit)
            or (self.balance  0)
        ):
            print(
                f"アクション[0:買,1:待,2:売]:{action}, ステップ:{self.num_step}, "
                f"累積リワード:{self.sum_reward:.4f}, 資産:{int(self.balance)}, リターン:{ret}, "
                f"トレード回数:{self.trade_count}, 明日: {int(price_tomorrow)},株価:{int(price_today)}"
            )
            done = True

        obs = self._get_observation() if not done else None
        info = {"trade_count": self.trade_count}
        self.equity_curve.append(float(self.balance))
        self.current_step += 1
        return obs, reward, done, info

    def render(self, mode="human"):
        # 必要に応じて可視化ロジックを実装可能
        pass

    def get_equity_curve(self):
        return self.equity_curve


class ResNetFeatures(BaseFeaturesExtractor):
    """
    1D ResNet ベースの特徴抽出器:
    入力時系列(window_size × input_dim)に対して1次元畳み込みと残差接続を使用
    """

    def __init__(
        self, observation_space: gym.spaces.Box, features_dim=128, num_blocks=3
    ):
        """
        Args:
            observation_space: 観測空間
            features_dim: 出力特徴量の次元数
            num_blocks: ResNetブロックの数
        """
        super(ResNetFeatures, self).__init__(
            observation_space, features_dim=features_dim
        )

        self.window_size = observation_space.shape[0]  # 時系列長
        self.input_dim = observation_space.shape[1]  # 入力特徴数

        # 入力層: (batch, window_size, input_dim) -> (batch, features_dim, window_size)
        self.input_projection = nn.Sequential(
            nn.Linear(self.input_dim, features_dim), nn.ReLU()
        )

        # 残差ブロック
        self.res_blocks = nn.ModuleList(
            [ResidualBlock(features_dim) for _ in range(num_blocks)]
        )

        # グローバル平均プーリング
        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, observations):
        # observations の shape: (batch, window_size, input_dim)
        batch_size = observations.size(0)

        # 特徴量次元に射影
        x = self.input_projection(observations)  # (batch, window_size, features_dim)
        x = x.transpose(1, 2)  # (batch, features_dim, window_size) に変換

        # 残差ブロックを通す
        for block in self.res_blocks:
            x = block(x)

        # グローバル平均プーリング
        x = self.pool(x).view(batch_size, -1)  # (batch, features_dim)
        return x


class ResidualBlock(nn.Module):
    """
    1D ResNet の残差ブロック
    """

    def __init__(self, channels, kernel_size=3):
        super(ResidualBlock, self).__init__()
        padding = kernel_size // 2

        self.conv1 = nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=kernel_size,
            padding=padding,
            bias=False,
        )
        self.bn1 = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=kernel_size,
            padding=padding,
            bias=False,
        )
        self.bn2 = nn.BatchNorm1d(channels)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual  # 残差接続
        out = self.relu(out)

        return out


# ──────────────────────────────
# 3. データのダウンロードと環境の作成
if __name__ == "__main__":
    print("データをダウンロード中...")
    # Yahoo Finance から日経225 (^N225) のヒストリカルデータを取得
    start = "1997-01-01"
    end = "2024-01-01"
    train_data = generate_env_data(start, end, ticker="^N225")

    # 窓長(直近○日分のデータを入力とする)
    window_size = 130
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 学習用環境(stable-baselines3 は vectorized environment を要求するため DummyVecEnv でラップ)
    train_env = DummyVecEnv(
        [
            lambda: NikkeiEnv(
                train_data,
                window_size=window_size,
                transaction_cost=0.001,
                risk_limit=0.6,
                trade_penalty=0.000000,
            )
        ]
    )

    policy_kwargs = dict(
        features_extractor_class=ResNetFeatures,
        features_extractor_kwargs=dict(
            features_dim=128, num_blocks=3  # ResNetの残差ブロック数
        ),
    )

    model = DQN(
        "MlpPolicy",
        train_env,
        policy_kwargs=policy_kwargs,
        exploration_final_eps=0.03,
        exploration_fraction=0.3,
        learning_rate=1e-5,
        verbose=1,
        device=device,
    )

    # チェックポイントコールバックの作成
    checkpoint_callback = CheckpointCallback(
        save_freq=10000,
        save_path="./",  # モデルを保存するディレクトリ(存在するか、事前に作成してください)
        name_prefix=f"nikkei_cp_{start}_{end}",
    )

    print("エージェントの学習開始...")
    model.learn(total_timesteps=350000, callback=checkpoint_callback, progress_bar=True)
    print("学習完了!")
    model.save(f"nikkei_cp_{start}_{end}")

参考文献





Source link

Views: 2

RELATED ARTICLES

返事を書く

あなたのコメントを入力してください。
ここにあなたの名前を入力してください

- Advertisment -