RT

背景

目前对Tensorflow的主流应用模式是使用python训练模型,使用c++或者java应用训练好的模型。上篇博客介绍了如何在工程中应用Tensorflow 动态库,本博客介绍如何在工程中应用Tensorflow静态库

编译静态链接库

clone tensorflow git 仓库

git clone https://github.com/tensorflow/tensorflow.git
cd tensorflow

进入 tensorflow 工程下contrib/makefile路径

tensorflow/contrib/makefile

运行编译脚本

# MacOS 以及Linux 使用此脚本
./build_all_linux.sh

脚本执行完成后我们得到tensorflow静态库以及相应头文件。
下一步我们将Tensorflow静态库头文件及Tensorflow依赖的静态库头文件整理到统一路径下,其他c++ 工程就可以应用这些库文件及头文件。

# 将tensorflow 头文件以及库文件收集到此路径
mkdir -p ~/tensorflow_libs/tensorflow
# include 路径存放tensorflow主要头文件
mkdir -p ~/tensorflow_libs/tensorflow/include/tensorflow
# lib 路径存放编译好的静态库
mkdir -p ~/tensorflow_libs/tensorflow/lib
# 拷贝tensorflow 主要头文件
cp -r tensorflow/core ~/tensorflow_libs/tensorflow/include
# 拷贝tensoflow 静态库
cp tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a ~/tensorflow_libs/tensorflow/lib
# 拷贝tensorflow 第三方头文件
mkdir -p ~/tensorflow_libs/tensorflow/tensorflow_third_party
cp -r third_party ~/tensorflow_libs/tensorflow/tensorflow_third_party
# 拷贝gen目录下文件
cp -r tensorflow/contrib/makefile/gen/host_obj ~/tensorflow_libs/tensorflow/
cp -r tensorflow/contrib/makefile/gen/proto ~/tensorflow_libs/tensorflow/
cp -r tensorflow/contrib/makefile/gen/protobuf ~/tensorflow_libs/tensorflow/
cp -r tensorflow/contrib/makefile/gen/proto_text ~/tensorflow_libs/tensorflow/

# 拷贝downloads下文件
# eigen3
mkdir -p ~/tensorflow_libs/tensorflow/eigen3
cp -r tensorflow/contrib/makefile/downloads/eigen/Eigen ~/tensorflow_libs/tensorflow/eigen3
cp -r tensorflow/contrib/makefile/downloads/eigen/unsupported ~/tensorflow_libs/tensorflow/eigen3
# absl
cp -r tensorflow/contrib/makefile/downloads/absl ~/tensorflow_libs/tensorflow/
# nsyc
mkdir -p ~/tensorflow_libs/tensorflow/nsyc/include
mkdir -p ~/tensorflow_libs/tensorflow/nsyc/lib
# 拷贝nsyc 头文件
cp -r tensorflow/contrib/makefile/downloads/nsync/* ~/tensorflow_libs/tensorflow/nsyc/include
# 拷贝nsyc 库文件
cp tensorflow/contrib/makefile/downloads/nsync/builds/default.macos.c++11/libnsync.a ~/tensorflow_libs/tensorflow/nsyc/lib

应用静态链接库

示例python model

我们使用一个简单的model作为示例,该model实现了两个字符串拼接的功能。模型保存到test_model.pb。代码如下

#-*- coding: utf-8 -*-
"""
File Name: test_model.py
Author: ce39906
mail: ce39906@163.com
Created Time: 2018-09-11 17:15:32
"""
import tensorflow as tf

a = tf.Variable("hello ", name = "a")
b = tf.Variable("tensorflow", name = "b")
result = tf.add(a, b, name = "result")

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.train.write_graph(sess.graph_def, '.', 'test_model.pb', as_text = False)
    print result.eval()

c++工程加载model

首先我们将整理好的tensorflow 头文件以及静态库文件全部拷贝到c++工程目录,本例中存放tensorflow头文件和库文件的路径为tensorflow

#将整理好的tensorflow头文件及库文件全部拷贝到c++工程目录
cp -r ~/tensorflow_libs/tensorflow .

示例c++ 代码如下

/*************************************************************************
    > File Name: load_model.cpp
    > Author: ce39906
    > Mail: ce39906@163.com
    > Created Time: 2018-09-08 08:28:51
 ************************************************************************/
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"

