import numpy as np
from flask import Flask, request, jsonify, render_template
import pickle
from sklearn.preprocessing import normalize
from werkzeug.utils import secure_filename
import scipy.signal as signal
import pandas as pd
import os

##########

with open('model_a1.bin', 'rb') as f_in:
    model_n = pickle.load(f_in)

"""
def intensityNormalisationFeatureScaling(da, dtype):
    max = da.max()
    min = da.min()

    return ((da - min) / (max - min)).astype(dtype)
"""

def filter_data(da, fs):

    """Filters the ECG data with a highpass at 0.1Hz and a bandstop around 50Hz (+/-2 Hz)"""

    b_dc, a_dc = signal.butter(4, (0.1/fs*2), btype='highpass')
    b_50, a_50 = signal.butter(4, [(48/fs*2),(52/fs*2)], btype='stop')

    da = signal.lfilter(b_dc, a_dc, da)
    da = signal.lfilter(b_50, a_50, da)

    return da



# Return difference array
def return_diff_array_table(array, dur):
    for idx in range(array.shape[1]-dur):
        before_col = array[:,idx]
        after_col = array[:,idx+dur]
        new_col = ((after_col - before_col)+1)/2
        new_col = new_col.reshape(-1,1)
        if idx == 0:
            new_table = new_col
        else :
            new_table = np.concatenate((new_table, new_col), axis=1)
#For concat add zero padding
    padding_array = np.zeros(shape=(array.shape[0],dur))
    new_table = np.concatenate((padding_array, new_table), axis=1)
    return new_table


#Concat
def return_merge_diff_table(df, diff_dur):
    fin_table = df.reshape(-1,187,1,1)
    for dur in diff_dur:
        temp_table = return_diff_array_table(df, dur)
        fin_table = np.concatenate((fin_table, temp_table.reshape(-1,187,1,1)), axis=2)
    return fin_table


def predict_endpoint(features):
    #features = intensityNormalisationFeatureScaling(features, float)
    features = filter_data(features, 300)
    X = return_merge_diff_table(features, diff_dur=[1])
    preds = model_n.predict(X)
    return preds




app = Flask('__name__')



@app.route('/')
def home():
    return render_template('index.html')

@app.route('/predict',methods=['POST'])
def predict():
    
    Age=float(request.form['Age'])
    
    Gender=float(request.form['Gender'])
    
    f = request.files['signal']
    f.save(secure_filename(f.filename))

    df = pd.read_csv(f.filename, header=None)

    df_2 = df.iloc[3:4,0:187] 

    features_t = np.array(df_2)

    pred = predict_endpoint(features_t) 

    print(pred[0])


   

    if float(pred[0][0]) > float(pred[0][1]):
        os.remove(f.filename)
        return render_template('Good.html', prediction=pred[0][0])
    else:
        return render_template('Bad.html', prediction2=pred[0][1])

    

if __name__ == "__main__":
    #app.run(debug=True, host='0.0.0.0', port=9696)
    app.run(debug=True)
