見出し画像

Pythonのデータ分析を学んだ初心者女性(50代)が金の価格を予測してみました

Pythonを使ったデータ分析を6カ月学び「金価格の予測」をしてみました。


実行環境

Python
Windows
Chrome
Google Collaboratory

機械学習の手法

SARIMAモデルを使った時系列解析

自分のレベル

文系上がり、女性、50代、プログラミング言語未経験の初心者

作成開始

情報をダウンロード

Nasdaqによって提供されている金の価格変動データ

リンク先の右上ダウンロードボタンをクリックしarchive.zipを取得

ダウンロードしたzipフォルダを解凍しアップロード

Google Collaboratoryを開きarchive.zipをアップロードし解凍

!unzip /content/archive.zip

実装するとデータのcsvファイルが取り込まれる

各取引日の主要な財務指標が含まれている
データセットは以下の列で構成

  • Date: 記録された各取引日の固有の日付

  • Close: 該当日の金の終値

  • Volume: 該当日の金取引量

  • Open: 始値: 当該日の金の始値

  • High: 取引日に記録された金の最高値

  • Low:取引日に記録された金の最安値

データの前処理

indexをDateにセットしindexのnp型を文字列から日付に変更

import warnings
import itertools
import pandas as pd
import numpy as np
import statsmodels.api as sm
import matplotlib.pyplot as plt
import datetime # datetime
%matplotlib inline

# 1.データの読み込み
gold_df = pd.read_csv("/content/goldstock.csv", index_col=0)

gold_df = gold_df.set_index('Date')
# indexの型を文字列から日付に変更する
gold_df.index = pd.to_datetime(gold_df.index)
gold_df


重複がないか確認し、重複を削除

if gold_df.index.is_unique==False:
  print("重複あり")

gold_df = gold_df.loc[~gold_df.index.duplicated(), :]
gold_df

日付のリサンプリング

gold_df = gold_df.asfreq("D", method="ffill")
gold_df

不要なColumの削除とデータ整理

gold_df = gold_df.drop(["Volume", "Open", "High",	"Low"], axis=1)

train = gold_df.iloc
train.isna().sum()
train


パラメーターの決定とモデル作成

SARIMAモデルを用いて時系列解析
※ 周期は日ごとのデータであることも考慮してs=7

def selectparameter(DATA, s):
    p = d = q = range(0, 2)
    pdq = list(itertools.product(p, d, q))
    seasonal_pdq = [(x[0], x[1], x[2], s) for x in list(itertools.product(p, d, q))]
    parameters = []
    BICs = np.array([])
    for param in pdq:
        for param_seasonal in seasonal_pdq:
            try:
                mod = sm.tsa.statespace.SARIMAX(DATA,
                                                order=param,
                                                seasonal_order=param_seasonal)
                results = mod.fit()
                parameters.append([param, param_seasonal, results.bic])
                BICs = np.append(BICs, results.bic)
            except:
                continue
    return parameters[np.argmin(BICs)]

# orderはselectparameter関数の0インデックス, seasonal_orderは1インデックスに格納されています
best_params = selectparameter(train, 7)
SARIMA_gold = sm.tsa.statespace.SARIMAX(train["Close"],
    order=best_params[0],
    seasonal_order=best_params[1]).fit()

予測

pred = SARIMA_gold.predict("2024-12-01", "2024-12-01")

グラフの可視化

plt.plot(gold_df)

plt.plot(pred, "r")
plt.show()


作成を終えて

一連の学習を終えて初めて作成しましたがそれぞれ学習したすべてが
頭の中でつながり、より深く理解することができました。
アウトプットをこれからも継続していき自分のスキルにしていこうと思います。



この記事が気に入ったらサポートをしてみませんか?