From 95132f18ef5c909f077153bc863d03cc5bacbffd Mon Sep 17 00:00:00 2001 From: amilabluetelecoms <amila@bluetelecoms.com> Date: Tue, 15 Nov 2022 02:58:56 +0000 Subject: [PATCH] lstm --- lstm_model.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/lstm_model.py b/lstm_model.py index e69de29..3037d0f 100644 --- a/lstm_model.py +++ b/lstm_model.py @@ -0,0 +1,29 @@ +from pandas import DataFrame +from pandas import concat + +def series_to_supervised(data, n_in=1, n_out=1, dropnan=True): + n_vars = 1 if type(data) is list else data.shape[1] + df = DataFrame(data) + cols, names = list(), list() + # input sequence (t-n, ... t-1) + for i in range(n_in, 0, -1): + cols.append(df.shift(i)) + names += [('var%d(t-%d)' % (j+1, i)) for j in range(n_vars)] + # forecast sequence (t, t+1, ... t+n) + for i in range(0, n_out): + cols.append(df.shift(-i)) + if i == 0: + names += [('var%d(t)' % (j+1)) for j in range(n_vars)] + else: + names += [('var%d(t+%d)' % (j+1, i)) for j in range(n_vars)] + # put it all together + agg = concat(cols, axis=1) + agg.columns = names + # drop rows with NaN values + if dropnan: + agg.dropna(inplace=True) + return agg + + + + -- GitLab