Ich trainiere Elman-Netzwerk mit Neurolab Python-Bibliothek und mein Netz funktioniert nicht richtig.Wie man Überanpassung in Elman neuronales Netz repariert?
- Trainingseingangsvektoren: http://pastebin.com/urQX2eEA
- Trainingszielvektor: http://pastebin.com/1JQh1xZv
- Mustervektor, der ein Netzwerk zu testen: http://pastebin.com/jprZhBHa
Aber während der Ausbildung es zeigt auch große Fehler:
Epoch: 100; Error: 23752443150.672318;
Epoch: 200; Error: 284037904.0305649;
Epoch: 300; Error: 174736152.57367808;
Epoch: 400; Error: 3318952.136089243;
Epoch: 500; Error: 299017.4471083774;
Epoch: 600; Error: 176600.0906688521;
Epoch: 700; Error: 176599.32080188877;
Epoch: 800; Error: 185178.21132511366;
Epoch: 900; Error: 177224.2950528976;
Epoch: 1000; Error: 176632.86797784362;
The maximum number of train epochs is reached
Als Ergebnis schlägt Netzwerk beim Testen sa fehl mple. Original-MICEX:
1758,97
1626,18
1688,34
1609,19
1654,55
1669
1733,17
1642,97
1711,53
1771,05
Prognostizierte MICEX:
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
[ 1237.59155306]
Hier mein Code:
import neurolab as nl
import numpy as np
# Create train samples
MICEX = [421.08,455.44,430.3,484,515.17,468.85,484.73,514.71,551.72,591.09,644.64,561.78,535.4,534.84,502.81,549.28,611.03,632.97,570.76,552.22,575.74,635.38,598.04,593.88,603.89,639.98,700.65,784.28,892.5,842.52,944.55,1011,1171.44,1320.83,1299.19,1486.85,1281.5,1331.39,1380.24,1448.72,1367.24,1426.83,1550.71,1693.47,1656.97,1655.19,1698.08,1697.28,1570.34,1665.96,1734.42,1677.02,1759.44,1874.73,1850.64,1888.86,1574.33,1660.42,1628.43,1667.35,1925.24,1753.67,1495.33,1348.92,1027.66,731.96,611.32,619.53,624.9,666.05,772.93,920.35,1123.38,971.55,1053.3,1091.98,1197.2,1237.18,1284.95,1370.01,1419.42,1332.64,1450.15,1436.04,1332.62,1309.31,1397.12,1368.9,1440.3,1523.39,1565.52,1687.99,1723.42,1777.84,1813.59,1741.84,1666.3,1666.59,1705.18,1546.05,1366.54,1498.6,1499.62,1402.02,1510.91,1594.32,1518.29,1474.14,1312.24,1386.89,1406.36,1422.38,1459.01,1423.46,1405.19,1477.87,1547.18,1487.46,1440.02,1386.69,1343.99,1331.24,1377.6,1364.54,1463.13,1509.62,1479.35,1503.39,1454.05,1444.71,1369.29,1306.01,1432.03,1476.38,1379.61,1400.71,1411.07,1488.47,1533.68,1396.61,1647.69]
Brent = [26.8,28.16,28.59,30.05,28.34,27.94,28.76,30.48,29.51,33.01,32.36,35.12,36.98,33.51,41.6,39.33,47.08,48.78,44.03,40.24,45.87,50.14,53.05,49.33,49.83,54.85,59.7,66.68,62.56,58.35,53.41,58.87,65.43,60.05,64.94,72,69,73.28,75.16,69.64,61.37,56.97,64.42,60.13,57.21,60.66,68.42,67.28,68.82,73.26,78.05,73.53,81.75,91.14,88,93.85,91.98,100.04,100.51,112.71,128.27,140.3,123.96,115.17,98.96,65.6,53.49,45.59,45.93,45.84,48.68,50.64,65.8,69.42,71.52,69.32,68.92,75.09,78.36,77.93,71.18,78.03,82.17,87.35,74.6,74.66,78.26,74.42,82.11,83.26,85.45,94.59,100.56,112.1,117.17,126.03,116.68,111.8,117.54,114.49,102.15,109.19,110.37,107.22,111.16,123.04,122.8,119.47,101.62,97.57,104.62,114.92,112.14,108.4,111.17,111.11,114.56,111,109.89,101.74,100.15,101.5,107.7,114.45,108.2,108.9,110.11,110.9,105.79,108.65,107.7,108.14,109.49,112.4,105.52,103.11,94.8,85.96,68.34,57.54,52.95]
DJIA = [8850.26,8985.44,9233.8,9415.82,9275.06,9801.12,9782.46,10453.92,10488.07,10583.92,10357.7,10225.57,10188.45,10435.48,10139.71,10173.92,10080.27,10027.47,10428.02,10783.01,10489.94,10766.23,10503.76,10192.51,10467.48,10274.97,10640.91,10481.6,10568.7,10440.07,10805.87,10717.5,10864.86,10993.41,11109.32,11367.14,11168.31,11150.22,11185.68,11381.15,11679.07,12080.73,12221.93,12463.15,12621.69,12268.63,12354.35,13062.91,13627.64,13408.62,13211.99,13357.74,13895.63,13930.01,13371.72,13264.82,12650.36,12266.39,12262.89,12820.13,12638.32,11350.01,11378.02,11543.96,10850.66,9325.01,8829.04,8776.39,8000.86,7062.93,7608.92,8168.12,8500.33,8447,9171.61,9496.28,9712.28,9712.73,10344.84,10428.05,10067.33,10325.26,10856.63,11008.61,10136.63,9774.02,10465.94,10014.72,10788.05,11118.49,11006.02,11577.51,11891.93,12226.34,12319.73,12810.54,12569.79,12414.34,12143.24,11613.53,10913.38,11955.01,12045.68,12217.56,12632.91,12952.07,13212.04,13213.63,12393.45,12880.09,13008.68,13090.84,13437.13,13096.46,13025.58,13104.14,13860.58,14054.49,14578.54,14839.8,15115.57,14909.6,15499.54,14810.31,15129.67,15545.75,16086.41,16576.66,15698.85,16321.71,16457.66,16580.84,16717.17,16826.6,16563.3,17098.45,17042.9,17390.52,17828.24,17823.07,17164.95]
CAC_40 = [2991.75,3084.1,3210.27,3311.42,3134.99,3373.2,3424.79,3557.9,3638.44,3725.44,3625.23,3674.28,3669.63,3732.99,3647.1,3594.28,3640.61,3706.82,3753.75,3821.16,3913.69,4027.16,4067.78,3908.93,4120.73,4229.35,4451.74,4399.36,4600.02,4436.45,4567.41,4715.23,4947.99,5000.45,5220.85,5188.4,4930.18,4965.96,5009.42,5165.04,5250.01,5348.73,5327.64,5541.76,5608.31,5516.32,5634.16,5930.77,6104,6054.93,5751.08,5662.7,5715.69,5841.08,5667.5,5614.08,4871.8,4790.66,4707.07,4996.54,5014.28,4425.61,4392.36,4485.64,4027.15,3487.07,3262.68,3217.97,2962.37,2693.96,2803.94,3159.85,3273.55,3138.93,3426.27,3657.72,3794.96,3601.43,3684.75,3936.33,3737.19,3708.8,3974.01,3816.99,3507.56,3442.89,3643.14,3476.18,3715.18,3833.5,3610.44,3804.78,4005.5,4110.35,3989.18,4106.92,4006.94,3980.78,3672.77,3256.76,2981.96,3242.84,3154.62,3159.81,3298.55,3447.94,3423.81,3212.8,3005.48,3196.65,3291.66,3413.07,3354.82,3429.27,3557.28,3641.07,3732.6,3723,3731.42,3856.75,3948.59,3738.91,3992.69,3933.78,4143.44,4299.89,4295.21,4295.95,4165.72,4408.08,4391.5,4487.39,4519.57,4422.84,4246.14,4381.04,4426.76,4233.09,4390.18,4263.55,4604.25]
SSEC = [1576.26,1486.02,1476.74,1421.98,1367.16,1348.3,1397.22,1497.04,1590.73,1675.07,1741.62,1595.59,1555.91,1399.16,1386.2,1342.06,1396.7,1320.54,1340.77,1266.5,1191.82,1306,1181.24,1159.15,1060.74,1080.94,1083.03,1162.8,1155.61,1092.82,1099.26,1161.06,1258.05,1299.03,1298.3,1440.22,1641.3,1672.21,1612.73,1658.64,1752.42,1837.99,2099.29,2675.47,2786.34,2881.07,3183.98,3841.27,4109.65,3820.7,4471.03,5218.82,5552.3,5954.77,4871.78,5261.56,4383.39,4348.54,3472.71,3693.11,3433.35,2736.1,2775.72,2397.37,2293.78,1728.79,1871.16,1820.81,1990.66,2082.85,2373.21,2477.57,2632.93,2959.36,3412.06,2667.74,2779.43,2995.85,3195.3,3277.14,2989.29,3051.94,3109.11,2870.61,2592.15,2398.37,2637.5,2638.8,2655.66,2978.83,2820.18,2808.08,2790.69,2905.05,2928.11,2911.51,2743.47,2762.08,2701.73,2567.34,2359.22,2468.25,2333.41,2199.42,2292.61,2428.49,2262.79,2396.32,2372.23,2225.43,2103.63,2047.52,2086.17,2068.88,1980.12,2269.13,2385.42,2365.59,2236.62,2177.91,2300.59,1979.21,1993.8,2098.38,2174.66,2141.61,2220.5,2115.98,2033.08,2056.3,2033.31,2026.36,2039.21,2048.33,2201.56,2217.2,2363.87,2420.18,2682.83,3234.68,3210.36]
Brent_sample = [62.48, 55.1, 66.8, 65.19, 63.14, 51.85, 53.12, 48.44, 49.5, 44.5]
DJIA_sample = [18132.7, 17776.12, 17840.52, 18010.68, 17619.51, 17689.86, 16528.03, 16284.7, 17663.54, 17719.92]
CAC_40_sample = [4922.99, 5031.47, 5042.84, 5084.08, 4812.24, 5081.73, 4652.34, 4453.91, 4880.18, 4951.83]
SSEC_sample = [3310.3, 3747.9, 4441.66, 4611.74, 4277.22, 3663.73, 3205.99, 3052.78, 3382.56, 3445.4]
MICEX = np.asarray(MICEX)
Brent = np.asarray(Brent)
DJIA = np.asarray(DJIA)
CAC_40 = np.asarray(CAC_40)
SSEC = np.asarray(SSEC)
Brent_sample = np.asarray(Brent_sample)
DJIA_sample = np.asarray(DJIA_sample)
CAC_40_sample = np.asarray(CAC_40_sample)
SSEC_sample = np.asarray(SSEC_sample)
size = len(MICEX)
inp = np.vstack((Brent, DJIA, CAC_40, SSEC)).T
tar = MICEX.reshape(size, 1)
smp = np.vstack((Brent_sample, DJIA_sample, CAC_40_sample, SSEC_sample)).T
# Create network with 2 layers and random initialized
net = nl.net.newelm(
[[min(inp[:, 0]), max(inp[:, 0])],
[min(inp[:, 1]), max(inp[:, 1])],
[min(inp[:, 2]), max(inp[:, 2])],
[min(inp[:, 3]), max(inp[:, 3])]
],
[46, 1],
[nl.trans.TanSig(), nl.trans.PureLin()] # SatLinPrm(0.00000001, 421.08, 1925.24)
)
# Set initialized functions and init
net.layers[0].initf = nl.init.InitRand([-0.1, 0.1], 'wb')
net.layers[1].initf = nl.init.InitRand([-0.1, 0.1], 'wb')
net.init()
# Changing training method
# net.trainf = nl.train.train_cg
# Train network
error = net.train(inp, tar, epochs=1000, show=100, goal=0.02)
# Simulate network
out = net.sim(smp)
print(smp)
print('MICEX predictions for the next 10 periods:\n', out)
Weiß jemand eine Lösung für dieses Problem?
Ich sehe keine Fehler. Was funktioniert nicht? Der Fehler wird kleiner und scheint zu funktionieren. Eine andere Sache, die man bemerkt, ist, dass der Fehler aufhört, kleiner zu werden, es scheint, dass du dein Trainingssatz überstülpst (https: // en.wikipedia.org/wiki/Overfitting), und es wird nicht gut verallgemeinern. –
@ john-carpenter Danke! Und wie kann ich diese Überanpassung in meinem Fall beheben? –
gibt es ein paar Methoden, eine ist zu erkennen, dass Ihr Fehler nicht mehr auf Ihrem Trainingssatz untergeht und das Training beendet. Der andere besteht darin, einen Validierungssatz zu verwenden, mit dem Sie nicht trainieren, aber dennoch den Fehler messen. Wenn die Fehlerrate auf dem Validierungssatz nicht mehr abfällt, stoppen Sie. [Hier ist ein SO-Link für weitere Details] (https://stackoverflow.com/questions/2976452/whats-ist-die-difference-between-train-validation-and-test-set-in-neural-networ) –