最小2乗法の考え方を用いるとパーセプトロンの欠点をある程度解決した学習
法を導くことができます。これは、1960年にWidrowとHoff が提案した
ADALINE(Adaptive Linear Neuron)というモデルで、閾値論理素子の線形の部
分
(35) |
まずは、前回のボール投げの記録を予測する最小2乗法のプログラムを参考に 2種類のアヤメのデータを識別するADALINEのプログラムを作ってみましょう。
アヤメ科のイリス・ベルシコロールとイリス・ベルジニカという2種類の花か ら、がくの長さ()、がくの幅()、花弁の長さ()、花弁の幅() の4種類の特徴を計測したデータがあります。データ数は、各花とも50個づつ あります。これは、Fisherという有名な統計学者が1936年に線形判別関数を適 用した有名なデータで、それ以来パターン認識の手法を確認する例として頻繁 に用いられています。
ここでは、教師の答えとしてイリス・ベルシコロールには()を与え、イ
リス・ベルジニカには()を与えるものとします。がくの長さ()、が
くの幅()、花弁の長さ()、花弁の幅()のデータから教師の答えを
予測するためのADALINEモデルは、
(36) |
(37) |
学習したADALINEモデルを用いてアヤメの花を識別するには、学習したADALINE モデルに計測した特徴量を代入し、教師の答えの予測値を求め、それがに 近ければイリス・ベルシコロールと判定し、に近ければイリス・ベルジニ カと判定すれば良いことになります。
具体的なプログラムは、以下のようになります。
#include <stdio.h> #include <stdlib.h> #define frand() rand()/((double)RAND_MAX) #define NSAMPLE 100 #define XDIM 4 main() { FILE *fp; double t[NSAMPLE]; double x[NSAMPLE][XDIM]; double a[XDIM+1]; int i, j, l; double y, err, mse; double derivatives[XDIM+1]; double alpha = 0.1; /* Learning Rate */ /* Open Data File */ if ((fp = fopen("niris.dat","r")) == NULL) { fprintf(stderr,"File Open Fail\n"); exit(1); } /* Read Data */ for (l = 0; l < NSAMPLE; l++) { /* Input input vectors */ for (j = 0; j < XDIM; j++) { fscanf(fp,"%lf",&(x[l][j])); } /* Set teacher signal */ if (l < 50) t[l] = 1.0; else t[l] = 0.0; } /* Close Data File */ fclose(fp); /* Print the data */ for (l = 0; l < NSAMPLE; l++) { printf("%3d : %8.2f ", l, t[l]); for (j = 0; j < XDIM; j++) { printf("%8.2f ", x[l][j]); } printf("\n"); } /* Initialize the parameters by random number */ for (j = 0; j < XDIM+1; j++) { a[j] = (frand() - 0.5); } /* Open output file */ fp = fopen("mse.out","w"); /* Learning the parameters */ for (i = 1; i < 1000; i++) { /* Learning Loop */ /* Compute derivatives */ /* Initialize derivatives */ for (j = 0; j < XDIM+1; j++) { derivatives[j] = 0.0; } /* update derivatives */ for (l = 0; l < NSAMPLE; l++) { /* prediction */ y = a[0]; for (j = 1; j < XDIM+1; j++) { y += a[j] * x[l][j-1]; } /* error */ err = t[l] - y; /* printf("err[%d] = %f\n", l, err);*/ /* update derivatives */ derivatives[0] += err; for (j = 1; j < XDIM+1; j++) { derivatives[j] += err * x[l][j-1]; } } for (j = 0; j < XDIM+1; j++) { derivatives[j] = -2.0 * derivatives[j] / (double)NSAMPLE; } /* update parameters */ for (j = 0; j < XDIM+1; j++) { a[j] = a[j] - alpha * derivatives[j]; } /* Compute Mean Squared Error */ mse = 0.0; for (l = 0; l < NSAMPLE; l++) { /* prediction */ y = a[0]; for (j = 1; j < XDIM+1; j++) { y += a[j] * x[l][j-1]; } /* error */ err = t[l] - y; mse += err * err; } mse = mse / (double)NSAMPLE; printf("%d : Mean Squared Error is %f\n", i, mse); fprintf(fp, "%f\n", mse); } fclose(fp); /* Print Estmated Parameters */ printf("\nEstimated Parameters\n"); for (j = 0; j < XDIM+1; j++) { printf("a[%d]=%f, ",j, a[j]); } printf("\n\n"); /* Prediction and Errors */ for (l = 0; l < NSAMPLE; l++) { /* prediction */ y = a[0]; for (j = 1; j < XDIM+1; j++) { y += a[j] * x[l][j-1]; } /* error */ err = t[l] - y; if ((1.0 - y)*(1.0 - y) <= (0.0 - y)*(0.0 - y)) { if (l < 50) { printf("%3d [Class1 : correct] : t = %f, y = %f (err = %f)\n", l, t[l], y, err); } else { printf("%3d [Class1 : not correct] : t = %f, y = %f (err = %f)\n", l, t[l], y, err); } } else { if (l >= 50) { printf("%3d [Class2 : correct] : t = %f, y = %f (err = %f)\n", l, t[l], y, err); } else { printf("%3d [Class2 : not correct] : t = %f, y = %f (err = %f)\n", l, t[l], y, err); } } } }
このプログラムは、先の最小2乗法を最急降下法で解くプログラムとほとんど 同じです。アヤメのデータファイルniris.datを読み込んで、そのデータに対 して最急降下法でパラメータを求めています。教師信号が実数値ではなく、 との2値で与えられるとことだけが最小2乗法との違いです。プログラムの 最後の部分では、得られたニューラルネット(識別器)の良さを確認するため に学習に用いたアヤメのデータを識別させています。ニューラルネットの出力 がとのどちらに近いかで、どちらのアヤメかを決定しています。プログ ラムの実行結果は、以下のようになります。
Estimated Parameters a[0]=1.239302, a[1]=0.145552, a[2]=0.139415, a[3]=-0.638937, a[4]=-0.537277, 0 [Class1 : correct] : t = 1.000000, y = 1.003690 (err = -0.003690) 1 [Class1 : correct] : t = 1.000000, y = 0.899888 (err = 0.100112) 2 [Class1 : correct] : t = 1.000000, y = 0.810625 (err = 0.189375) 3 [Class1 : correct] : t = 1.000000, y = 0.775510 (err = 0.224490) 4 [Class1 : correct] : t = 1.000000, y = 0.752820 (err = 0.247180) 5 [Class1 : correct] : t = 1.000000, y = 0.789633 (err = 0.210367) 6 [Class1 : correct] : t = 1.000000, y = 0.771009 (err = 0.228991) 7 [Class1 : correct] : t = 1.000000, y = 1.168271 (err = -0.168271) 8 [Class1 : correct] : t = 1.000000, y = 0.943976 (err = 0.056024) 9 [Class1 : correct] : t = 1.000000, y = 0.816619 (err = 0.183381) 10 [Class1 : correct] : t = 1.000000, y = 0.984887 (err = 0.015113) 11 [Class1 : correct] : t = 1.000000, y = 0.856557 (err = 0.143443) 12 [Class1 : correct] : t = 1.000000, y = 1.043678 (err = -0.043678) 13 [Class1 : correct] : t = 1.000000, y = 0.748846 (err = 0.251154) 14 [Class1 : correct] : t = 1.000000, y = 1.130948 (err = -0.130948) 15 [Class1 : correct] : t = 1.000000, y = 1.027689 (err = -0.027689) 16 [Class1 : correct] : t = 1.000000, y = 0.694755 (err = 0.305245) 17 [Class1 : correct] : t = 1.000000, y = 1.132590 (err = -0.132590) 18 [Class1 : correct] : t = 1.000000, y = 0.543723 (err = 0.456277) 19 [Class1 : correct] : t = 1.000000, y = 1.035076 (err = -0.035076) 20 [Class2 : not correct] : t = 1.000000, y = 0.490680 (err = 0.509320) 21 [Class1 : correct] : t = 1.000000, y = 1.041685 (err = -0.041685) 22 [Class1 : correct] : t = 1.000000, y = 0.512358 (err = 0.487642) 23 [Class1 : correct] : t = 1.000000, y = 0.858199 (err = 0.141801) 24 [Class1 : correct] : t = 1.000000, y = 1.017686 (err = -0.017686) 25 [Class1 : correct] : t = 1.000000, y = 0.977978 (err = 0.022022) 26 [Class1 : correct] : t = 1.000000, y = 0.803766 (err = 0.196234) 27 [Class1 : correct] : t = 1.000000, y = 0.565534 (err = 0.434466) 28 [Class1 : correct] : t = 1.000000, y = 0.733136 (err = 0.266864) 29 [Class1 : correct] : t = 1.000000, y = 1.300773 (err = -0.300773) 30 [Class1 : correct] : t = 1.000000, y = 1.021680 (err = -0.021680) 31 [Class1 : correct] : t = 1.000000, y = 1.128719 (err = -0.128719) 32 [Class1 : correct] : t = 1.000000, y = 1.063775 (err = -0.063775) 33 [Class2 : not correct] : t = 1.000000, y = 0.380334 (err = 0.619666) 34 [Class1 : correct] : t = 1.000000, y = 0.659518 (err = 0.340482) 35 [Class1 : correct] : t = 1.000000, y = 0.822877 (err = 0.177123) 36 [Class1 : correct] : t = 1.000000, y = 0.848019 (err = 0.151981) 37 [Class1 : correct] : t = 1.000000, y = 0.771196 (err = 0.228804) 38 [Class1 : correct] : t = 1.000000, y = 0.981463 (err = 0.018537) 39 [Class1 : correct] : t = 1.000000, y = 0.839696 (err = 0.160304) 40 [Class1 : correct] : t = 1.000000, y = 0.797250 (err = 0.202750) 41 [Class1 : correct] : t = 1.000000, y = 0.817255 (err = 0.182745) 42 [Class1 : correct] : t = 1.000000, y = 0.995367 (err = 0.004633) 43 [Class1 : correct] : t = 1.000000, y = 1.153796 (err = -0.153796) 44 [Class1 : correct] : t = 1.000000, y = 0.848869 (err = 0.151131) 45 [Class1 : correct] : t = 1.000000, y = 1.033489 (err = -0.033489) 46 [Class1 : correct] : t = 1.000000, y = 0.930673 (err = 0.069327) 47 [Class1 : correct] : t = 1.000000, y = 0.982449 (err = 0.017551) 48 [Class1 : correct] : t = 1.000000, y = 1.273824 (err = -0.273824) 49 [Class1 : correct] : t = 1.000000, y = 0.934896 (err = 0.065104) 50 [Class2 : correct] : t = 0.000000, y = -0.337600 (err = 0.337600) 51 [Class2 : correct] : t = 0.000000, y = 0.132929 (err = -0.132929) 52 [Class2 : correct] : t = 0.000000, y = 0.026276 (err = -0.026276) 53 [Class2 : correct] : t = 0.000000, y = 0.174352 (err = -0.174352) 54 [Class2 : correct] : t = 0.000000, y = -0.113841 (err = 0.113841) 55 [Class2 : correct] : t = 0.000000, y = -0.139841 (err = 0.139841) 56 [Class2 : correct] : t = 0.000000, y = 0.269516 (err = -0.269516) 57 [Class2 : correct] : t = 0.000000, y = 0.096326 (err = -0.096326) 58 [Class2 : correct] : t = 0.000000, y = 0.043823 (err = -0.043823) 59 [Class2 : correct] : t = 0.000000, y = -0.119072 (err = 0.119072) 60 [Class2 : correct] : t = 0.000000, y = 0.345999 (err = -0.345999) 61 [Class2 : correct] : t = 0.000000, y = 0.166008 (err = -0.166008) 62 [Class2 : correct] : t = 0.000000, y = 0.118683 (err = -0.118683) 63 [Class2 : correct] : t = 0.000000, y = 0.016717 (err = -0.016717) 64 [Class2 : correct] : t = 0.000000, y = -0.188593 (err = 0.188593) 65 [Class2 : correct] : t = 0.000000, y = 0.043580 (err = -0.043580) 66 [Class2 : correct] : t = 0.000000, y = 0.277996 (err = -0.277996) 67 [Class2 : correct] : t = 0.000000, y = 0.027482 (err = -0.027482) 68 [Class2 : correct] : t = 0.000000, y = -0.500986 (err = 0.500986) 69 [Class2 : correct] : t = 0.000000, y = 0.326909 (err = -0.326909) 70 [Class2 : correct] : t = 0.000000, y = -0.013590 (err = 0.013590) 71 [Class2 : correct] : t = 0.000000, y = 0.290259 (err = -0.290259) 72 [Class2 : correct] : t = 0.000000, y = -0.152001 (err = 0.152001) 73 [Class2 : correct] : t = 0.000000, y = 0.364375 (err = -0.364375) 74 [Class2 : correct] : t = 0.000000, y = 0.124712 (err = -0.124712) 75 [Class2 : correct] : t = 0.000000, y = 0.283933 (err = -0.283933) 76 [Class2 : correct] : t = 0.000000, y = 0.415164 (err = -0.415164) 77 [Class2 : correct] : t = 0.000000, y = 0.425416 (err = -0.425416) 78 [Class2 : correct] : t = 0.000000, y = -0.052291 (err = 0.052291) 79 [Class2 : correct] : t = 0.000000, y = 0.433825 (err = -0.433825) 80 [Class2 : correct] : t = 0.000000, y = 0.083760 (err = -0.083760) 81 [Class2 : correct] : t = 0.000000, y = 0.313110 (err = -0.313110) 82 [Class2 : correct] : t = 0.000000, y = -0.123014 (err = 0.123014) 83 [Class1 : not correct] : t = 0.000000, y = 0.536005 (err = -0.536005) 84 [Class2 : correct] : t = 0.000000, y = 0.325728 (err = -0.325728) 85 [Class2 : correct] : t = 0.000000, y = -0.082091 (err = 0.082091) 86 [Class2 : correct] : t = 0.000000, y = -0.089522 (err = 0.089522) 87 [Class2 : correct] : t = 0.000000, y = 0.292471 (err = -0.292471) 88 [Class2 : correct] : t = 0.000000, y = 0.444113 (err = -0.444113) 89 [Class2 : correct] : t = 0.000000, y = 0.204709 (err = -0.204709) 90 [Class2 : correct] : t = 0.000000, y = -0.115327 (err = 0.115327) 91 [Class2 : correct] : t = 0.000000, y = 0.172210 (err = -0.172210) 92 [Class2 : correct] : t = 0.000000, y = 0.132929 (err = -0.132929) 93 [Class2 : correct] : t = 0.000000, y = -0.103840 (err = 0.103840) 94 [Class2 : correct] : t = 0.000000, y = -0.158180 (err = 0.158180) 95 [Class2 : correct] : t = 0.000000, y = 0.068565 (err = -0.068565) 96 [Class2 : correct] : t = 0.000000, y = 0.193150 (err = -0.193150) 97 [Class2 : correct] : t = 0.000000, y = 0.245497 (err = -0.245497) 98 [Class2 : correct] : t = 0.000000, y = 0.036213 (err = -0.036213) 99 [Class2 : correct] : t = 0.000000, y = 0.317548 (err = -0.317548)
この例では、3個の間違いましたが、4個の特徴からほぼアヤメの種類を識別で きていることがわかります。