Caffe常用Python接口用法总结
概述
根据Caffe的日常使用情况,对常用的python接口进行了总结整理,以方便后续科研工作中查阅。
Caffe初始化设置
# ------------- 设置计算设备 ------------- #
caffe.set_mode_gpu() # 使用GPU,注意默认使用CPU,要使用GPU就不能缺少这一句
caffe.set_mode_cpu() # 使用CPU
caffe.set_device(device_id) # 不设置默认为0
Model训练
# --------------- Solver --------------- #
# ------- 加载Solver,有2种方法 -------
# 1. 无论Solver类型是什么,统一设置为SGD
solver = caffe.SGDSolver('path/to/solver.prototxt')
# 2. 根据solver.prototxt中指定的Solver类型读取,默认为SGD
solver = caffe.get_solver('path/to/solver.prototxt')
# ------- 前向传播 -------
solver.net.forward() # train net forward
solver.test_nets[0].forward() # test net forward, test net允许有多个(train net只能有1个)
# ------- 反向传播 -------
solver.net.backward()
# ------- 模型训练 -------
solver.step(n) # 模型进行n次forward和backward,完成n次训练
solver.solver() # 模型根据solver.prototxt中的设置,进行完整模型训练
# ------- 模型保存 -------
solver.net.save('name.caffemodel')
加载已训练Model
# ------- 从已有caffemodel中加载 -------
net = caffe.Net(
deploy_prototxt_path, # deploy网络定义prototxt文件
caffe_model_path, # 已训练模型的caffemodel文件
caffe.TEST # phase设置为Test
)
# ------- 从已有solverstate中恢复训练 -------
solver.restore('path/to/solver.solverstate')
net2.copy_from('path/to/net.caffemodel')
# ------- 从已有Net对象中共享得到 -------
net2.share_with(net1) # net2共享net1的权重(权重指针指向同一地址)
数据预处理
# ------- 读取均值文件 -------
mean_blob = caffe.proto.caffe_pb2.BlobProto()
mean_blob.ParseFromString(open('mean.binaryproto','rb').read())
mean_npy = caffe.io.blobproto_to_array(mean_blob)
# ------- 图片预处理 -------
from scipy.misc import imread, imresize
img = imread(filename)
img = imresize(img, dst_size)
if len(img.shape) == 3:
mean_val = np.mean(mean_npy, axis=(2,3))
img -= mean_val[:,np.newaxis,np.newaxis]*np.ones(img.shape) # ndarray broadcasting
img_batch = np.array([img.transpose((2,0,1))])
img_batch.astype(np.float32)
elif len(img.shape) == 2:
mean_val = np.mean(mean_npy)
img -= mean_val
img_batch = np.array([img[np.newaxis,:,:]])
img_batch.astype(np.float32)
net.blobs['data'].data[...] = img_batch
访问网络
# ------- 网络参数 -------
conv1_weight = net.params['conv1'][0].data # type: np.ndarray
conv1_bias = net.params['conv1'][1].data
# ------- 网络参数的梯度 -------
conv1_weight = net.params['conv1'][0].diff # type: np.ndarray
conv1_bias = net.params['conv1'][1].diff
# ------- feature map -------
conv1_feat = net.blobs['conv1'].data[0] # type: np.ndarray
conv1_feat = net.blobs['conv1'].diff[0] # 梯度
for layer_name, param in net.params.items():
print layer_name + '\t' + str(param[0].data.shape), str(param[1].data.shape)
for layer_name, blob in net.blobs.items():
print layer_name + '\t' + str(blob.data.shape)