欢迎访问 生活随笔!

ag凯发k8国际

当前位置: ag凯发k8国际 > 人工智能 > pytorch >内容正文

pytorch

深度学习模型保存-ag凯发k8国际

发布时间:2024/10/14 pytorch 34 豆豆
ag凯发k8国际 收集整理的这篇文章主要介绍了 深度学习模型保存_web服务部署深度学习模型 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

本文的目的是介绍如何使用web服务快速部署深度学习模型,虽然tf有tfserving可以进行模型部署,但是对于pytorch无能为力(如果要使用的话需要把torch模型进行转换,有些麻烦);因此,本文在这里介绍一种使用web服务部署深度学习的方法(简单有效,不喜勿喷)。

本文以简单的新闻分类模型来举例,模型:bert;数据来源:清华新闻语料(地址:

thuctc: 一个高效的中文文本分类工具),清华新闻语料共有14个类别,分别是体育,娱乐,家居,彩票,房产,教育,时尚,时政,星座,游戏,社会,科技,股票和财经。为了快速训练模型,本人在每个类别中分别随机挑选1000个作为训练集,200个作为验证集。数据预处理、模型训练和pb模型保存代码见:新闻分类模型训练github地址。(非重点,不过多介绍了,github上有详细的使用说明,有问题可留言。)

为了使web服务部署变得简洁,因此本人构造一个方法类,方便加载pb模型,对传入文本进行数据预处理以及进行模型预测。

模型初始化代码如下:

import bert_tokenization import tensorflow as tf from tensorflow.python.platform import gfile import numpy as np import osclass classificationmodel(object):def __init__(self):self.tokenizer = noneself.sess = noneself.is_train = noneself.input_ids = noneself.input_mask = noneself.segment_ids = noneself.predictions = noneself.max_seq_length = noneself.label_dict = ['体育', '娱乐', '家居', '彩票', '房产', '教育', '时尚', '时政', '星座', '游戏', '社会', '科技', '股票', '财经']

其中,tokenizer 为分词器;sess为tf的session模块;is_train、input_ids、input_mask和segment_ids分别是pb模型的输入;predictions为pb模型的输出;max_seq_length为模型的最大输入长度;label_dict为新闻分类标签。

加载pb模型代码如下:

def load_model(self, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):os.environ['cuda_device_order'] = 'pci_bus_id'os.environ['cuda_visible_devices'] = gpu_idself.tokenizer = bert_tokenization.fulltokenizer(vocab_file=vocab_file, do_lower_case=true)gpu_options = tf.gpuoptions(per_process_gpu_memory_fraction=gpu_memory_fraction)sess_config = tf.configproto(gpu_options=gpu_options)self.sess = tf.session(config=sess_config)with gfile.fastgfile(model_path, "rb") as f:graph_def = tf.graphdef()graph_def.parsefromstring(f.read())self.sess.graph.as_default()tf.import_graph_def(graph_def, name="")self.sess.run(tf.global_variables_initializer())self.is_train = self.sess.graph.get_tensor_by_name("input/is_train:0")self.input_ids = self.sess.graph.get_tensor_by_name("input/input_ids:0")self.input_mask = self.sess.graph.get_tensor_by_name("input/input_mask:0")self.segment_ids = self.sess.graph.get_tensor_by_name("input/segment_ids:0")self.predictions = self.sess.graph.get_tensor_by_name("output_layer/predictions:0")self.max_seq_length = max_seq_length

其中,gpu_id为使用gpu的序号;vocab_file为bert模型所使用的字典路径;gpu_memory_fraction为使用gpu时所占用的比例;model_path为pb模型的路径;max_seq_length为bert模型的最大长度。

将传入文本转化成模型所需格式代码如下:

