Python- ocr识别模型(MNN模型)预测(trocr多输入)

Python- ocr识别模型(MNN模型)预测(trocr多输入)

import MNN.nn as nn

import MNN.cv as cv

import MNN.numpy as np

import MNN.expr as expr

import cv2

# 配置执行后端,线程数,精度等信息;key-vlaue请查看API介绍

config = {}

config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理

config['backend'] = 0 # CPU

config['numThread'] = 4 # 线程数

def decode_text(tokens, vocab, vocab_inp):

##decode trocr

s_start = vocab.get('')

s_end = vocab.get('')

unk = vocab.get('')

pad = vocab.get('')

text = ''

for tk in tokens:

if tk == s_end:

break

if tk not in [s_end, s_start, pad, unk]:

text += vocab_inp[tk]

return text

def do_norm(x):

mean = [0.5, 0.5, 0.5]

std = [0.5, 0.5, 0.5]

x = x/255.0

x[ :, :,0] -= mean[0]

x[ :, :,1] -= mean[1]

x[ :, :,2] -= mean[2]

x[ :, :,0] /= std[0]

x[ :, :,1] /= std[1]

x[ :, :,2] /= std[2]

return x

##### 查看图像的输入输出节点名称 #########

# m = expr.load_as_dict(mnn_model_path)

# inputs_outputs = expr.get_inputs_and_outputs(m)

# for key in inputs_outputs[0].keys():

# print('input names:\t', key)

# for key in inputs_outputs[1].keys():

# print('output names:\t', key)

rt = nn.create_runtime_manager((config,))

# 加载模型创建_Module

encoder_model_path = './trocr-chinese-main/onnx/general_en/encoder_model.mnn'

decoder_model_path = './trocr-chinese-main/onnx/general_en/decoder_model.mnn'

image_path = './test/20240808141905.png'

print('image path:', image_path)

encoder_net = nn.load_module_from_file(encoder_model_path, ['pixel_values'], ["last_hidden_state"], runtime_manager=rt)

decoder_net = nn.load_module_from_file(decoder_model_path, ["attention_mask", "encoder_hidden_states", "input_ids"], ["logits"], runtime_manager=rt)

vocab = {'': 0, '': 1, '': 2, '': 3, '': 4, '0': 5, '1': 6, '2': 7, '3': 8, '4': 9, '5': 10, '6': 11, '7': 12, ......}

vocab_inp = {0: '', 1: '', 2: '', 3: '', 4: '', 5: '0', 6: '1', 7: '2', 8: '3', 9: '4', 10: '5', 11: '6', 12: '7', ......}

image_data = cv2.imread(image_path)

image_data = image_data[..., ::-1] ##BRG to RGB

image_data = cv2.resize(image_data, (384, 384))

image_data = do_norm(image_data)

image_data = np.transpose(image_data, (2, 0, 1))

input_var = np.expand_dims(image_data, 0)

# print(input_var.shape) #[1, 3, 384, 384]

# NHWC to NC4HW4

input_var = expr.convert(input_var, expr.NC4HW4)

# 执行推理

output_var = encoder_net.forward(input_var)

# print(output_var.shape) #[1, 578, 384]

# NC4HW4 to NHWC

# output_var = expr.convert(output_var, expr.NHWC)

# print(output_var.shape)

ids = [vocab[""]]

mask = [1]

for i in range(50):

input_ids = np.array([ids])

attention_mask = np.array([mask])

output_var_decoder = decoder_net.forward([attention_mask,output_var,input_ids])

# print(output_var_decoder[0][0].shape) #[1, 98]

pred = output_var_decoder[0][0]

pred = pred.argmax(axis=1)

# print(pred)

if pred[-1] == vocab[""]:

break

ids.append(pred[-1])

mask.append(1)

text = decode_text(ids, vocab, vocab_inp)

print(text)

更多创意作品