Sugg 技术的原理

Trie tree 介绍

trie 源自 retrieval ,中文称为前缀树或字典树。具体介绍见wiki trie

C++ 实现

以下trie实现支持任何语言(Chinese,English,Janpanse...)。主要包括以下三个接口

// 使用一组词初始化trie.
void Init(const std::vector<std::string>& dict);
// 在trie 中查找word是否存在.
bool Lookup(const std::string& word);
// 返回在trie中所有以word为前缀的词.
std::vector<std::string> Suggest(const std::string& word);

具体代码实现如下
trie.hpp

/*************************************************************************
    > File Name: trie.hpp
    > Author: ce39906
    > Mail: ce39906@163.com
    > Created Time: 2018-07-19 11:12:09
 ************************************************************************/
#ifndef TRIE_HPP
#define TRIE_HPP

#include <vector>
#include <string>

namespace trie
{

class Trie
{
    static constexpr size_t kAsciiCount = 256;
    struct TrieNode
    {
        TrieNode(const char val)
          : val(val), is_end(false), childrens(kAsciiCount, nullptr)
        {
        }

        char val;
        bool is_end;
        std::vector<TrieNode*> childrens;
    };

  public:
    Trie()
    {
        root = new TrieNode('0');
    }

    ~Trie()
    {
        ReleaseTrie(root);
    }

    Trie(const Trie&) = delete;
    Trie& operator = (const Trie&) = delete;

    void Init(const std::vector<std::string>& dict);

    bool Lookup(const std::string& word) const;

    std::vector<std::string> Suggest(const std::string& word) const;

    void PrintSuggs(const std::string& word) const;

  private:

    void Insert(const std::string& word);

    bool Search(const TrieNode* parent, const std::string& word, const size_t idx) const;

    void Dfs(const TrieNode* cur, std::string& word, std::vector<std::string>& suggs) const;

    void ReleaseTrie(const TrieNode* root);

    TrieNode* root;
};
} // ns trie
#endif

trie.cpp

/*************************************************************************
    > File Name: trie.cpp
    > Author: ce39906
    > Mail: ce39906@163.com
    > Created Time: 2018-07-19 14:04:15
 ************************************************************************/
#include "trie.hpp"
#include <iostream>

namespace trie
{

void Trie::Init(const std::vector<std::string>& dict)
{
    for (const std::string& word : dict)
    {
        Insert(word);
    }
}

bool Trie::Lookup(const std::string& word) const
{
    return Search(root, word, 0);
}

std::vector<std::string> Trie::Suggest(const std::string& word) const
{
    std::vector<std::string> suggs;
    if (word.empty())
    {
        return suggs;
    }
    // find prefix
    TrieNode* cur = root;
    // unsigned char range : 0 ~ 255
    for (const unsigned char c : word)
    {
        const std::vector<TrieNode*>& childrens = cur->childrens;
        if (!childrens[c])
        {
            return suggs;
        }
        cur = childrens[c];
    }

    if (!cur)
    {
        return suggs;
    }

    std::string prefix(word.begin(), word.end() - 1);
    Dfs(cur, prefix, suggs);
    return suggs;
}

void Trie::Insert(const std::string& word)
{
    if (word.empty())
        return;

    TrieNode* cur = root;
    // unsigned char range : 0 ~ 255
    for (const unsigned char c : word)
    {
        if (!cur->childrens[c])
        {
            cur->childrens[c] = new TrieNode(c);
        }
        cur = cur->childrens[c];
    }
    cur->is_end = true;
}

bool Trie::Search(const TrieNode* parent, const std::string& word, const size_t idx) const
{
    if (word.empty())
    {
        return false;
    }
    const std::vector<TrieNode*>& childrens = parent->childrens;
    // explicitly cast to unsigned char is needed
    const unsigned char c = word[idx];
    TrieNode* cur = childrens[c];
    if (!cur)
    {
        return false;
    }
    if (idx == word.size() - 1)
    {
        return cur->is_end;
    }
    return Search(cur, word, idx + 1);
}

void Trie::Dfs(const TrieNode* cur, std::string& word, std::vector<std::string>& suggs) const
{
    if (cur->is_end)
    {
        suggs.emplace_back(word + cur->val);
    }
    word.push_back(cur->val);
    const std::vector<TrieNode*>& childrens = cur->childrens;
    for (const TrieNode* children : childrens)
    {
        if (children)
        {
            Dfs(children, word, suggs);
        }
    }
    word.pop_back();
}

void Trie::ReleaseTrie(const TrieNode* root)
{
    if (!root) return;
    bool no_children = true;
    const std::vector<TrieNode*>& childrens = root->childrens;
    for (const TrieNode* children : childrens)
    {
        if (children)
        {
            no_children = false;
            ReleaseTrie(children);
        }
    }

    if (no_children)
    {
        delete root;
    }
}

void Trie::PrintSuggs(const std::string& word) const
{
    const auto quoted_string = [] (const std::string str)
    {
        return "\"" + str + "\"";
    };

    const std::vector<std::string>& suggs = Suggest(word);
    if (suggs.empty())
    {
        std::cout << "No suggs for " << quoted_string(word) << std::endl;
        return;
    }

    std::cout << "Suggs for " << quoted_string(word) << " are :\n";
    for (const std::string& sugg : suggs)
    {
        std::cout << quoted_string(sugg) << " ";
    }
    std::cout << std::endl;
}

} // ns trie

测试

测试数据

测试数据如下,本例中存储在文件trie_data

中国人民
中午
中国人
中国梦
伟大复兴
2020中国制造
中国制造2020
军工etf
北京
北京天安门
天气
天气预报
北京天气预报
beijing
beijing tiananmen
汉语
韩国人
韩国
韩范
美国热
东京热
苍井空
苍老师
机器学习
机器人
机器猫
机器狗
美团网
美团外卖
美团平台
美团酒旅
美团生鲜
美团大象
be搜搜

测试代码

test_trie.cpp

/*************************************************************************
    > File Name: test_trie.cpp
    > Author: ce39906
    > Mail: ce39906@163.com
    > Created Time: 2018-07-19 19:53:22
 ************************************************************************/
#include "trie.hpp"

#include <iostream>
#include <fstream>
#include <cstdlib>
#define NDEBUG
#include <cassert>

using namespace trie;

void usage(const char* bin)
{
    std::cout << bin << " : Need a filename as a parameter.\n";
    std::exit(EXIT_FAILURE);
}

void readFile2Vector(const std::string& file, std::vector<std::string>& vec)
{
    std::fstream infile(file, std::ios_base::in);
    std::string line;
    while(getline(infile, line, '\n'))
    {
        vec.emplace_back(line);
    }
}

int main(int argc, char* argv[])
{
    if (argc < 2)
    {
        usage(argv[0]);
    }

    const std::string data_file(argv[1]);
    std::vector<std::string> dict;
    readFile2Vector(data_file, dict);

    Trie trie;
    trie.Init(dict);
    // test trie lookup function
    for (const std::string& word : dict)
    {
        (void) word;
        assert(trie.Lookup(word) == true);
    }

    trie.PrintSuggs("美");
    trie.PrintSuggs("be");
    trie.PrintSuggs("中");
    trie.PrintSuggs("苍");
    trie.PrintSuggs("null");

    return 0;
}

编译

g++ --std=c++11 -O2 trie.cpp test_trie.cpp -o trie

测试结果

执行

./trie trie_data

pic

github 地址

https://github.com/ce39906/self-practices/tree/master/cppcode/trie