背景
最近在工作中遇到了TensorFlow,需要用go调用处理好的模型。因为在上家公司也有用到过TensorFlow,当时用来根据天气情况预测光伏发电的趋势,整体来说还是不错的,不过版本更新太快了,都已经更新到2.x版本了,本来以为没啥问题,结果还是想的太简单了。在这记录一下使用的过程和遇到的一些坑吧。
TensorFlow
简单介绍一下Tensorflow吧,TensorFlow 是由 Google Brain 团队开发的一个开源机器学习框架。它可用于构建各种人工智能应用程序,包括图像和语音识别、自然语言处理、推荐系统等等。TensorFlow 支持多种编程语言,包括 Python、C++、Java、Go 等等,并且可以在多种计算平台上运行,包括 CPU、GPU、TPU 等等。
TensorFlow 的核心是一个数据流图计算引擎,它使用数据流图模型来表示计算任务,其中节点表示操作,边表示数据流。TensorFlow 提供了丰富的高级接口和工具来简化机器学习任务的开发过程,包括 Keras、Estimator、TensorBoard 等等。同时,它也允许用户使用低级接口和自定义操作来实现更高级别的功能和更灵活的控制。
由于 TensorFlow 具有开源、跨平台、高度灵活和强大的功能等优点,因此它已成为最受欢迎的机器学习框架之一,被广泛应用于学术研究和商业开发领域。感兴趣的同学可以去官网学习一下。
使用
在这里我们只介绍golang对于tensorflwo的使用。通过官网我们了解到,对于TensorFlow2.x版本已经不支持go的api了,对于go的支持最高到1.15.5,并且在官方文档安装Go版TensorFlow这一节已经提示我们:
注意:TensorFlow Go API 不在 TensorFlow API 稳定性保障的涵盖范围内。
这就很头疼了,不过也不能官方说啥就是啥吧,硬着头皮也得上。
使用go的api需要先安装tensorflow的C库,具体安装请参考官方文档。当我们一切准备就绪的时候,使用官方提供的库进行测试:
1go get github.com/tensorflow/tensorflow/tensorflow/go
2
3go test github.com/tensorflow/tensorflow/tensorflow/go
得到了这样的结果
搜了一下确实是因为缺少对2.x版本的支持导致的。官方🐮🍺,诚不欺我。
不过工作还得继续,在pkg.go.dev
上找到了一个库:
1go get github.com/galeone/tensorflow/tensorflow/go
使用起来跟官方很像,我们的需求是加载已经保存好的模型并进行预测得到结果。
1model, err := tf.LoadSavedModel("./model", []string{"serve"}, nil)
2if err != nil {
3 panic(err)
4}
5
6inputTensor, _ := tf.NewTensor(black)
7_, err = model.Session.Run(
8 map[tf.Output]*tf.Tensor{
9 model.Graph.Operation("serving_default_input_1").Output(0): inputTensor,
10 },
11 []tf.Output{model.Graph.Operation("StatefulPartitionedCall").Output(0)},
12 nil)
13if err != nil {
14 panic(err)
15}
首先我们调用LoadSavedModel
方法加载模型,然后直接调用Run方法即可。需要注意⚠️的是,我们要填写正确模型的输入和输出的名称,如果跟保存的模型对不上的话会报错。如果你不知道模型名称,可以运行下面的命令:
1saved_model_cli show --dir modelpath --tag_set serve --signature_def serving_default
因为tensorflow需要一些性能和环境,所以我直接将代码放到了自己租的一台服务器上,运行竟然成功了。
从运行结果看,24万的数据从模型加载到预测用了35s左右,租用的服务器一共12个核全部占满了,如果不加限制的话,在真实的线上环境很可能会导致我们其他功能不正常,所以需要对cpu进行限制。
通过询问ChatGPT我得到了如下结果:
实践的过程发现Config
类型根本不匹配,翻阅了源码发现关于运行的配置在https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto这个文件内
在跟进这个库的源码发现在protobuf目录下存在config.pb.go文件,文件中有个ConfigProto
结构体,正好能够跟上面网站上对应上,因此大胆猜测这个结构就是控制运行参数的。
我们将这个包导入,并修改模型的加载:
1import tp "github.com/galeone/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"
1var cp tp.ConfigProto
2cp.InterOpParallelismThreads = 2
3cp.IntraOpParallelismThreads = 2
4cp.UsePerSessionThreads = true
5bytes, _ := proto.Marshal(&cp)
6model, err := tf.LoadSavedModel("./model", []string{"serve"}, &tf.SessionOptions{
7 Config: bytes,
8})
9if err != nil {
10 panic(err)
11}
运行再次查看cpu的使用,确实成功降到了2个核。
总结
-
TensorFlow Go Api官方只支持2.0版本以下
-
推荐使用github.com/galeone/tensorflow/tensorflow/go库
-
限制cpu占用可以使用github.com/galeone/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto来配置
注: TensorFlow 版本需要的 glibc 版本如下:
TensorFlow GLIBC >= 2.0 >= 2.3 >= 1.5 >= 2.17 <= 1.4 >= 2.12 如果版本不匹配可能会报:
/lib64/libc.so.6: version `GLIBC_2.34’ not found