AI开发平台ModelArts-TensorFlow 2.1:推理代码

时间:2023-11-01 16:20:34

推理代码

在模型代码推理文件customize_service.py中,需要添加一个子类,该子类继承对应模型类型的父类,各模型类型的父类名称和导入语句如请参考表1

import loggingimport threadingimport numpy as npimport tensorflow as tffrom PIL import Imagefrom model_service.tfserving_model_service import TfServingBaseServicelogger = logging.getLogger()logger.setLevel(logging.INFO)class MnistService(TfServingBaseService):    def __init__(self, model_name, model_path):        self.model_name = model_name        self.model_path = model_path        self.model = None        self.predict = None        # label文件可以在这里加载,在后处理函数里使用        # label.txt放在obs和模型包的目录        # with open(os.path.join(self.model_path, 'label.txt')) as f:        #     self.label = json.load(f)        # 非阻塞方式加载saved_model模型,防止阻塞超时        thread = threading.Thread(target=self.load_model)        thread.start()    def load_model(self):        # load saved_model 格式的模型        self.model = tf.saved_model.load(self.model_path)        signature_defs = self.model.signatures.keys()        signature = []        # only one signature allowed        for signature_def in signature_defs:            signature.append(signature_def)        if len(signature) == 1:            model_signature = signature[0]        else:            logging.warning("signatures more than one, use serving_default signature from %s", signature)            model_signature = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY        self.predict = self.model.signatures[model_signature]    def _preprocess(self, data):        images = []        for k, v in data.items():            for file_name, file_content in v.items():                image1 = Image.open(file_content)                image1 = np.array(image1, dtype=np.float32)                image1.resize((28, 28, 1))                images.append(image1)        images = tf.convert_to_tensor(images, dtype=tf.dtypes.float32)        preprocessed_data = images        return preprocessed_data    def _inference(self, data):        return self.predict(data)    def _postprocess(self, data):        return {            "result": int(data["output"].numpy()[0].argmax())        }
support.huaweicloud.com/inference-modelarts/inference-modelarts-0080.html