#include <iostream>
#include <string>

const static std::string kModelPath = "test_model.pb";

int main()
{
    using namespace tensorflow;
    auto session = NewSession(SessionOptions());
    if (session == nullptr)
    {
        std::cerr << "Tensorflow session create failded.\n";
        return -1;
    }
    else
    {
        std::cout << "Tensorflow session create success.\n";
    }

    Status status;
    // Read in the protobuf graph we exported
    GraphDef graph_def;
    status = ReadBinaryProto(Env::Default(), kModelPath, &graph_def);
    if (!status.ok())
    {
        std::cerr << "Error reading graph definition from " << kModelPath
            << ": " << status.ToString();
        return -1;
    }
    else
    {
        std::cout << "Read graph def success.\n";
    }
    // Add the graph to the session
    status = session->Create(graph_def);
    if (!status.ok())
    {
        std::cerr << "Error creating graph: " << status.ToString();
        return -1;
    }
    else
    {
        std::cout << "Create graph success.\n";
    }
    // Set model input
    Tensor hello(DT_STRING, TensorShape());
    hello.scalar<string>()() = "hello";

    Tensor tensorflow(DT_STRING, TensorShape());
    tensorflow.scalar<string>()() = " tensorflow";

    // Apply the loaded model
    std::vector<std::pair<string, tensorflow::Tensor>> inputs =
    {
        { "a", hello },
        { "b", tensorflow },
    }; // input

    std::vector<tensorflow::Tensor> outputs; // output
    status = session->Run(inputs, {"result"}, {}, &outputs);
    if (!status.ok())
    {
        std::cerr << status.ToString() << std::endl;
        return -1;
    }
    else
    {
        std::cout << "Run session successfully" << std::endl;
    }
    // Output the result
    const auto result = outputs[0].scalar<string>()();
    std::cout << "Result value: " << result << std::endl;

    status = session->Close();
    if (!status.ok())
    {
        std::cerr << "Session closed success";
        return -1;
    }

    return 0;
}

CMakeLists.txt 如下

CMAKE_MINIMUM_REQUIRED(VERSION 2.8)

SET(TENSORFLOW_INCLUDE_PATH ${CMAKE_SOURCE_DIR}/tensorflow/include)
SET(TENSORFLOW_LIBARY ${CMAKE_SOURCE_DIR}/tensorflow/lib/libtensorflow-core.a)
MESSAGE(STATUS "TENSORFLOW_INCLUDE_PATH ${TENSORFLOW_INCLUDE_PATH}")
MESSAGE(STATUS "TENSORFLOW_LIBARY ${TENSORFLOW_LIBARY}")

SET(TENSORFLOW_PROTOBUF_INCLUDE_PATH ${CMAKE_SOURCE_DIR}/tensorflow/protobuf/include)
SET(TENSORFLOW_PROTOBUF_LIBRARY_PATH ${CMAKE_SOURCE_DIR}/tensorflow/protobuf/lib)
SET(TENSORFLOW_PROTOBUF_LIBRARY ${TENSORFLOW_PROTOBUF_LIBRARY_PATH}/libprotobuf.a)
SET(TENSORFLOW_PROTOBUF_LITE_LIBRARY ${TENSORFLOW_PROTOBUF_LIBRARY_PATH}/libprotobuf-lite.a)
SET(TENSORFLOW_PROTOC_LIBRARY ${TENSORFLOW_PROTOBUF_LIBRARY_PATH}/libprotoc.a)
MESSAGE(STATUS "TENSORFLOW_PROTOBUF_INCLUDE_PATH ${TENSORFLOW_PROTOBUF_INCLUDE_PATH}")
MESSAGE(STATUS "TENSORFLOW_PROTOBUF_LIBRARY_PATH ${TENSORFLOW_PROTOBUF_LIBRARY_PATH}")

SET(TENSORFLOW_NSYNC_INCLUDE_PATH ${CMAKE_SOURCE_DIR}/tensorflow/nsync/include)
SET(TENSORFLOW_NSYNC_LIBRARY_PATH ${CMAKE_SOURCE_DIR}/tensorflow/nsync/lib)
MESSAGE(STATUS "TENSORFLOW_NSYNC_INCLUDE_PATH ${TENSORFLOW_NSYNC_INCLUDE_PATH}")
MESSAGE(STATUS "TENSORFLOW_NSYNC_LIBRARY_PATH ${TENSORFLOW_NSYNC_LIBRARY_PATH}")
SET(TENSORFLOW_NSYNC_LIBRARY ${TENSORFLOW_NSYNC_LIBRARY_PATH}/libnsync.a)

