209 lines
No EOL
6.4 KiB
Nim
209 lines
No EOL
6.4 KiB
Nim
import tensorCeral
|
|
import arraymancer
|
|
import os
|
|
import streams
|
|
import strformat
|
|
import json
|
|
import std.jsonutils
|
|
import random
|
|
import math
|
|
import sequtils
|
|
import tables
|
|
import sugar
|
|
import stats
|
|
import nimpy
|
|
|
|
#This program handles the training and serialization of a machine learning algorthym
|
|
#Due to current technical limitations around serialization, It cannot be saved unfortunately
|
|
#This code is currently not used.
|
|
var outdata = initTable[int, seq[(float32, int)]]()
|
|
for x in 0 .. 301:
|
|
outdata[x] = @[]
|
|
randomize()
|
|
proc getNumberOfFiles(path : string) : int =
|
|
echo collect(for x in os.walkdir(path) : x)
|
|
return collect(for x in os.walkdir(path) : x).high+1
|
|
|
|
|
|
proc getUnique(a : Tensor[float32]) : int=
|
|
var unique : seq[float]
|
|
for x in a:
|
|
if not unique.contains(x):
|
|
unique.add(x)
|
|
return unique.len()
|
|
network TwoLayersNet:
|
|
layers:
|
|
fc1: Linear(300, 42)
|
|
fc2: Linear(42, 300)
|
|
forward x:
|
|
x.fc1.relu.fc2
|
|
|
|
proc save(network: TwoLayersNet[float32], outy : int) =
|
|
|
|
network.fc1.weight.value.write_npy(&"model/hiddenweight{$outy}.npy")
|
|
network.fc1.bias.value.write_npy(&"model/hiddenbias{$outy}.npy")
|
|
network.fc2.weight.value.write_npy(&"model/outputweight{$outy}.npy")
|
|
network.fc2.bias.value.write_npy(&"model/outputbias{$outy}.npy")
|
|
|
|
proc load*(ctx: Context[Tensor[float32]], inny : int): TwoLayersNet[float32] =
|
|
result.fc1.weight.value = read_npy[float32](&"model/hiddenweight{inny}.npy")
|
|
result.fc1.bias.value = read_npy[float32](&"model/hiddenbias{inny}.npy")
|
|
result.fc2.weight.value = read_npy[float32](&"model/outputweight{inny}.npy")
|
|
result.fc2.bias.value = read_npy[float32](&"model/outputbias{inny}.npy")
|
|
|
|
proc echoUsage() =
|
|
echo "This program requires stdinputs"
|
|
echo " To train a model and save it:"
|
|
echo " -t [tensor1.bin] [tensor2.bin] ..."
|
|
echo " To anaylize its outputs:"
|
|
echo " -s [stats1.json] [stats2.json] ..."
|
|
echo "each program can take between 1 and an infinite number of inputs"
|
|
|
|
when isMainModule:
|
|
var params = commandLineParams()
|
|
params.setlen(1)
|
|
if params.len() == 0 or not ["-t","-s"].contains(params[0]):
|
|
echoUsage()
|
|
quit(1)
|
|
if params[0] == "-s":
|
|
# statistic segement :)
|
|
if params.len() == 1:
|
|
echoUsage()
|
|
quit(1)
|
|
|
|
var plt = pyImport("matplotlib.pyplot")
|
|
let fig = plt.subplots(1, len(params)-1)[1]
|
|
var figit = 0
|
|
|
|
for x in params[1 .. ^1]:
|
|
var newplot : PyObject
|
|
if len(params) == 2:
|
|
#if plt is given 1,1 it is a different type than 1,2+ because python...
|
|
newplot = fig
|
|
else:
|
|
newplot = fig[figit]
|
|
var table : Table[int, seq[(float, int)]]
|
|
fromJson(table, parseJson(readFile(x)))
|
|
var decomp : seq[int]
|
|
var means : seq[float]
|
|
var full : RunningStat
|
|
var tempith : RunningStat
|
|
var rawTemptih : seq[int]
|
|
var rawstats : seq[float]
|
|
|
|
for x in 3 .. 301:
|
|
if table[x].len() == 0:
|
|
continue
|
|
rawTemptih.add(table[x].len())
|
|
tempith.push(table[x].len().float)
|
|
|
|
for x in 3 .. 301:
|
|
#15000 is a magic number, but filters out early training abnormalities
|
|
let temp = table[x].map(x=>x[0]).filter(x => x < 15000)
|
|
|
|
if temp.len() == 0:
|
|
continue
|
|
|
|
else:
|
|
var statistics: RunningStat # must be var
|
|
statistics.push(temp)
|
|
full.push(temp)
|
|
decomp.add(x)
|
|
rawstats.add(statistics.mean)
|
|
|
|
let percent = 1 - ((((tempith.max - temp.len().float ) + tempith.min) * (1 / tempith.max)))
|
|
means.add(statistics.mean*percent)
|
|
|
|
|
|
let fullMean = collect(for x in 0 .. decomp.high : full.mean)
|
|
discard newplot.scatter(decomp, rawstats, label="Raw mean value at each occurance")
|
|
discard newplot.scatter(decomp, means, label="Occurance normalized mean")
|
|
discard newplot.plot(decomp, fullMean, label="Global mean")
|
|
discard newplot.set_ylim(0, 15000)
|
|
discard newplot.set_title(x[1])
|
|
discard newplot.set_ylabel("loss")
|
|
discard newplot.set_xlabel("Amount of different variables")
|
|
discard newplot.legend(loc="upper left")
|
|
figit+=1
|
|
discard plt.show()
|
|
quit()
|
|
if params[0] == "-t":
|
|
if params.len() == 1:
|
|
echoUsage()
|
|
quit(1)
|
|
|
|
var
|
|
ctx = newContext Tensor[float32]
|
|
model = ctx.init(TwoLayersNet)
|
|
optim = model.optimizerSGD(learning_rate = 1e-5'f32)
|
|
|
|
var circular : seq[float32]
|
|
proc addToCache(input : float32) =
|
|
if circular.len() == 10:
|
|
circular.delete(0)
|
|
circular.add(input)
|
|
|
|
proc writey(die = false) {.noconv.} =
|
|
model.save(0)
|
|
if die:
|
|
echo circular
|
|
let outint = getNumberOfFiles("./trainingstats/")
|
|
echo &"writingoutput to: ./trainingstats/stats{outint}.json"
|
|
writeFile(&"./trainingstats/stats{outint}.json", $(outdata.toJson()))
|
|
quit()
|
|
|
|
proc exit() {.noconv.} =
|
|
writey(true)
|
|
|
|
setControlCHook(exit)
|
|
|
|
var prev = -1.0'f32
|
|
var it = 0
|
|
|
|
var dicty = newTable[int, seq[(int, string)]]()
|
|
for x in 0 .. 301:
|
|
dicty[x] = @[]
|
|
|
|
proc train() =
|
|
for tensors in params:
|
|
var input = newFileStream(tensors)
|
|
let decompressed = deSerializeTensors(input)
|
|
let max = (decompressed[1].shape[0])-1
|
|
for county in 0 .. max:
|
|
let
|
|
x1 = ctx.variable(decompressed[2][county .. county+1, _])
|
|
y1 = decompressed[1][county .. county+1, _]
|
|
unique = (x1.value).getUnique()
|
|
#around half
|
|
#if unique >= 110:
|
|
#model = models[1]
|
|
var strike = 0
|
|
for t in 0 .. 50:
|
|
var
|
|
y_pred = model.forward(x1)
|
|
loss = y_pred.mse_loss(y1)
|
|
if t mod 10 == 0:
|
|
echo loss.value[0]
|
|
outdata[unique].add((loss.value[0], it))
|
|
if loss.value[0] == prev:
|
|
if strike == 7:
|
|
break
|
|
strike += 1
|
|
|
|
if loss.value[0] != prev:
|
|
if strike != 0:
|
|
strike = 0
|
|
loss.backprop()
|
|
optim.update()
|
|
prev = loss.value[0]
|
|
#encase of crashing it writes saves every time
|
|
echo "hmm"
|
|
addToCache(prev)
|
|
it+=1
|
|
writey()
|
|
|
|
#train()
|
|
|
|
|
|
train()
|
|
writey() |