開源技術教學文件網 撰寫高階函式 (Higher-Order Function)

最後修改日期為 JUN 18, 2019

前言

許多高階函式的思維,在於對串列 (list) 的操作,在 Go 語言中,可以用切片代替串列。在高階函式中,有幾個常見的模式 (pattern),在不同的程式語言或框架會有不同的名字,但精神是相通的。筆者在這裡介紹一些常見的高階函式模式,做為各位讀者撰寫高階函式的參考。

filter

fliter 函式接收一個串列及條件函式,回傳符合條件函式的新串列。有些程式語言會用 grepselect 等同概念但異名的函式。

package main
 
import (
    "log"
)
 
func filter(arr []int, predicate func(int) bool) []int {
    out := make([]int, 0)
 
    for _, e := range arr {
        if predicate(e) {
            out = append(out, e)
        }
    }
 
    return out
}
 
func eq(m []int, n []int) bool {
    if len(m) != len(n) {
        return false
    }
 
    for i := 0; i < len(m); i++ {
        if m[i] != n[i] {
            return false
        }
    }
 
    return true
}
 
func main() {
    arr := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
 
    out := filter(arr, func(n int) bool { return n%2 == 0 })
 
    if !eq(out, []int{2, 4, 6, 8, 10}) {
        log.Fatal("Wrong value")
    }
}

map

map 函式接收一個串列和轉換函式,每個元素經轉換函式轉換,最後回傳一個新的串列。有些程式語言用 apply 等名稱。

package main
 
import (
    "log"
)
 
// map is predefined word in Go. Use apply instead.
func apply(arr []int, mapper func(int) int) []int {
    out := make([]int, len(arr))
 
    for i, e := range arr {
        out[i] = mapper(e)
    }
 
    return out
}
 
// eq declared as before.
 
func main() {
    arr := []int{1, 2, 3, 4, 5}
 
    out := apply(arr, func(n int) int { return n * n })
 
    if !eq(out, []int{1, 4, 9, 16, 25}) {
        log.Fatal("Wrong value")
    }
}

reduce

reduce 函式接受一個串列和一個縮減函式,回傳一個單一值。有時候又稱為 fold 或其他名稱。

package main
 
import (
    "log"
)
 
func reduce(arr []int, reducer func(int, int) int) int {
    if len(arr) == 0 {
        return 0
    } else if len(arr) == 1 {
        return arr[0]
    }
 
    n := arr[0]
 
    for i := 1; i < len(arr); i++ {
        n = reducer(n, arr[i])
    }
 
    return n
}
 
func main() {
    arr := []int{1, 2, 3, 4, 5}
 
    n := reduce(arr, func(a int, b int) int { return a + b })
 
    if !(n == 15) {
        log.Fatal("Wrong value")
    }
}

partition

partition 函式接受一個串列和判斷函式,該函式會依據判斷函式,將原串列分為兩個新的串列。見下例:

package main
 
import (
    "log"
)
 
func partition(arr []int, predicate func(int) bool) ([]int, []int) {
    fit := make([]int, 0)
    nonfit := make([]int, 0)
 
    for _, e := range arr {
        if predicate(e) {
            fit = append(fit, e)
        } else {
            nonfit = append(nonfit, e)
        }
    }
 
    return fit, nonfit
}
 
// eq declared as before.
 
func main() {
    arr := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
 
    even, odd := partition(arr, func(n int) bool { return n%2 == 0 })
 
    if !eq(even, []int{2, 4, 6, 8, 10}) {
        log.Fatal("Wrong value")
    }
 
    if !eq(odd, []int{1, 3, 5, 7, 9}) {
        log.Fatal("Wrong value")
    }
}

fold

fold 函式接受兩個等長的串列,將其合併為一個新的串列,每個串列的元素是以兩個串列的元素組合而成的元組 (tuple)。在 Go 語言中,沒有內建的元組,以結構取代。見下例:

package main
 
import (
    "fmt"
)
 
type Tuple struct {
    First  int
    Second int
}
 
func zip(m []int, n []int) func() (Tuple, bool) {
    if len(m) != len(n) {
        panic("Unequal list length")
    }
 
    i := -1
    return func() (Tuple, bool) {
        i += 1
        if i < len(m) {
            return Tuple{First: m[i], Second: n[i]}, true
        } else {
            return Tuple{First: 0, Second: 0}, false
        }
    }
}
 
func main() {
    m := []int{1, 2, 3}
    n := []int{4, 5, 6}
 
    // Get the iterator
    iter := zip(m, n)
 
    // Iterate through the list.
    for {
        out, ok := iter()
        if !ok {
            break
        }
 
        fmt.Println(out)
    }
}

要注意的是,我們為了要保存 zip 函式內部的計數器狀態,我們將其寫成迭代器 (iterator),這裡應用到先前提過的閉包。當迭代器結束時,ok 會回傳 false,這時就離開迴圈。

enumerate

enumerate 函式接收一個串列,回傳以原串列元素和其索引值為元素的新串列。見下例:

package main
 
import (
    "fmt"
)
 
type Tuple struct {
    Index   int
    Element int
}
 
func enumerate(m []int) func() (Tuple, bool) {
    i := -1
    return func() (Tuple, bool) {
        i += 1
        if i < len(m) {
            return Tuple{Index: i, Element: m[i]}, true
        } else {
            return Tuple{Index: i, Element: 0}, false
        }
    }
}
 
func main() {
    m := []int{10, 20, 30, 40, 50}
 
    iter := enumerate(m)
 
    for {
        out, ok := iter()
        if !ok {
            break
        }
 
        fmt.Println(fmt.Sprintf("%d: %d", out.Index, out.Element))
    }
}

同樣地,我們在這裡使用迭代器。

組合數個高階函式

高階函式間,可以相互組合,達到複合的效果,如下例:

package main
 
// filter is declared as above.
 
// apply is declared as above.
 
// reduce is declared as above.
 
func main() {
    in := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
 
    temp := filter(in, func(n int) bool { return n%2 != 0 })
    temp = apply(temp, func(n int) int { return n * n })
    out := reduce(temp, func(a int, b int) int { return a + b })
 
    if !(out == (1*1)+(3*3)+(5*5)+(7*7)+(9*9)) {
        log.Fatal("Wrong value")
    }
}
分享本文
Facebook Twitter LinkedIn LINE Skype EverNote GMail Yahoo Yahoo
追蹤本站
Facebook Facebook Twitter