Prophet 初学笔记[通俗易懂]
生活随笔
收集整理的這篇文章主要介紹了
Prophet 初学笔记[通俗易懂]
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
本文介紹 Prophet 模型的簡單調用。
(一)日志設置為不輸出
import os
class SuppressStdoutStderr(object):
"""
A context manager for doing a "deep suppression" of stdout and stderr in
Python, i.e. will suppress all print, even if the print originates in a
compiled C/Fortran sub-function.
This will not suppress raised exceptions, since exceptions are printed
to stderr just before a script exits, and after the context manager has
exited (at least, I think that is why it lets exceptions through).
"""
def __init__(self):
# open a pair of null files
self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)]
# save the actual stdout (1) and stderr (2) file descriptors
self.save_fds = (os.dup(1), os.dup(2))
def __enter__(self):
# assign the null pointers to stdout and stderr
os.dup2(self.null_fds[0], 1)
os.dup2(self.null_fds[1], 2)
def __exit__(self, *_):
# reassign the real stdout/stderr back to (1) and (2)
os.dup2(self.save_fds[0], 1)
os.dup2(self.save_fds[1], 2)
# close the null files
os.close(self.null_fds[0])
os.close(self.null_fds[1])
os.close(self.save_fds[0])
os.close(self.save_fds[1])
(二)Prophet 預測模型 class,與官方 Prophet 結構相似,但不繼承
1. 初始化
from datetime import datetime, timedelta
from typing import Tuple
import numpy as np
import pandas as pd
from fbprophet import Prophet
from fbprophet.diagnostics import cross_validation
STD_D_STR = "%Y-%m-%d" # '%m': 月份,'%M': 分鐘
class ProphetPredictor(object):
def __init__(self, x_train: pd.DataFrame,
trg_st_dt: datetime, tm_step: int, his_st_dt: datetime, his_en_dt: datetime,
cv_horizon: str, cv_period: str, cv_initial: str,
n_changepoints=None, changepoint_range=0.7,
yearly_seasonality=False, weekly_seasonality=True, daily_seasonality=False,
holidays=None, seasonality_mode='multiplicative',
seasonality_prior_scale=10, holidays_prior_scale=0, changepoint_prior_scale=0.05):
"""
initialisation
:param x_train: 數據集, ['ds', 'y']
:param trg_st_dt: 預測開始日期
:param tm_step: 預測時間間隔
:param his_st_dt: 訓練集開始日期
:param his_en_dt: 訓練集結束日期 todo: 允許使用的訓練集結束日期必須嚴格小于預測開始日期
:param cv_horizon: 交叉驗證 horizon 參數, '3 days' 格式
:param cv_period: 交叉驗證 period 參數, '3 days' 格式
:param cv_initial: 交叉驗證 initial 參數, '3 days' 格式
:param n_changepoints: Changepoint 最大數量
:param changepoint_range: Changepoint 在歷史數據中出現的時間范圍
:param yearly_seasonality: 年周期性
:param weekly_seasonality: 周周期性
:param daily_seasonality: 日周期性
:param holidays: 節假日或特殊日期
:param seasonality_mode: 季節模型方式, {'additive', 'multiplicative'}
:param seasonality_prior_scale: 改變周期性影響因素的強度
:param holidays_prior_scale: 改變假日模型的強度
:param changepoint_prior_scale: 設定自動突變點選擇的靈活性,值越大越容易出現 Changepoint
"""
# Prophet 模型參數
self.params = {
"n_changepoints": n_changepoints,
"changepoint_range": changepoint_range,
"yearly_seasonality": yearly_seasonality,
"weekly_seasonality": weekly_seasonality,
"daily_seasonality": daily_seasonality,
"holidays": holidays,
"seasonality_mode": seasonality_mode,
"seasonality_prior_scale": seasonality_prior_scale,
"holidays_prior_scale": holidays_prior_scale,
"changepoint_prior_scale": changepoint_prior_scale
}
self.trg_st_dt = datetime.strptime(trg_st_dt, STD_D_STR) if isinstance(trg_st_dt, str) else trg_st_dt
self.tm_step = tm_step
self.trg_en_dt = self.trg_st_dt + timedelta(days=tm_step - 1)
self.his_st_dt = datetime.strptime(his_st_dt, STD_D_STR) if isinstance(his_st_dt, str) else his_st_dt
self.his_en_dt = datetime.strptime(his_en_dt, STD_D_STR) if isinstance(his_en_dt, str) else his_en_dt
# 提前期 = 預測開始日期 - 歷史數據最晚日期 - 1 (認為次日預測提前期為0)
self.ahead = (self.trg_st_dt - self.his_en_dt).days - 1
self.x_train = x_train[['ds', 'y']].copy()
self.model = None
self.cv_horizon = cv_horizon
self.cv_period = cv_period
self.cv_initial = cv_initial
self.map_err = 100
2. 模型訓練
def fit(self):
"""
模型訓練
:return: 無
"""
self.x_train = self.x_train[
(datetime.strftime(self.his_st_dt, STD_D_STR) <= self.x_train['ds'])
& (self.x_train['ds'] <= datetime.strftime(self.his_en_dt, STD_D_STR))].reset_index(drop=True)
self.model = Prophet(**self.params)
with SuppressStdoutStderr():
self.model.fit(df=self.x_train)
3. 交叉驗證
def cv(self, params=None) -> float:
"""
交叉驗證
:param params: 模型參數,網格尋參時不為 None
:return: map_err: 平均絕對百分比誤差(MAPE)
"""
params_ = params if params else self.params
self.model = Prophet(**params_)
with SuppressStdoutStderr():
self.model.fit(self.x_train)
cv_result = cross_validation(self.model,
horizon=self.cv_horizon, period=self.cv_period, initial=self.cv_initial)
# 平均絕對百分比誤差(MAPE)
map_err = np.mean(np.abs(cv_result['yhat'] - cv_result['y']) / cv_result['y']) * 100
return map_err
4. 網格尋參
def grid_search(self) -> pd.DataFrame:
"""
網格尋參
:return: df_search: 尋參記錄
"""
list_n_changepoints = [i for i in range(2, 7)]
list_changepoint_range = [i / 10 for i in range(5, 10)]
list_seasonality_mode = ["additive", "multiplicative"]
list_seasonality_prior_scale = [0.1, 0.5, 1, 5, 10]
list_changepoint_prior_scale = [0.1, 0.5, 1, 5, 10]
list_search = []
for nc in list_n_changepoints:
for cr in list_changepoint_range:
for sm in list_seasonality_mode:
for sps in list_seasonality_prior_scale:
for cps in list_changepoint_prior_scale:
params = {
"n_changepoints": nc,
"changepoint_range": cr,
"yearly_seasonality": False,
"weekly_seasonality": True,
"daily_seasonality": False,
"holidays": None,
"seasonality_mode": sm,
"seasonality_prior_scale": sps,
"holidays_prior_scale": 0,
"changepoint_prior_scale": cps
}
score = self.cv(params=params)
list_search.append([nc, cr, sm, sps, cps, score])
if score < self.map_err:
self.map_err, self.params = score, params
print("current best mse: {0}; current params: {1}".format(round(self.map_err, 4),
params))
df_search = pd.DataFrame(data=list_search,
columns=['n_changepoints', 'changepoint_range', 'seasonality_mode',
'seasonality_prior_scale', 'changepoint_prior_scale', 'mse'])
return df_search
5. 預測
def predict(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
預測
:return: df_predict_y: 預測數據
:return: df_history_y: 歷史數據
"""
df_future = self.model.make_future_dataframe(periods=self.ahead + self.tm_step,
include_history=True).dropna().reset_index(drop=True)
df_predict = self.model.predict(df=df_future)
df_predict['ds'] = df_predict['ds'].apply(lambda x: str(x.date()))
df_predict_y = df_predict[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(self.tm_step)
df_history_y = df_predict[['ds', 'yhat']][: - (self.tm_step + self.ahead)]
return df_predict_y, df_history_y
(三)其他嘗試
添加季節性組件:
self.model.add_seasonality(name='sin', period=2 * np.pi / 7, fourier_order=1)
(四)模型調用示例
1. 網格尋參
# 初始化
prophet_predictor = ProphetPredictor(x_train=df_input,
trg_st_dt=trg_st_dt, tm_step=3, his_st_dt=his_st_dt, his_en_dt=his_en_dt,
cv_horizon='3 days', cv_period='3 days', cv_initial='135 days')
# 網格尋參
dts_search = datetime.now()
df_search = prophet_predictor.grid_search()
print("df_search:\n", df_search, '\n')
dte_search = datetime.now()
tm_search = round((dte_search - dts_search).seconds + (dte_search - dts_search).microseconds / (10 ** 6), 3)
print("grid search time: {} s".format(tm_search), '\n')
# 訓練
dts_train = datetime.now()
prophet_predictor.fit()
dte_train = datetime.now()
tm_train = round((dte_train - dts_train).seconds + (dte_train - dts_train).microseconds / (10 ** 6), 3)
print("train time: {} s".format(tm_train), '\n')
# 預測
df_predict_y, df_history_y = prophet_predictor.predict()
print("df_predict_y:\n", df_predict_y, '\n')
print("df_history_y:\n", df_history_y, '\n')
2. 給定參數
# 參數設定
params = {
"n_changepoints": 2,
"changepoint_range": 0.7,
"seasonality_mode": 'additive',
"seasonality_prior_scale": 0.5,
"changepoint_prior_scale": 10,
"yearly_seasonality": False,
"weekly_seasonality": True,
"daily_seasonality": False
}
# 初始化
prophet_predictor = ProphetPredictor(x_train=df_input,
trg_st_dt=trg_st_dt, tm_step=3, his_st_dt=his_st_dt, his_en_dt=his_en_dt,
cv_horizon='3 days', cv_period='3 days', cv_initial='135 days',
**params)
# 訓練
dts_train = datetime.now()
prophet_predictor.fit()
dte_train = datetime.now()
tm_train = round((dte_train - dts_train).seconds + (dte_train - dts_train).microseconds / (10 ** 6), 3)
print("train time: {} s".format(tm_train), '\n')
# 預測
df_predict_y, df_history_y = prophet_predictor.predict()
print("df_predict_y:\n", df_predict_y, '\n')
print("df_history_y:\n", df_history_y, '\n')
參考資料:
https://www.cnblogs.com/fulu/p/13329656.html
https://www.cnblogs.com/zhazhaacmer/p/13786940.html
總結
以上是生活随笔為你收集整理的Prophet 初学笔记[通俗易懂]的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: SAP UI5 XML view lif
- 下一篇: SAP UI5 Manifest fil