next up previous
次へ: ロジスティック回帰モデル 上へ: ニューラルネット入門 戻る: パーセプトロン

ADALINE

最小2乗法の考え方を用いるとパーセプトロンの欠点をある程度解決した学習 法を導くことができます。これは、1960年にWidrowとHoff が提案した ADALINE(Adaptive Linear Neuron)というモデルで、閾値論理素子の線形の部 分

\begin{displaymath}
y = \sum_{i=1}^M a_i x_i + a_0
\end{displaymath} (35)

のみを取り出して利用するものです。このモデルでの学習は、教師の答えとネッ トワークの出力との平均2乗誤差を最小とするような結合重み $(a_0,a_1,\ldots,a_M)$を最急降下法によって求めるものです。従って、この モデルの出力関数は、McCullochとPittsの閾値論理素子やRosenblattのパーセ プトロンのように閾値関数ではなく、図3(b)のような線 形関数であるとみなすことができます。

まずは、前回のボール投げの記録を予測する最小2乗法のプログラムを参考に 2種類のアヤメのデータを識別するADALINEのプログラムを作ってみましょう。

アヤメ科のイリス・ベルシコロールとイリス・ベルジニカという2種類の花か ら、がくの長さ($x1$)、がくの幅($x2$)、花弁の長さ($x3$)、花弁の幅($x4$) の4種類の特徴を計測したデータがあります。データ数は、各花とも50個づつ あります。これは、Fisherという有名な統計学者が1936年に線形判別関数を適 用した有名なデータで、それ以来パターン認識の手法を確認する例として頻繁 に用いられています。

ここでは、教師の答えとしてイリス・ベルシコロールには($t=1$)を与え、イ リス・ベルジニカには($t=0$)を与えるものとします。がくの長さ($x1$)、が くの幅($x2$)、花弁の長さ($x3$)、花弁の幅($x4$)のデータから教師の答えを 予測するためのADALINEモデルは、

\begin{displaymath}
y(x1,x2,x3,x4) = a_0 + a_1 x1 + a_2 x2 + a_3 x3 + a_4 x4
\end{displaymath} (36)

となります。最小2乗法の場合と全く同じように、ADALINEでも予測値と教師 の答えとの平均2乗誤差を最小にするようなパラメータ $(a_0,a_1,a_2,a_3,a_4)$を最急降下法で求めます。平均2乗誤差の各パラメー タでの微分を計算して、パラメータの更新式を具体的に求めると
$\displaystyle a_0^{(k+1)}$ $\textstyle =$ $\displaystyle a_0^{(k)} + 2 \alpha \frac{1}{100} \sum_{l=1}^{100} (t_l - y_l)$  
$\displaystyle a_1^{(k+1)}$ $\textstyle =$ $\displaystyle a_1^{(k)} + 2 \alpha \frac{1}{100} \sum_{l=1}^{100} (t_l - y_l) x1_l$  
$\displaystyle a_2^{(k+1)}$ $\textstyle =$ $\displaystyle a_2^{(k)} + 2 \alpha \frac{1}{100} \sum_{l=1}^{100} (t_l - y_l) x2_l$  
$\displaystyle a_3^{(k+1)}$ $\textstyle =$ $\displaystyle a_3^{(k)} + 2 \alpha \frac{1}{100} \sum_{l=1}^{100} (t_l - y_l) x3_l$  
$\displaystyle a_4^{(k+1)}$ $\textstyle =$ $\displaystyle a_4^{(k)} + 2 \alpha \frac{1}{100} \sum_{l=1}^{100} (t_l - y_l) x4_l$ (37)

のようになります。ここ$t_l$および$y_l$は、それぞれ、$l$番目の計測デー タに対する教師の答えおよびADALINEモデルでの予測値です。また、$x1_l$$x2_l$$x3_l$および$x4_l$は、$l$番目の花を計測した特徴量の計測値です。

学習したADALINEモデルを用いてアヤメの花を識別するには、学習したADALINE モデルに計測した特徴量を代入し、教師の答えの予測値を求め、それが$1$に 近ければイリス・ベルシコロールと判定し、$0$に近ければイリス・ベルジニ カと判定すれば良いことになります。

具体的なプログラムは、以下のようになります。

#include <stdio.h>
#include <stdlib.h>

#define frand()   rand()/((double)RAND_MAX)

#define NSAMPLE 100
#define XDIM 4

