fbpx

JavaScript 實現類神經網路 (瀏覽器 Deep Learning 好棒棒)


記得古時候有寫過一篇「用 Machine Learning 辨識鳶尾花」搞搞 Machine Learning 建立貝氏分類器來辨識鳶尾花,最近想在 JavaScript 上面跑 Neural Network (類神經網路) 的演算法,刷刷存在感。很久以前有用過 Java 手刻類神經網路演算法,用來辨識昆蟲聲音,太久 Code 不知道丟到哪了,現在連 Bug 都寫不太出來。前陣子剛好 Google 到國外有個神人的 NodeJS Project (node-neural-network) 剛好符合我的需求,就先拿來玩看看囉。

用 JavaScript Deep Learning 辨識鳶尾花

由於我手邊沒有什麼數據可以跑,影像處理就是跑 PlayBoy Lenna 圖,AI 資料處理就是跑鳶尾花這些老梗,沒什麼創意,如果有人知道比較新潮的數據庫,可否麻煩留言告訴我一下,我也想趕流行!

先說明一下類神經網路的概念,基本上就是複合的迴歸分析,透過神經元的訓練,理論上可以學習任何模型,當然你給的 Feature 要合理,訓練樣本要夠充足 (類神經依賴大量的訓練資料)。可以把類神經網路想像為一個黑盒子 f(x),不停的告訴它什麼樣的 Input 會得到什麼樣的 Output,慢慢的累績訓練,有一天這個 f(x) 就會跟你說:「媽,我好像懂了什麼?」,表示收斂了學習完成,可以開始進行預測囉。典型的類神經網路長得像下面這樣:

neural network

上圖左邊的是輸入神經元 Input Layer,表示我們資料的 Feature 維度特徵值;中間的是隱藏層,可以多層多節點,一般來說數量是 Input 神經元的 1.5 倍;最後接上的是輸出層 Output Layer,就是我們預期的預測結果。單單這些複雜的神經元連接方法,就有上百篇論文再討論,我們今天測試的連接方法是最典型的模型。

實現 JavaScript 類神經網路機器學習

先說明一下程式執行的流程:

  1. 建構類神經網路 4-6-3 結構
  2. 載入 iris.csv 資料檔並且進行正規劃,將數值投影到 0~1 的範圍
  3. 每個種類隨機抽出三組後,剩餘的進入神經網路進行訓練 (因為資料樣本不多,所以隨機重複訓練 10,000 次)
  4. 最後從前一個步驟隨機抽出的數據進行預測,並計算正確率
  5. 將訓練好的網路存為 Json File,未來可以讀取後直接進行預測

測試的鳶尾花資料有四組 Feature,花分成三個種類,因此 Input 4 個節點,Output 3 個節點,中間的隱藏層使用 Input 節點的 1.5 倍,所以隱藏層使用了 6 個節點。連接方式與上圖相同,完整的程式碼已經放到 GitHub,有興趣的可以輸入以下命令進行測試:

git clone https://github.com/samejack/blog-content

cd sj-nn

npm install

nom run example

執行畫面如下:

執行了幾次,正確率表現還不錯。這樣的規模說真的只會是玩具,聽說谷歌的網路都上億個節點,真正的深度學習。類神經屬於機器學習 (Machine Learning ) 中的監督式學習,只在有足夠的資料下,可以自動學習找出的特徵與結果的關係。由於這次使用的是 JavaScript 語言,因此也可以在瀏覽器中執行,作者的網站上還有許多有趣而且很威的 Example,不管是訓練 XOR Gate 或者辨識貓等等,有興趣的可以進去看看囉,很廢的 JavaScript 深度學習結束了,下次見。

這是真的廣告

白金贊助

平價童鞋首選