标签: KMP

AC自动机原理及实现

AC自动机原理及实现

概述

AC自动机算法是一种常见的多串匹配算法。理解本算法需要先理解当模式串只有一个的时候的情况,即KMP算法原理与实现 | KSkun’s Blog。下面将默认你学会了KMP算法,介绍AC自动机算法。

结构与fail指针

AC自动机实际上要在一棵trie树上工作。因此我们得先把trie树建出来。
有了trie树,我们应用类似KMP的思路,计算出fail数组的替代品:fail指针。它的定义与fail数组很相似,是找到一个最长的前缀使得它与当前匹配的后缀相同。如果我们要求一个点的fail指针,其祖先的fail都已知,我们要沿着它父亲的fail指针找到一个有当前字母儿子的节点,并跳转至该儿子。这个上跳看上去很烦人,我们考虑把上跳的过程省去,即建虚拟边把跳到最后的目标给加在儿子上。由于我们需要按层计算fail,我们要利用BFS序来计算,写出来看上去是这样的。

inline void calfail() {
    for(int i = 0; i < 26; i++) {
        if(ch[0][i]) {
            fail[ch[0][i]] = 0;
            que.push(ch[0][i]);
        }
    }
    while(!que.empty()) {
        int u = que.front(); que.pop();
        for(int i = 0; i < 26; i++) {
            if(ch[u][i]) {
                fail[ch[u][i]] = ch[fail[u]][i];
                que.push(ch[u][i]);
            } else {
                ch[u][i] = ch[fail[u]][i]; // 加了这一行就不用跳跳跳啦
            }
        }
    }
}

匹配

基本和KMP没啥区别,失配→跳fail指针→继续匹配,注意不能匹配成功一次就不继续,应该沿fail指针继续上跳计算所有单词。

inline void query(char *str) {
    int len = strlen(str), p = 0;
    for(int i = 0; i < len; i++) {
        p = ch[p][str[i] - 'a'];
        for(int t = p; t; t = fail[t]) {
            // 需要统计什么的在这干
        }
    }
}

例题1:【P3808】【模板】AC自动机(简单版) – 洛谷

题意

给定n个模式串和1个文本串,求有多少个模式串在文本串里出现过。

题解

找到一个后要设为已访问过,否则会算重。

代码

// Code by KSkun, 2018/3
#include <cstdio>
#include <cstring>

#include <queue>

const int MAXN = 1000005;

int ch[MAXN][26], val[MAXN], fail[MAXN], cnt;
std::queue<int> que;

inline void insert(char *str) {
    int len = strlen(str), p = 0;
    for(int i = 0; i < len; i++) {
        int t = str[i] - 'a';
        if(!ch[p][t]) ch[p][t] = ++cnt;
        p = ch[p][t];
    }
    val[p]++;
}

inline void calfail() {
    for(int i = 0; i < 26; i++) {
        if(ch[0][i]) {
            fail[ch[0][i]] = 0;
            que.push(ch[0][i]);
        }
    }
    while(!que.empty()) {
        int u = que.front(); que.pop();
        for(int i = 0; i < 26; i++) {
            if(ch[u][i]) {
                fail[ch[u][i]] = ch[fail[u]][i];
                que.push(ch[u][i]);
            } else {
                ch[u][i] = ch[fail[u]][i];
            }
        }
    }
}

inline int query(char *str) {
    int len = strlen(str), p = 0, res = 0;
    for(int i = 0; i < len; i++) {
        p = ch[p][str[i] - 'a'];
        for(int t = p; t && ~val[t]; t = fail[t]) {
            res += val[t];
            val[t] = -1;
        }
    }
    return res;
}

int n;
char str[MAXN];

int main() {
    scanf("%d", &n);
    for(int i = 0; i < n; i++) {
        scanf("%s", str);
        insert(str);
    }
    calfail();
    scanf("%s", str);
    printf("%d", query(str));
    return 0;
}

例题2:【P3796】【模板】AC自动机(加强版) – 洛谷

题意

有N个由小写字母组成的模式串以及一个文本串T。每个模式串可能会在文本串中出现多次。你需要找出哪些模式串在文本串T中出现的次数最多。

题解

每找到一次在单词的计数器上加一,最后找最大值即可。

代码

// Code by KSkun, 2018/3
#include <cstdio>
#include <cstring>