main() {
  FILE *fp;
  double t[NSAMPLE];
  double x[NSAMPLE][XDIM];
  double a[XDIM+1];
  int   i, j, l;
  double y, err, mse;
  double derivatives[XDIM+1];
  double alpha = 0.1; /* Learning Rate */

  /* Open Data File */
  if ((fp = fopen("niris.dat","r")) == NULL) {
    fprintf(stderr,"File Open Fail\n");
    exit(1);
  }

  /* Read Data */
  for (l = 0; l < NSAMPLE; l++) {
    /* Input input vectors */
    for (j = 0; j < XDIM; j++) {
      fscanf(fp,"%lf",&(x[l][j]));
    }
    /* Set teacher signal */
    if (l < 50) t[l] = 1.0; else t[l] = 0.0;
  }
  /* Close Data File */
  fclose(fp);

  /* Print the data */
  for (l = 0; l < NSAMPLE; l++) {
    printf("%3d : %8.2f ", l, t[l]);
    for (j = 0; j < XDIM; j++) {
      printf("%8.2f ", x[l][j]);
    }
    printf("\n");
  }

  /* Initialize the parameters by random number */
  for (j = 0; j < XDIM+1; j++) {
    a[j] = (frand() - 0.5);
  }

  /* Open output file */
  fp = fopen("mse.out","w");

  /* Learning the parameters */
  for (i = 1; i < 1000; i++) { /* Learning Loop */

    /* Compute derivatives */
    /* Initialize derivatives */
    for (j = 0; j < XDIM+1; j++) {
      derivatives[j] = 0.0;
    }
    /* update derivatives */
    for (l = 0; l < NSAMPLE; l++) {
      /* prediction */
      y = a[0];
      for (j = 1; j < XDIM+1; j++) {
	y += a[j] * x[l][j-1];
      }

      /* error */
      err = t[l] - y;
      /*      printf("err[%d] = %f\n", l, err);*/

      /* update derivatives */
      derivatives[0] += err;
      for (j = 1; j < XDIM+1; j++) {
	derivatives[j] += err * x[l][j-1];
      }

    }
    for (j = 0; j < XDIM+1; j++) {
      derivatives[j] = -2.0 * derivatives[j] / (double)NSAMPLE;
    }

    /* update parameters */
    for (j = 0; j < XDIM+1; j++) {
      a[j] = a[j] - alpha * derivatives[j];
    }

    /* Compute Mean Squared Error */
    mse = 0.0;
    for (l = 0; l < NSAMPLE; l++) {
      /* prediction */
      y = a[0];
      for (j = 1; j < XDIM+1; j++) {
	y += a[j] * x[l][j-1];
      }

      /* error */
      err = t[l] - y;

      mse += err * err;
    }
    mse = mse / (double)NSAMPLE;

    printf("%d : Mean Squared Error is %f\n", i, mse);
    fprintf(fp, "%f\n", mse);

  }

  fclose(fp);

  /* Print Estmated Parameters */
  printf("\nEstimated Parameters\n");
  for (j = 0; j < XDIM+1; j++) {
    printf("a[%d]=%f, ",j, a[j]);
  }
  printf("\n\n");

  /* Prediction and Errors */
  for (l = 0; l < NSAMPLE; l++) {
    /* prediction */
    y = a[0];
    for (j = 1; j < XDIM+1; j++) {
      y += a[j] * x[l][j-1];
    }
    
    /* error */
    err = t[l] - y;

    if ((1.0 - y)*(1.0 - y) <= (0.0 - y)*(0.0 - y)) {
      if (l < 50) { 
	printf("%3d [Class1 : correct] : t = %f, y = %f (err = %f)\n", l, t[l], y, err);
      } else {
	printf("%3d [Class1 : not correct] : t = %f, y = %f (err = %f)\n", l, t[l], y, err);
      }
    } else {
      if (l >= 50) {
	printf("%3d [Class2 : correct] : t = %f, y = %f (err = %f)\n", l, t[l], y, err);
      } else {
	printf("%3d [Class2 : not correct] : t = %f, y = %f (err = %f)\n", l, t[l], y, err);
      }
    }
  }

}

このプログラムは、先の最小2乗法を最急降下法で解くプログラムとほとんど 同じです。アヤメのデータファイルniris.datを読み込んで、そのデータに対 して最急降下法でパラメータを求めています。教師信号が実数値ではなく、$0$$1$の2値で与えられるとことだけが最小2乗法との違いです。プログラムの 最後の部分では、得られたニューラルネット(識別器)の良さを確認するため に学習に用いたアヤメのデータを識別させています。ニューラルネットの出力 が$0$$1$のどちらに近いかで、どちらのアヤメかを決定しています。プログ ラムの実行結果は、以下のようになります。