def convert_fearture(self, text):max_seq_length = self.max_seq_lengthmax_length_context = max_seq_length - 2content_token = self.tokenizer.tokenize(text)if len(content_token) > max_length_context:content_token = content_token[:max_length_context]tokens = []segment_ids = []tokens.append("[cls]")segment_ids.append(0)for token in content_token:tokens.append(token)segment_ids.append(0)tokens.append("[sep]")segment_ids.append(0)input_ids = self.tokenizer.convert_tokens_to_ids(tokens)input_mask = [1] * len(input_ids)while len(input_ids) < max_seq_length:input_ids.append(0)input_mask.append(0)segment_ids.append(0)assert len(input_ids) == max_seq_lengthassert len(input_mask) == max_seq_lengthassert len(segment_ids) == max_seq_lengthinput_ids = np.array(input_ids)input_mask = np.array(input_mask)segment_ids = np.array(segment_ids)return input_ids, input_mask, segment_ids

预测代码如下:

def predict(self, text):input_ids_temp, input_mask_temp, segment_ids_temp = self.convert_fearture(text)feed = {self.is_train: false,self.input_ids: input_ids_temp.reshape(1, self.max_seq_length),self.input_mask: input_mask_temp.reshape(1, self.max_seq_length),self.segment_ids: segment_ids_temp.reshape(1, self.max_seq_length)}[label] = self.sess.run([self.predictions], feed)label_name = self.label_dict[label[0]]return label[0], label_name

其中,输入是一个新闻文本,输出为类别序号以及对应的标签名称。详细完整代码见github:

classificationmodel.py文件。


(划重点)上面介绍的都是如何方便简洁地加载模型,下面开始使用web服务挂起模型。通俗地讲,其实本人就是通过flask框架,搭建了一个web服务,来获取外部的输入;并且使用挂载的模型进行预测;最后将预测结果通过web服务传出。

from gevent import monkey monkey.patch_all() from flask import flask, request from gevent import wsgi import json from classificationmodel import classificationmodeldef start_sever(http_id, port, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):model = classificationmodel()model.load_model(gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length)print("load model ending!")app = flask(__name__)@app.route('/')def index():return "this is news classification model server"@app.route('/news-classification', methods=['get', 'post'])def response_request():if request.method == 'post':text = request.form.get('text')else:text = request.args.get('text')label, label_name = model.predict(text)d = {"label": str(label), "label_name": label_name}print(d)return json.dumps(d, ensure_ascii=false)server = wsgi.wsgiserver((str(http_id), port), app)server.serve_forever()

其中,http_id为web服务的地址;port为端口号;gpu_id、vocab_file、gpu_memory_fraction、model_path和max_seq_length为上面介绍的加载模型所需要的参数,详细见上文。

index函数用于检验web服务是否畅通。如图1所示。

图1

response_request函数为响应函数。定义了两种请求数据的方式,get和post。当使用get方法获取web输入时,获取命令为request.args.get('text');当使用post方法获取web输入时,获取命令为request.form.get('text')。

当web服务起起来之后,就可以调用啦!!!

浏览器调用如图2所示。

图2

code调用如下:

import requestsdef http_test(text):url = 'http://127.0.0.1:5555/news-classification'raw_data = {'text': text}res = requests.post(url, raw_data)result = res.json()return resultif __name__ == "__main__":text = "姚明在nba打球,很强。"result = http_test(text)print(result["label_name"])

以上就是通过web服务部署深度学习模型的全部内容,喜欢的同学还请多多点赞~~~~~


推荐几篇本人之前写的一些文章:

刘聪nlp:短文本相似度算法研究

刘聪nlp:阅读笔记:开放域检索问答(orqa)

刘聪nlp:论文阅读笔记:文本蕴含之bimpm

喜欢的同学,可以关注一下专栏,关注一下作者,还请多多点赞~~~~~~

总结

以上是ag凯发k8国际为你收集整理的深度学习模型保存_web服务部署深度学习模型的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得ag凯发k8国际网站内容还不错,欢迎将ag凯发k8国际推荐给好友。

  • 上一篇:
  • 下一篇:
网站地图