AC自动机原理及实现

AC自动机原理及实现

概述

AC自动机算法是一种常见的多串匹配算法。理解本算法需要先理解当模式串只有一个的时候的情况,即KMP算法原理与实现 | KSkun’s Blog。下面将默认你学会了KMP算法,介绍AC自动机算法。

结构与fail指针

AC自动机实际上要在一棵trie树上工作。因此我们得先把trie树建出来。
有了trie树,我们应用类似KMP的思路,计算出fail数组的替代品:fail指针。它的定义与fail数组很相似,是找到一个最长的前缀使得它与当前匹配的后缀相同。如果我们要求一个点的fail指针,其祖先的fail都已知,我们要沿着它父亲的fail指针找到一个有当前字母儿子的节点,并跳转至该儿子。这个上跳看上去很烦人,我们考虑把上跳的过程省去,即建虚拟边把跳到最后的目标给加在儿子上。由于我们需要按层计算fail,我们要利用BFS序来计算,写出来看上去是这样的。

inline void calfail() {
    for(int i = 0; i < 26; i++) {
        if(ch[0][i]) {
            fail[ch[0][i]] = 0;
            que.push(ch[0][i]);
        }
    }
    while(!que.empty()) {
        int u = que.front(); que.pop();
        for(int i = 0; i < 26; i++) {
            if(ch[u][i]) {
                fail[ch[u][i]] = ch[fail[u]][i];
                que.push(ch[u][i]);
            } else {
                ch[u][i] = ch[fail[u]][i]; // 加了这一行就不用跳跳跳啦
            }
        }
    }
}

匹配

基本和KMP没啥区别,失配→跳fail指针→继续匹配,注意不能匹配成功一次就不继续,应该沿fail指针继续上跳计算所有单词。

inline void query(char *str) {
    int len = strlen(str), p = 0;
    for(int i = 0; i < len; i++) {
        p = ch[p][str[i] - 'a'];
        for(int t = p; t; t = fail[t]) {
            // 需要统计什么的在这干
        }
    }
}

例题1:【P3808】【模板】AC自动机(简单版) – 洛谷

题意

给定n个模式串和1个文本串,求有多少个模式串在文本串里出现过。

题解

找到一个后要设为已访问过,否则会算重。

代码

// Code by KSkun, 2018/3
#include <cstdio>
#include <cstring>

#include <queue>

const int MAXN = 1000005;

int ch[MAXN][26], val[MAXN], fail[MAXN], cnt;
std::queue<int> que;

inline void insert(char *str) {
    int len = strlen(str), p = 0;
    for(int i = 0; i < len; i++) {
        int t = str[i] - 'a';
        if(!ch[p][t]) ch[p][t] = ++cnt;
        p = ch[p][t];
    }
    val[p]++;
}

inline void calfail() {
    for(int i = 0; i < 26; i++) {
        if(ch[0][i]) {
            fail[ch[0][i]] = 0;
            que.push(ch[0][i]);
        }
    }
    while(!que.empty()) {
        int u = que.front(); que.pop();
        for(int i = 0; i < 26; i++) {
            if(ch[u][i]) {
                fail[ch[u][i]] = ch[fail[u]][i];
                que.push(ch[u][i]);
            } else {
                ch[u][i] = ch[fail[u]][i];
            }
        }
    }
}

inline int query(char *str) {
    int len = strlen(str), p = 0, res = 0;
    for(int i = 0; i < len; i++) {
        p = ch[p][str[i] - 'a'];
        for(int t = p; t && ~val[t]; t = fail[t]) {
            res += val[t];
            val[t] = -1;
        }
    }
    return res;
}

int n;
char str[MAXN];

int main() {
    scanf("%d", &n);
    for(int i = 0; i < n; i++) {
        scanf("%s", str);
        insert(str);
    }
    calfail();
    scanf("%s", str);
    printf("%d", query(str));
    return 0;
}

例题2:【P3796】【模板】AC自动机(加强版) – 洛谷

题意

有N个由小写字母组成的模式串以及一个文本串T。每个模式串可能会在文本串中出现多次。你需要找出哪些模式串在文本串T中出现的次数最多。

题解

每找到一次在单词的计数器上加一,最后找最大值即可。

代码

// Code by KSkun, 2018/3
#include <cstdio>
#include <cstring>

#include <queue>

const int MAXN = 20005;

int ch[MAXN][26], val[MAXN], fail[MAXN], ans[MAXN], cnt;
std::queue<int> que;

inline void insert(char *str, int num) {
    int len = strlen(str), p = 0;
    for(int i = 0; i < len; i++) {
        int t = str[i] - 'a';
        if(!ch[p][t]) ch[p][t] = ++cnt;
        p = ch[p][t];
    }
    val[p] = num;
}

inline void calfail() {
    memset(fail, 0, sizeof(fail));
    for(int i = 0; i < 26; i++) {
        if(ch[0][i]) {
            fail[ch[0][i]] = 0;
            que.push(ch[0][i]);
        }
    }
    while(!que.empty()) {
        int u = que.front(); que.pop();
        for(int i = 0; i < 26; i++) {
            if(ch[u][i]) {
                fail[ch[u][i]] = ch[fail[u]][i];
                que.push(ch[u][i]);
            } else {
                ch[u][i] = ch[fail[u]][i];
            }
        }
    }
}

inline void query(char *str) {
    int len = strlen(str), p = 0;
    for(int i = 0; i < len; i++) {
        p = ch[p][str[i] - 'a'];
        for(int t = p; t; t = fail[t]) {
            if(val[t]) ans[val[t]]++;
        }
    }
}

int n;
char pat[155][75], str[1000005];

int main() {
    for(;;) {
        scanf("%d", &n);
        if(n == 0) break;
        memset(ch, 0, sizeof(ch));
        memset(val, 0, sizeof(val));
        memset(ans, 0, sizeof(ans));
        cnt = 0;
        for(int i = 1; i <= n; i++) {
            scanf("%s", pat[i]);
            insert(pat[i], i);
        }
        calfail();
        scanf("%s", str);
        query(str);
        int mx = 0;
        for(int i = 1; i <= n; i++) {
            mx = std::max(mx, ans[i]);
        }
        printf("%d\n", mx);
        for(int i = 1; i <= n; i++) {
            if(ans[i] == mx) printf("%s\n", pat[i]);
        }
    }
    return 0;
}

参考资料



发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据