import tensorflow as tf import numpy as np class LogisticRegression(object): def __init__(self, in_channels, class_num, use_bias=True): self.input_x = tf.placeholder(tf.float32, [None, in_channels], name='input_x') self.input_y = tf.placeholder(tf.float32, [None, 1], name='input_y') self.out = self.fc_layer(input_x=self.input_x, in_channels=in_channels, class_num=class_num, use_bias=use_bias) with tf.name_scope('loss'): self.losses = tf.losses.mean_squared_error(predictions=self.out, labels=self.input_y) with tf.name_scope('train_op'): self.optimizer = tf.train.GradientDescentOptimizer( learning_rate=0.001) self.train_op = self.optimizer.minimize(self.losses) self.sess = tf.Session() self.graph = tf.get_default_graph() with self.graph.as_default(): with self.sess.as_default(): tf.global_variables_initializer().run() def fc_layer(self, input_x, in_channels, class_num, use_bias=True): with tf.name_scope('fc'): fc_w = tf.Variable(tf.truncated_normal([in_channels, class_num], stddev=0.1), name='weight') if use_bias: fc_b = tf.Variable(tf.constant(0.0, shape=[ class_num, ]), name='bias') fc_out = tf.nn.bias_add(tf.matmul(input_x, fc_w), fc_b) else: fc_out = tf.matmul(input_x, fc_w) return fc_out def to(self, device): pass def trainable_variables(self): return tf.trainable_variables() def state_dict(self): with self.graph.as_default(): with self.sess.as_default(): model_param = list() param_name = list() for var in tf.global_variables(): param = self.graph.get_tensor_by_name(var.name).eval() if 'weight' in var.name: param = np.transpose(param, (1, 0)) model_param.append(param) param_name.append(var.name.split(':')[0].replace("/", '.')) model_dict = {k: v for k, v in zip(param_name, model_param)} return model_dict def load_state_dict(self, model_para, strict=False): with self.graph.as_default(): with self.sess.as_default(): for name in model_para.keys(): new_param = model_para[name] param = self.graph.get_tensor_by_name( name.replace('.', '/') + (':0')) if 'weight' in name: new_param = np.transpose(new_param, (1, 0)) tf.assign(param, new_param).eval()