#include <queue>

const int MAXN = 20005;

int ch[MAXN][26], val[MAXN], fail[MAXN], ans[MAXN], cnt;
std::queue<int> que;

inline void insert(char *str, int num) {
    int len = strlen(str), p = 0;
    for(int i = 0; i < len; i++) {
        int t = str[i] - 'a';
        if(!ch[p][t]) ch[p][t] = ++cnt;
        p = ch[p][t];
    }
    val[p] = num;
}

inline void calfail() {
    memset(fail, 0, sizeof(fail));
    for(int i = 0; i < 26; i++) {
        if(ch[0][i]) {
            fail[ch[0][i]] = 0;
            que.push(ch[0][i]);
        }
    }
    while(!que.empty()) {
        int u = que.front(); que.pop();
        for(int i = 0; i < 26; i++) {
            if(ch[u][i]) {
                fail[ch[u][i]] = ch[fail[u]][i];
                que.push(ch[u][i]);
            } else {
                ch[u][i] = ch[fail[u]][i];
            }
        }
    }
}

inline void query(char *str) {
    int len = strlen(str), p = 0;
    for(int i = 0; i < len; i++) {
        p = ch[p][str[i] - 'a'];
        for(int t = p; t; t = fail[t]) {
            if(val[t]) ans[val[t]]++;
        }
    }
}

int n;
char pat[155][75], str[1000005];

int main() {
    for(;;) {
        scanf("%d", &n);
        if(n == 0) break;
        memset(ch, 0, sizeof(ch));
        memset(val, 0, sizeof(val));
        memset(ans, 0, sizeof(ans));
        cnt = 0;
        for(int i = 1; i <= n; i++) {
            scanf("%s", pat[i]);
            insert(pat[i], i);
        }
        calfail();
        scanf("%s", str);
        query(str);
        int mx = 0;
        for(int i = 1; i <= n; i++) {
            mx = std::max(mx, ans[i]);
        }
        printf("%d\n", mx);
        for(int i = 1; i <= n; i++) {
            if(ans[i] == mx) printf("%s\n", pat[i]);
        }
    }
    return 0;
}

参考资料

KMP算法原理与实现

KMP算法原理与实现

概述

KMP是一种字符串匹配算法,其复杂度已经达到了该类算法的下界,即O(|T|+|P|),其中T是文本串,P是模式串。下面介绍它的原理与实现。

MP算法(俗称)

下面我们介绍一种被称作MP算法的东西,这可以说是KMP算法的一个前身。
我们来尝试用P=”ABCD”来匹配T=”ABCABCABCD”,并观察它的过程。
第一步:P_0 \rightarrow T_0

ABCABCABCD
|||X
ABCD

第二步:P_0 \rightarrow T_1

ABCABCABCD
 X
 ABCD

第三步:P_0 \rightarrow T_2

ABCABCABCD
  X
  ABCD

第四步:P_0 \rightarrow T_3

ABCABCABCD
   |||X
   ABCD

……
第七步:P_0 \rightarrow T_6

ABCABCABCD
      ||||
      ABCD

这便是朴素地匹配字符串的过程,我们发现要匹配完成一共尝试了7次。这个算法的复杂度是O(|T||P|)的。
下面我们想,如果在第一步P_3失配后能够直接转移到第四步的位置,因为我们会发现用T_1, T_2去匹配P是完全没有用的。我们需要知道,如果一个字符失配了,应该将P后移到哪里才能避免中间若干失配的情况。我们把这个信息表示为fail[i]失配数组,具体而言,失配数组表示如果当前位置失配了,应该将P的哪一位移到这里来。
我们会发现,失配需要移动的时候实际上是从当前位置之前的子串中找到一个最长的后缀,使得其与某一前缀相等。这个过程可以用自我匹配来完成。

inline void calfail() {
    int i = 0, j = -1;
    fail[0] = -1;
    for(; pattern[i]; i++, j++) {
        while(j >= 0 && pattern[j] != pattern[i]) {
            j = fail[j];
        }
        fail[i + 1] = j + 1;
    }
}

从当前位置开始往前的某个后缀是原串的某个前缀,那么后一位如果失配应该移至这个前缀的后一位。
举个例子,如果P=”ABABC”,fail数组应该看起来像这样

P=    A B A B C
fail=-1 0 0 1 2

