82 lines
3.0 KiB
Python
82 lines
3.0 KiB
Python
import tensorflow as tf
|
|
import numpy as np
|
|
|
|
|
|
class LogisticRegression(object):
|
|
def __init__(self, in_channels, class_num, use_bias=True):
|
|
|
|
self.input_x = tf.placeholder(tf.float32, [None, in_channels],
|
|
name='input_x')
|
|
self.input_y = tf.placeholder(tf.float32, [None, 1], name='input_y')
|
|
|
|
self.out = self.fc_layer(input_x=self.input_x,
|
|
in_channels=in_channels,
|
|
class_num=class_num,
|
|
use_bias=use_bias)
|
|
|
|
with tf.name_scope('loss'):
|
|
self.losses = tf.losses.mean_squared_error(predictions=self.out,
|
|
labels=self.input_y)
|
|
|
|
with tf.name_scope('train_op'):
|
|
self.optimizer = tf.train.GradientDescentOptimizer(
|
|
learning_rate=0.001)
|
|
self.train_op = self.optimizer.minimize(self.losses)
|
|
|
|
self.sess = tf.Session()
|
|
self.graph = tf.get_default_graph()
|
|
|
|
with self.graph.as_default():
|
|
with self.sess.as_default():
|
|
tf.global_variables_initializer().run()
|
|
|
|
def fc_layer(self, input_x, in_channels, class_num, use_bias=True):
|
|
with tf.name_scope('fc'):
|
|
fc_w = tf.Variable(tf.truncated_normal([in_channels, class_num],
|
|
stddev=0.1),
|
|
name='weight')
|
|
if use_bias:
|
|
fc_b = tf.Variable(tf.constant(0.0, shape=[
|
|
class_num,
|
|
]),
|
|
name='bias')
|
|
fc_out = tf.nn.bias_add(tf.matmul(input_x, fc_w), fc_b)
|
|
else:
|
|
fc_out = tf.matmul(input_x, fc_w)
|
|
|
|
return fc_out
|
|
|
|
def to(self, device):
|
|
pass
|
|
|
|
def trainable_variables(self):
|
|
return tf.trainable_variables()
|
|
|
|
def state_dict(self):
|
|
with self.graph.as_default():
|
|
with self.sess.as_default():
|
|
model_param = list()
|
|
param_name = list()
|
|
for var in tf.global_variables():
|
|
param = self.graph.get_tensor_by_name(var.name).eval()
|
|
if 'weight' in var.name:
|
|
param = np.transpose(param, (1, 0))
|
|
model_param.append(param)
|
|
param_name.append(var.name.split(':')[0].replace("/", '.'))
|
|
|
|
model_dict = {k: v for k, v in zip(param_name, model_param)}
|
|
|
|
return model_dict
|
|
|
|
def load_state_dict(self, model_para, strict=False):
|
|
with self.graph.as_default():
|
|
with self.sess.as_default():
|
|
for name in model_para.keys():
|
|
new_param = model_para[name]
|
|
|
|
param = self.graph.get_tensor_by_name(
|
|
name.replace('.', '/') + (':0'))
|
|
if 'weight' in name:
|
|
new_param = np.transpose(new_param, (1, 0))
|
|
tf.assign(param, new_param).eval()
|