技術指標用 RSI 及 MACD

1 找股票代號

def getListedCode(typ="twse",codeLength=4):
    import sqlite3
    sqlite_file = "list.db" #上市上櫃下市 股票代號列表 資料庫檔案
    conn = sqlite3.connect(sqlite_file)
    c = conn.cursor()
    out = conn.execute("SELECT code, type FROM StockList WHERE type LIKE ? AND length(code) == ? ",("twse",codeLength)) 
    # 取 代號4碼 類別上市 股票代號
    codeList = []
    for i, row in enumerate(out):
        codeList.append(row[0])
    conn.close()
    print("取得股票代號總數:", len(codeList)) #codeList 為取出股票代號列表
    return codeList

2 找各股技術指標資料並將資料分為訓練及回測

技術分析程式碼

def taCalc(codeList,dbName="TechnicalAnalysis.db",tableName="twse"):
    import sqlite3
    conn = sqlite3.connect(dbName)
    c = conn.cursor()
    for code in codeList:
        print(code)
        out = conn.execute("SELECT code, close, date, MACD, RSI FROM twse WHERE code LIKE {} ORDER BY date".format(code))
        dalist = []
        for code, close, date, MACD, RSI in out:
            dalist.append([code, close, date, MACD, RSI])
        data5day = []
        for i, val in enumerate(dalist[:-6]):
            if dalist[i+5][1]-dalist[i][1] > 0:
                data5day.append([dalist[i][1],dalist[i][3],dalist[i][4],1])
            if dalist[i+5][1]-dalist[i][1] < 0:
                data5day.append([dalist[i][1],dalist[i][3],dalist[i][4],0])
        trainlines = int(len(data5day) * 0.7)
        nameofTRAINsample = 'data5day'+str(code)+'.train'
        with open(nameofTRAINsample, 'a') as the_file:
            for i in data5day[:trainlines]:
                string = str(i[3])+' 1:'+str(i[0])+' 2:'+str(i[1])+' 3:'+str(i[2])+'\n'
                the_file.write(string)
        nameofTESTsample = 'data5day'+str(code)+'.test'
        with open(nameofTESTsample, 'a') as the_file:
            for i in data5day[trainlines+1:]:
                string = str(i[3])+' 1:'+str(i[0])+' 2:'+str(i[1])+' 3:'+str(i[2])+'\n'
                the_file.write(string)
    conn.close()

3 訓練模型

def train(code):
    import xgboost as xgb
    # read in data
    nameofTRAINsample = 'data5day'+str(code)+'.train'
    dtrain = xgb.DMatrix(nameofTRAINsample)
    # specify parameters via map
    param = {'objective':'binary:logistic' }
    num_round = 2
    bst = xgb.train(param, dtrain, num_round)
    nameofMODEL = 'data5day'+str(code)+'.model'
    bst.save_model('models/'+nameofMODEL)

4 取得回測答案

def getAns(code):
    nameofTESTsample = 'data5day'+str(code)+'.test'
    with open('data5day1301.test', 'r') as f:
        content = f.readlines()
    answer = []
    for i,val in enumerate(content):
        answer.append(int(content[i][0]))
    return answer

5 計算準確度

def probs(code,answer,thresh=0.5):
    import xgboost as xgb
    nameofMODEL = 'data5day'+str(code)+'.model'
    bst = xgb.Booster({'nthread': 4})
    bst.load_model('models/'+nameofMODEL)
    nameofTESTsample = 'data5day'+str(code)+'.test'
    dtest = xgb.DMatrix(nameofTESTsample)
    preds = bst.predict(dtest)
    #preds = list(preds.tolist())
    with open(nameofTESTsample, 'r') as f:
        content = f.readlines()
    answer = []
    for i,val in enumerate(content):
        answer.append(int(content[i][0]))
    correct = 0
    incorrect = 0
    for pred, ans in zip(preds, answer):
        if ((pred > thresh) & ans == 1) or ((pred < thresh) & ans == 0):
            correct = correct + 1
        else:
            incorrect = incorrect + 1
    print("Correct:",correct,"Incorrect:",incorrect,"Accuracy:",correct/len(answer),"PredNo:",len(preds))

6 預測的部份(之後用)

def predictAnswer(code,thresh=0.5):
    import xgboost as xgb
    nameofMODEL = 'data5day'+str(code)+'.model'
    bst = xgb.Booster({'nthread': 4})
    bst.load_model('models/'+nameofMODEL)
    nameofTESTsample = 'data5day'+str(code)+'.test'
    dtest = xgb.DMatrix(nameofTESTsample)
    preds = bst.predict(dtest)
    if pred > thresh:
        return 1
    if pred < thresh:
        return 0
    else:
        return -1

7 執行的程式碼

codeList = getListedCode()
taCalc(codeList)
for code in codeList:
    try:
        print(code)
        answer = getAns(code)
        train(code)
        probs(code,answer)
    except Exception:
        print("Something went wrong HERE!!!!!!!!")

準確度列表