当我们发现没有这样的后缀时,我们会到达P的头,得到fail[0]=-1这个值,这意味着我们需要将P整体后移了。现在我们匹配的思路也就明确了,即失配→用失配数组中的上一个位置对齐继续匹配,直到匹配完成。
下面就是一个匹配的实现。

inline int match() {
    calfail();
    int i = 0, j = 0, m = strlen(pattern);
    for(; text[i]; i++, j++) {
        while(j >= 0 && pattern[j] != text[i]) {
            j = fail[j];
        }
        if(j >= m - 1) {
            return i - m + 1;
        }
    }
}

现在的算法就是O(|P|+|T|)的了。

KMP算法

我们来观察一组MP的过程。现在T=”ABCABCABABC”,P=”ABABC”。
第一步:P_0 \rightarrow T_0

ABCABCABABC
||X
ABABC

第二步:P_0 \rightarrow T_2

ABCABCABABC
  X
  ABABC

第三步:P_0 \rightarrow T_3

ABCABCABABC
   ||X
   ABABC

第四步:P_0 \rightarrow T_5

ABCABCABABC
     X
     ABABC

第四步:P_0 \rightarrow T_6

ABCABCABABC
      |||||
      ABABC

我们发现第二步、第四步的时候,我们老是在C这个位置失配,每次失配还要尝试2次,既然我们都知道了C和现在与fail指定的位置的A配不上,为什么不想办法把这段给跳过去呢?
接下来我们会更改fail的求法,使得中间重复的A被跳过去,实际上更改的方法也很简单。

inline void calfail() {
    int i = 0, j = -1;
    fail[0] = -1;
    for(; pattern[i]; i++, j++) {
        while(j >= 0 && pattern[j] != pattern[i]) {
            j = fail[j];
        }
        fail[i + 1] = pattern[j + 1] == pattern[i + 1] ? fail[j + 1] : j + 1;
        // 如果遇到了相同的字符,再往前跳一次
    }
}

匹配的方法与上面MP的一样,只是fail有一点小区别而已。这个优化并没有影响整体复杂度,只是一个常数优化。
注意我们求出来的fail数组实际上比P串长一位,最后一位的值并没有用上。

代码

MP算法

本代码可以通过【P3375】【模板】KMP字符串匹配 – 洛谷,一题。

// Code by KSkun, 2018/3
#include <cstdio>
#include <cstring>

const int MAXN = 1000005;

int fail[MAXN];
char str1[MAXN], str2[MAXN];

inline void calfail() {
    int i = 0, j = -1;
    fail[0] = -1;
    for(; str2[i]; i++, j++) {
        while(j >= 0 && str2[j] != str2[i]) {
            j = fail[j];
        }
        fail[i + 1] = j + 1;
    }
}

inline void match() {
    calfail();
    int i = 0, j = 0, m = strlen(str2);
    for(; str1[i]; i++, j++) {
        while(j >= 0 && str2[j] != str1[i]) {
            j = fail[j];
        }
        if(j >= m - 1) {
            printf("%d\n", i - m + 2);
        }
    }
}

int main() {
    scanf("%s%s", str1, str2);
    match();
    for(int i = 1; str2[i - 1]; i++) {
        printf("%d ", fail[i]);
    }
    return 0;
}

KMP算法

// Code by KSkun, 2018/3
#include <cstdio>
#include <cstring>

const int MAXN = 1000005;

int fail[MAXN];
char str1[MAXN], str2[MAXN];

inline void calfail() {
    int i = 0, j = -1;
    fail[0] = -1;
    for(; str2[i]; i++, j++) {
        while(j >= 0 && str2[j] != str2[i]) {
            j = fail[j];
        }
        fail[i + 1] = str2[j + 1] == str2[i + 1] ? fail[j + 1] : j + 1;
    }
}

inline void match() {
    calfail();
    int i = 0, j = 0, m = strlen(str2);
    for(; str1[i]; i++, j++) {
        while(j >= 0 && str2[j] != str1[i]) {
            j = fail[j];
        }
        if(j >= m - 1) {
            printf("%d\n", i - m + 2);
        }
    }
}

int main() {
    scanf("%s%s", str1, str2);
    match();
    for(int i = 0; str2[i]; i++) {
        printf("%d ", fail[i]);
    }
    return 0;
}

参考资料