式(39)のロジスティック回帰モデルでは、出力はから
の間の値で、イリス・ベルシコロールの場合にはに近い値を出力し、そ
うでない場合(イリス・ベルジニカの場合)にはに近い値を出力することが
期待されます。そこで、ロジスティック回帰モデルの出力をイリス・ベル
シコロールである確率と解釈します。また、今考えている問題ではアヤメの種
類は2種類のみですので、イリス・ベルジニカである確率はと解釈でき
ます。従って、100個のアヤメの計測データが得られる尤もらしさ(尤度)は、
(40) |
(41) |
これまでと同じように、最急降下法を適用するためには、評価関数(対数尤度)
の各パラメータでの微分が必要となります。対数尤度をパラメータで微
分すると
(42) |
(43) |
(44) |
ロジスティック回帰の場合には、出力値はからの間の値を取り、イリス・ ベルシコロールである確率の推定値であると解釈できますので、アヤメの花を 識別するには、出力値が以上ならイリス・ベルシコロールであり、 以下ならイリス・ベルジニカであると判断すれば良いことになります。
先のADALINEのプログラムを修正して、ロジスティック回帰モデルを用いてア ヤメの識別のためのパラメータを学習するプログラムを作ってみると、以下の ようになります。
#include <stdio.h> #include <stdlib.h> #include <math.h> #define frand() rand()/((double)RAND_MAX) #define NSAMPLE 100 #define XDIM 4 double logit(double eta) { return(exp(eta)/(1.0+exp(eta))); } main() { FILE *fp; double t[NSAMPLE]; double x[NSAMPLE][XDIM]; double a[XDIM+1]; int i, j, l; double eta; double y, err, likelihood; double derivatives[XDIM+1]; double alpha = 0.1; /* Learning Rate */ /* Open Data File */ if ((fp = fopen("niris.dat","r")) == NULL) { fprintf(stderr,"File Open Fail\n"); exit(1); } /* Read Data */ for (l = 0; l < NSAMPLE; l++) { /* Input input vectors */ for (j = 0; j < XDIM; j++) { fscanf(fp,"%lf",&(x[l][j])); } /* Set teacher signal */ if (l < 50) t[l] = 1.0; else t[l] = 0.0; } /* Close Data File */ fclose(fp); /* Print the data */ for (l = 0; l < NSAMPLE; l++) { printf("%3d : %8.2f ", l, t[l]); for (j = 0; j < XDIM; j++) { printf("%8.2f ", x[l][j]); } printf("\n"); } /* Initialize the parameters by random number */ for (j = 0; j < XDIM+1; j++) { a[j] = (frand() - 0.5); } /* Open output file */ fp = fopen("likelihood.out","w"); /* Learning the parameters */ for (i = 1; i < 100; i++) { /* Learning Loop */ /* Compute derivatives */ /* Initialize derivatives */ for (j = 0; j < XDIM+1; j++) { derivatives[j] = 0.0; } /* update derivatives */ for (l = 0; l < NSAMPLE; l++) { /* prediction */ eta = a[0]; for (j = 1; j < XDIM+1; j++) { eta += a[j] * x[l][j-1]; } y = logit(eta); /* error */ err = t[l] - y; /* update derivatives */ derivatives[0] += err; for (j = 1; j < XDIM+1; j++) { derivatives[j] += err * x[l][j-1]; } } /* update parameters */ for (j = 0; j < XDIM+1; j++) { a[j] = a[j] + alpha * derivatives[j]; } /* Compute Log Likelihood */ likelihood = 0.0; for (l = 0; l < NSAMPLE; l++) { /* prediction */ eta = a[0]; for (j = 1; j < XDIM+1; j++) { eta += a[j] * x[l][j-1]; } y = logit(eta); likelihood += t[l] * log(y) + (1.0 - t[l]) * log(1.0 - y); } printf("%d : Log Likeihood is %f\n", i, likelihood); fprintf(fp, "%f\n", likelihood); } fclose(fp); /* Print Estmated Parameters */ printf("\nEstimated Parameters\n"); for (j = 0; j < XDIM+1; j++) { printf("a[%d]=%f, ",j, a[j]); } printf("\n\n"); /* Prediction and Log Likelihood */ for (l = 0; l < NSAMPLE; l++) { /* prediction */ eta = a[0]; for (j = 1; j < XDIM+1; j++) { eta += a[j] * x[l][j-1]; } y = logit(eta); if ( y > 0.5) { if (l < 50) { printf("%3d [Class1 : correct] : t = %f, y = %f\n", l, t[l], y); } else { printf("%3d [Class1 : not correct] : t = %f, y = %f\n", l, t[l], y); } } else { if (l >= 50) { printf("%3d [Class2 : correct] : t = %f, y = %f\n", l, t[l], y); } else { printf("%3d [Class2 : not correct] : t = %f, y = %f\n", l, t[l], y); } } } }
このプログラムの出力結果は、以下のようになります。
Estimated Parameters a[0]=8.946368, a[1]=0.882509, a[2]=1.338263, a[3]=-6.766164, a[4]=-7.298297, 0 [Class1 : correct] : t = 1.000000, y = 0.993719 1 [Class1 : correct] : t = 1.000000, y = 0.985676 2 [Class1 : correct] : t = 1.000000, y = 0.948786 3 [Class1 : correct] : t = 1.000000, y = 0.987152 4 [Class1 : correct] : t = 1.000000, y = 0.938278 5 [Class1 : correct] : t = 1.000000, y = 0.984824 6 [Class1 : correct] : t = 1.000000, y = 0.937193 7 [Class1 : correct] : t = 1.000000, y = 0.999931 8 [Class1 : correct] : t = 1.000000, y = 0.993680 9 [Class1 : correct] : t = 1.000000, y = 0.990782 10 [Class1 : correct] : t = 1.000000, y = 0.999542 11 [Class1 : correct] : t = 1.000000, y = 0.985725 12 [Class1 : correct] : t = 1.000000, y = 0.999419 13 [Class1 : correct] : t = 1.000000, y = 0.960009 14 [Class1 : correct] : t = 1.000000, y = 0.999605 15 [Class1 : correct] : t = 1.000000, y = 0.996275 16 [Class1 : correct] : t = 1.000000, y = 0.940514 17 [Class1 : correct] : t = 1.000000, y = 0.999773 18 [Class1 : correct] : t = 1.000000, y = 0.718517 19 [Class1 : correct] : t = 1.000000, y = 0.999371 20 [Class2 : not correct] : t = 1.000000, y = 0.416174 21 [Class1 : correct] : t = 1.000000, y = 0.998533 22 [Class1 : correct] : t = 1.000000, y = 0.605839 23 [Class1 : correct] : t = 1.000000, y = 0.991769 24 [Class1 : correct] : t = 1.000000, y = 0.997522 25 [Class1 : correct] : t = 1.000000, y = 0.994371 26 [Class1 : correct] : t = 1.000000, y = 0.962073 27 [Class1 : correct] : t = 1.000000, y = 0.522861 28 [Class1 : correct] : t = 1.000000, y = 0.946845 29 [Class1 : correct] : t = 1.000000, y = 0.999966 30 [Class1 : correct] : t = 1.000000, y = 0.999352 31 [Class1 : correct] : t = 1.000000, y = 0.999831 32 [Class1 : correct] : t = 1.000000, y = 0.999283 33 [Class2 : not correct] : t = 1.000000, y = 0.268092 34 [Class1 : correct] : t = 1.000000, y = 0.927374 35 [Class1 : correct] : t = 1.000000, y = 0.969515 36 [Class1 : correct] : t = 1.000000, y = 0.969959 37 [Class1 : correct] : t = 1.000000, y = 0.974863 38 [Class1 : correct] : t = 1.000000, y = 0.998015 39 [Class1 : correct] : t = 1.000000, y = 0.993021 40 [Class1 : correct] : t = 1.000000, y = 0.990881 41 [Class1 : correct] : t = 1.000000, y = 0.979586 42 [Class1 : correct] : t = 1.000000, y = 0.998568 43 [Class1 : correct] : t = 1.000000, y = 0.999916 44 [Class1 : correct] : t = 1.000000, y = 0.992693 45 [Class1 : correct] : t = 1.000000, y = 0.998997 46 [Class1 : correct] : t = 1.000000, y = 0.996440 47 [Class1 : correct] : t = 1.000000, y = 0.996933 48 [Class1 : correct] : t = 1.000000, y = 0.999966 49 [Class1 : correct] : t = 1.000000, y = 0.996702 50 [Class2 : correct] : t = 0.000000, y = 0.000018 51 [Class2 : correct] : t = 0.000000, y = 0.016302 52 [Class2 : correct] : t = 0.000000, y = 0.001129 53 [Class2 : correct] : t = 0.000000, y = 0.019609 54 [Class2 : correct] : t = 0.000000, y = 0.000335 55 [Class2 : correct] : t = 0.000000, y = 0.000131 56 [Class2 : correct] : t = 0.000000, y = 0.190189 57 [Class2 : correct] : t = 0.000000, y = 0.003928 58 [Class2 : correct] : t = 0.000000, y = 0.004127 59 [Class2 : correct] : t = 0.000000, y = 0.000079 60 [Class2 : correct] : t = 0.000000, y = 0.058820 61 [Class2 : correct] : t = 0.000000, y = 0.014368 62 [Class2 : correct] : t = 0.000000, y = 0.003806 63 [Class2 : correct] : t = 0.000000, y = 0.004500 64 [Class2 : correct] : t = 0.000000, y = 0.000185 65 [Class2 : correct] : t = 0.000000, y = 0.001456 66 [Class2 : correct] : t = 0.000000, y = 0.047170 67 [Class2 : correct] : t = 0.000000, y = 0.000445 68 [Class2 : correct] : t = 0.000000, y = 0.000002 69 [Class2 : correct] : t = 0.000000, y = 0.231585 70 [Class2 : correct] : t = 0.000000, y = 0.000534 71 [Class2 : correct] : t = 0.000000, y = 0.037842 72 [Class2 : correct] : t = 0.000000, y = 0.000140 73 [Class2 : correct] : t = 0.000000, y = 0.137514 74 [Class2 : correct] : t = 0.000000, y = 0.003994 75 [Class2 : correct] : t = 0.000000, y = 0.027528 76 [Class2 : correct] : t = 0.000000, y = 0.222651 77 [Class2 : correct] : t = 0.000000, y = 0.244983 78 [Class2 : correct] : t = 0.000000, y = 0.000915 79 [Class2 : correct] : t = 0.000000, y = 0.183885 80 [Class2 : correct] : t = 0.000000, y = 0.002655 81 [Class2 : correct] : t = 0.000000, y = 0.011796 82 [Class2 : correct] : t = 0.000000, y = 0.000350 83 [Class1 : not correct] : t = 0.000000, y = 0.642195 84 [Class2 : correct] : t = 0.000000, y = 0.230225 85 [Class2 : correct] : t = 0.000000, y = 0.000146 86 [Class2 : correct] : t = 0.000000, y = 0.000293 87 [Class2 : correct] : t = 0.000000, y = 0.057084 88 [Class2 : correct] : t = 0.000000, y = 0.299894 89 [Class2 : correct] : t = 0.000000, y = 0.008427 90 [Class2 : correct] : t = 0.000000, y = 0.000178 91 [Class2 : correct] : t = 0.000000, y = 0.003929 92 [Class2 : correct] : t = 0.000000, y = 0.016302 93 [Class2 : correct] : t = 0.000000, y = 0.000222 94 [Class2 : correct] : t = 0.000000, y = 0.000086 95 [Class2 : correct] : t = 0.000000, y = 0.001591 96 [Class2 : correct] : t = 0.000000, y = 0.021935 97 [Class2 : correct] : t = 0.000000, y = 0.022459 98 [Class2 : correct] : t = 0.000000, y = 0.001482 99 [Class2 : correct] : t = 0.000000, y = 0.108289
今回も、やはり3個の識別に失敗していますが、4個の特徴からほぼアヤメの種 類を識別できるようになっていることがわかります。