Estimated Parameters
a[0]=1.239302, a[1]=0.145552, a[2]=0.139415, a[3]=-0.638937, a[4]=-0.537277, 

  0 [Class1 : correct] : t = 1.000000, y = 1.003690 (err = -0.003690)
  1 [Class1 : correct] : t = 1.000000, y = 0.899888 (err = 0.100112)
  2 [Class1 : correct] : t = 1.000000, y = 0.810625 (err = 0.189375)
  3 [Class1 : correct] : t = 1.000000, y = 0.775510 (err = 0.224490)
  4 [Class1 : correct] : t = 1.000000, y = 0.752820 (err = 0.247180)
  5 [Class1 : correct] : t = 1.000000, y = 0.789633 (err = 0.210367)
  6 [Class1 : correct] : t = 1.000000, y = 0.771009 (err = 0.228991)
  7 [Class1 : correct] : t = 1.000000, y = 1.168271 (err = -0.168271)
  8 [Class1 : correct] : t = 1.000000, y = 0.943976 (err = 0.056024)
  9 [Class1 : correct] : t = 1.000000, y = 0.816619 (err = 0.183381)
 10 [Class1 : correct] : t = 1.000000, y = 0.984887 (err = 0.015113)
 11 [Class1 : correct] : t = 1.000000, y = 0.856557 (err = 0.143443)
 12 [Class1 : correct] : t = 1.000000, y = 1.043678 (err = -0.043678)
 13 [Class1 : correct] : t = 1.000000, y = 0.748846 (err = 0.251154)
 14 [Class1 : correct] : t = 1.000000, y = 1.130948 (err = -0.130948)
 15 [Class1 : correct] : t = 1.000000, y = 1.027689 (err = -0.027689)
 16 [Class1 : correct] : t = 1.000000, y = 0.694755 (err = 0.305245)
 17 [Class1 : correct] : t = 1.000000, y = 1.132590 (err = -0.132590)
 18 [Class1 : correct] : t = 1.000000, y = 0.543723 (err = 0.456277)
 19 [Class1 : correct] : t = 1.000000, y = 1.035076 (err = -0.035076)
 20 [Class2 : not correct] : t = 1.000000, y = 0.490680 (err = 0.509320)
 21 [Class1 : correct] : t = 1.000000, y = 1.041685 (err = -0.041685)
 22 [Class1 : correct] : t = 1.000000, y = 0.512358 (err = 0.487642)
 23 [Class1 : correct] : t = 1.000000, y = 0.858199 (err = 0.141801)
 24 [Class1 : correct] : t = 1.000000, y = 1.017686 (err = -0.017686)
 25 [Class1 : correct] : t = 1.000000, y = 0.977978 (err = 0.022022)
 26 [Class1 : correct] : t = 1.000000, y = 0.803766 (err = 0.196234)
 27 [Class1 : correct] : t = 1.000000, y = 0.565534 (err = 0.434466)
 28 [Class1 : correct] : t = 1.000000, y = 0.733136 (err = 0.266864)
 29 [Class1 : correct] : t = 1.000000, y = 1.300773 (err = -0.300773)
 30 [Class1 : correct] : t = 1.000000, y = 1.021680 (err = -0.021680)
 31 [Class1 : correct] : t = 1.000000, y = 1.128719 (err = -0.128719)
 32 [Class1 : correct] : t = 1.000000, y = 1.063775 (err = -0.063775)
 33 [Class2 : not correct] : t = 1.000000, y = 0.380334 (err = 0.619666)
 34 [Class1 : correct] : t = 1.000000, y = 0.659518 (err = 0.340482)
 35 [Class1 : correct] : t = 1.000000, y = 0.822877 (err = 0.177123)
 36 [Class1 : correct] : t = 1.000000, y = 0.848019 (err = 0.151981)
 37 [Class1 : correct] : t = 1.000000, y = 0.771196 (err = 0.228804)
 38 [Class1 : correct] : t = 1.000000, y = 0.981463 (err = 0.018537)
 39 [Class1 : correct] : t = 1.000000, y = 0.839696 (err = 0.160304)
 40 [Class1 : correct] : t = 1.000000, y = 0.797250 (err = 0.202750)
 41 [Class1 : correct] : t = 1.000000, y = 0.817255 (err = 0.182745)
 42 [Class1 : correct] : t = 1.000000, y = 0.995367 (err = 0.004633)
 43 [Class1 : correct] : t = 1.000000, y = 1.153796 (err = -0.153796)
 44 [Class1 : correct] : t = 1.000000, y = 0.848869 (err = 0.151131)
 45 [Class1 : correct] : t = 1.000000, y = 1.033489 (err = -0.033489)
 46 [Class1 : correct] : t = 1.000000, y = 0.930673 (err = 0.069327)
 47 [Class1 : correct] : t = 1.000000, y = 0.982449 (err = 0.017551)
 48 [Class1 : correct] : t = 1.000000, y = 1.273824 (err = -0.273824)
 49 [Class1 : correct] : t = 1.000000, y = 0.934896 (err = 0.065104)
 50 [Class2 : correct] : t = 0.000000, y = -0.337600 (err = 0.337600)
 51 [Class2 : correct] : t = 0.000000, y = 0.132929 (err = -0.132929)
 52 [Class2 : correct] : t = 0.000000, y = 0.026276 (err = -0.026276)
 53 [Class2 : correct] : t = 0.000000, y = 0.174352 (err = -0.174352)
 54 [Class2 : correct] : t = 0.000000, y = -0.113841 (err = 0.113841)
 55 [Class2 : correct] : t = 0.000000, y = -0.139841 (err = 0.139841)
 56 [Class2 : correct] : t = 0.000000, y = 0.269516 (err = -0.269516)
 57 [Class2 : correct] : t = 0.000000, y = 0.096326 (err = -0.096326)
 58 [Class2 : correct] : t = 0.000000, y = 0.043823 (err = -0.043823)
 59 [Class2 : correct] : t = 0.000000, y = -0.119072 (err = 0.119072)
 60 [Class2 : correct] : t = 0.000000, y = 0.345999 (err = -0.345999)
 61 [Class2 : correct] : t = 0.000000, y = 0.166008 (err = -0.166008)
 62 [Class2 : correct] : t = 0.000000, y = 0.118683 (err = -0.118683)
 63 [Class2 : correct] : t = 0.000000, y = 0.016717 (err = -0.016717)
 64 [Class2 : correct] : t = 0.000000, y = -0.188593 (err = 0.188593)
 65 [Class2 : correct] : t = 0.000000, y = 0.043580 (err = -0.043580)
 66 [Class2 : correct] : t = 0.000000, y = 0.277996 (err = -0.277996)
 67 [Class2 : correct] : t = 0.000000, y = 0.027482 (err = -0.027482)
 68 [Class2 : correct] : t = 0.000000, y = -0.500986 (err = 0.500986)
 69 [Class2 : correct] : t = 0.000000, y = 0.326909 (err = -0.326909)
 70 [Class2 : correct] : t = 0.000000, y = -0.013590 (err = 0.013590)
 71 [Class2 : correct] : t = 0.000000, y = 0.290259 (err = -0.290259)
 72 [Class2 : correct] : t = 0.000000, y = -0.152001 (err = 0.152001)
 73 [Class2 : correct] : t = 0.000000, y = 0.364375 (err = -0.364375)
 74 [Class2 : correct] : t = 0.000000, y = 0.124712 (err = -0.124712)
 75 [Class2 : correct] : t = 0.000000, y = 0.283933 (err = -0.283933)
 76 [Class2 : correct] : t = 0.000000, y = 0.415164 (err = -0.415164)
 77 [Class2 : correct] : t = 0.000000, y = 0.425416 (err = -0.425416)
 78 [Class2 : correct] : t = 0.000000, y = -0.052291 (err = 0.052291)
 79 [Class2 : correct] : t = 0.000000, y = 0.433825 (err = -0.433825)
 80 [Class2 : correct] : t = 0.000000, y = 0.083760 (err = -0.083760)
 81 [Class2 : correct] : t = 0.000000, y = 0.313110 (err = -0.313110)
 82 [Class2 : correct] : t = 0.000000, y = -0.123014 (err = 0.123014)
 83 [Class1 : not correct] : t = 0.000000, y = 0.536005 (err = -0.536005)
 84 [Class2 : correct] : t = 0.000000, y = 0.325728 (err = -0.325728)
 85 [Class2 : correct] : t = 0.000000, y = -0.082091 (err = 0.082091)
 86 [Class2 : correct] : t = 0.000000, y = -0.089522 (err = 0.089522)
 87 [Class2 : correct] : t = 0.000000, y = 0.292471 (err = -0.292471)
 88 [Class2 : correct] : t = 0.000000, y = 0.444113 (err = -0.444113)
 89 [Class2 : correct] : t = 0.000000, y = 0.204709 (err = -0.204709)
 90 [Class2 : correct] : t = 0.000000, y = -0.115327 (err = 0.115327)
 91 [Class2 : correct] : t = 0.000000, y = 0.172210 (err = -0.172210)
 92 [Class2 : correct] : t = 0.000000, y = 0.132929 (err = -0.132929)
 93 [Class2 : correct] : t = 0.000000, y = -0.103840 (err = 0.103840)
 94 [Class2 : correct] : t = 0.000000, y = -0.158180 (err = 0.158180)
 95 [Class2 : correct] : t = 0.000000, y = 0.068565 (err = -0.068565)
 96 [Class2 : correct] : t = 0.000000, y = 0.193150 (err = -0.193150)
 97 [Class2 : correct] : t = 0.000000, y = 0.245497 (err = -0.245497)
 98 [Class2 : correct] : t = 0.000000, y = 0.036213 (err = -0.036213)
 99 [Class2 : correct] : t = 0.000000, y = 0.317548 (err = -0.317548)

この例では、3個の間違いましたが、4個の特徴からほぼアヤメの種類を識別で きていることがわかります。



平成14年7月19日