SET(TENSORFLOW_PROTO_INCLUDE_PATH ${CMAKE_SOURCE_DIR}/tensorflow/proto)
SET(TENSORFLOW_PROTO_TEXT_INCLUDE_PATH ${CMAKE_SOURCE_DIR}/tensorflow/proto_text)
SET(TENSORFLOW_HOST_OBJ_INCLUDE_PATH ${CMAKE_SOURCE_DIR}/tensorflow/host_obj)
SET(TENSORFLOW_EIGEN_INCLUDE_PATH ${CMAKE_SOURCE_DIR}/tensorflow/eigen3)
SET(TENSORFLOW_ABSL_INCLUDE_PATH ${CMAKE_SOURCE_DIR}/tensorflow/absl)
SET(TENSORFLOW_THIRD_PARTY_INCLUDE_PATH ${CMAKE_SOURCE_DIR}/tensorflow/tensorflow_third_party)
MESSAGE(STATUS "TENSORFLOW_PROTO_INCLUDE_PATH ${TENSORFLOW_PROTO_INCLUDE_PATH}")
MESSAGE(STATUS "TENSORFLOW_PROTO_TEXT_INCLUDE_PATH ${TENSORFLOW_PROTO_TEXT_INCLUDE_PATH}")
MESSAGE(STATUS "TENSORFLOW_HOST_OBJ_INCLUDE_PATH ${TENSORFLOW_HOST_OBJ_INCLUDE_PATH}")
MESSAGE(STATUS "TENSORFLOW_EIGEN_INCLUDE_PATH ${TENSORFLOW_EIGEN_INCLUDE_PATH}")
MESSAGE(STATUS "TENSORFLOW_ABSL_INCLUDE_PATH ${TENSORFLOW_ABSL_INCLUDE_PATH}")
MESSAGE(STATUS "TENSORFLOW_THIRD_PARTY_INCLUDE_PATH ${TENSORFLOW_THIRD_PARTY_INCLUDE_PATH}")

INCLUDE_DIRECTORIES(${TENSORFLOW_PROTOBUF_INCLUDE_PATH})
INCLUDE_DIRECTORIES(${TENSORFLOW_INCLUDE_PATH})
INCLUDE_DIRECTORIES(${TENSORFLOW_PROTO_INCLUDE_PATH})
INCLUDE_DIRECTORIES(${TENSORFLOW_PROTO_TEXT_INCLUDE_PATH})
INCLUDE_DIRECTORIES(${TENSORFLOW_HOST_OBJ_INCLUDE_PATH})
INCLUDE_DIRECTORIES(${TENSORFLOW_EIGEN_INCLUDE_PATH})
INCLUDE_DIRECTORIES(${TENSORFLOW_ABSL_INCLUDE_PATH})
INCLUDE_DIRECTORIES(${TENSORFLOW_NSYNC_INCLUDE_PATH})
INCLUDE_DIRECTORIES(${TENSORFLOW_THIRD_PARTY_INCLUDE_PATH})

ADD_EXECUTABLE(load_model load_model.cpp)

SET(LOAD_MODEL_LIBRARIES
    ${TENSORFLOW_PROTOBUF_LIBRARY}
    ${TENSORFLOW_PROTOC_LIBRARY}
    ${TENSORFLOW_NSYNC_LIBRARY}
    ${TENSORFLOW_LIBARY})

SET(LDFLAGS "-std=c++11 -msse4.1 -fPIC -O3 -march=native -Wall -finline-functions -undefined dynamic_lookup -all_load")
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}${LDFLAGS}")

MESSAGE(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")
MESSAGE(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
TARGET_LINK_LIBRARIES(load_model ${LOAD_MODEL_LIBRARIES} ${CMAKE_CXX_FLAGS})

编译

mkdir -p build && cd build
cmake ..
make
# 将编译好的二进制文件拷贝到上级目录
cp load_model ..

执行

./load_model

输出如下
pic