記得古時候有寫過一篇「用 Machine Learning 辨識鳶尾花」搞搞 Machine Learning 建立貝氏分類器來辨識鳶尾花,最近想在 JavaScript 上面跑 Neural Network (類神經網路) 的演算法,刷刷存在感。很久以前有用過 Java 手刻類神經網路演算法,用來辨識昆蟲聲音,太久 Code 不知道丟到哪了,現在連 Bug 都寫不太出來。前陣子剛好 Google 到國外有個神人的 NodeJS Project (node-neural-network) 剛好符合我的需求,就先拿來玩看看囉。
用 JavaScript Deep Learning 辨識鳶尾花
由於我手邊沒有什麼數據可以跑,影像處理就是跑 PlayBoy Lenna 圖,AI 資料處理就是跑鳶尾花這些老梗,沒什麼創意,如果有人知道比較新潮的數據庫,可否麻煩留言告訴我一下,我也想趕流行!
先說明一下類神經網路的概念,基本上就是複合的迴歸分析,透過神經元的訓練,理論上可以學習任何模型,當然你給的 Feature 要合理,訓練樣本要夠充足 (類神經依賴大量的訓練資料)。可以把類神經網路想像為一個黑盒子 f(x),不停的告訴它什麼樣的 Input 會得到什麼樣的 Output,慢慢的累績訓練,有一天這個 f(x) 就會跟你說:「媽,我好像懂了什麼?」,表示收斂了學習完成,可以開始進行預測囉。典型的類神經網路長得像下面這樣:
上圖左邊的是輸入神經元 Input Layer,表示我們資料的 Feature 維度特徵值;中間的是隱藏層,可以多層多節點,一般來說數量是 Input 神經元的 1.5 倍;最後接上的是輸出層 Output Layer,就是我們預期的預測結果。單單這些複雜的神經元連接方法,就有上百篇論文再討論,我們今天測試的連接方法是最典型的模型。
實現 JavaScript 類神經網路機器學習
先說明一下程式執行的流程:
- 建構類神經網路 4-6-3 結構
- 載入 iris.csv 資料檔並且進行正規劃,將數值投影到 0~1 的範圍
- 每個種類隨機抽出三組後,剩餘的進入神經網路進行訓練 (因為資料樣本不多,所以隨機重複訓練 10,000 次)
- 最後從前一個步驟隨機抽出的數據進行預測,並計算正確率
- 將訓練好的網路存為 Json File,未來可以讀取後直接進行預測
測試的鳶尾花資料有四組 Feature,花分成三個種類,因此 Input 4 個節點,Output 3 個節點,中間的隱藏層使用 Input 節點的 1.5 倍,所以隱藏層使用了 6 個節點。連接方式與上圖相同,完整的程式碼已經放到 GitHub,有興趣的可以輸入以下命令進行測試:
git clone https://github.com/samejack/blog-content
cd sj-nn
npm install
nom run example
執行畫面如下:
執行了幾次,正確率表現還不錯。這樣的規模說真的只會是玩具,聽說谷歌的網路都上億個節點,真正的深度學習。類神經屬於機器學習 (Machine Learning ) 中的監督式學習,只在有足夠的資料下,可以自動學習找出的特徵與結果的關係。由於這次使用的是 JavaScript 語言,因此也可以在瀏覽器中執行,作者的網站上還有許多有趣而且很威的 Example,不管是訓練 XOR Gate 或者辨識貓等等,有興趣的可以進去看看囉,很廢的 JavaScript 深度學習結束了,下次見。