import MNN.nn as nn
import MNN.cv as cv
import MNN.numpy as np
import MNN.expr as expr
import cv2
# 配置执行后端,线程数,精度等信息;key-vlaue请查看API介绍
config = {}
config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理
config['backend'] = 0 # CPU
config['numThread'] = 4 # 线程数
def decode_text(tokens, vocab, vocab_inp):
##decode trocr
s_start = vocab.get('')
s_end = vocab.get('')
unk = vocab.get('
pad = vocab.get('
text = ''
for tk in tokens:
if tk == s_end:
break
if tk not in [s_end, s_start, pad, unk]:
text += vocab_inp[tk]
return text
def do_norm(x):
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
x = x/255.0
x[ :, :,0] -= mean[0]
x[ :, :,1] -= mean[1]
x[ :, :,2] -= mean[2]
x[ :, :,0] /= std[0]
x[ :, :,1] /= std[1]
x[ :, :,2] /= std[2]
return x
##### 查看图像的输入输出节点名称 #########
# m = expr.load_as_dict(mnn_model_path)
# inputs_outputs = expr.get_inputs_and_outputs(m)
# for key in inputs_outputs[0].keys():
# print('input names:\t', key)
# for key in inputs_outputs[1].keys():
# print('output names:\t', key)
rt = nn.create_runtime_manager((config,))
# 加载模型创建_Module
encoder_model_path = './trocr-chinese-main/onnx/general_en/encoder_model.mnn'
decoder_model_path = './trocr-chinese-main/onnx/general_en/decoder_model.mnn'
image_path = './test/20240808141905.png'
print('image path:', image_path)
encoder_net = nn.load_module_from_file(encoder_model_path, ['pixel_values'], ["last_hidden_state"], runtime_manager=rt)
decoder_net = nn.load_module_from_file(decoder_model_path, ["attention_mask", "encoder_hidden_states", "input_ids"], ["logits"], runtime_manager=rt)
vocab = {'': 0, '': 2, '
vocab_inp = {0: '', 1: '', 3: '
image_data = cv2.imread(image_path)
image_data = image_data[..., ::-1] ##BRG to RGB
image_data = cv2.resize(image_data, (384, 384))
image_data = do_norm(image_data)
image_data = np.transpose(image_data, (2, 0, 1))
input_var = np.expand_dims(image_data, 0)
# print(input_var.shape) #[1, 3, 384, 384]
# NHWC to NC4HW4
input_var = expr.convert(input_var, expr.NC4HW4)
# 执行推理
output_var = encoder_net.forward(input_var)
# print(output_var.shape) #[1, 578, 384]
# NC4HW4 to NHWC
# output_var = expr.convert(output_var, expr.NHWC)
# print(output_var.shape)
ids = [vocab[""]]
mask = [1]
for i in range(50):
input_ids = np.array([ids])
attention_mask = np.array([mask])
output_var_decoder = decoder_net.forward([attention_mask,output_var,input_ids])
# print(output_var_decoder[0][0].shape) #[1, 98]
pred = output_var_decoder[0][0]
pred = pred.argmax(axis=1)
# print(pred)
if pred[-1] == vocab[""]:
break
ids.append(pred[-1])
mask.append(1)
text = decode_text(ids, vocab, vocab_inp)
print(text)