Tensorflow로 선형 회귀 모델(linear regression model) 구현하기

실습 환경은 Python 3.6과 tensorflow 1.5이다.

 

선형 회귀(linear regression)란?

사진 출처: 위키백과

선형 회귀란, x와 y의 값(데이터)가 주어졌을 때, 두 데이터 간 관계의 규칙성을 찾아 새로운 x값(input)에 대해 적절한 y값(output)을 예측하는 기법이다.

 

선행 개념

  • 플레이스홀더
  • 세션과 지연 실행

앞의 두 가지 개념을 아직 들어본 적이 없다면 아래의 게시물을 참고하자.

2020/07/14 - [개발 일지/ML] - 텐서(tensor), 플레이스홀더(placeholder)와 변수

 

 

1. 데이터셋 설정

# y = 2x 의 선형 그래프
x_data = [1, 2, 3]
y_data = [2, 4, 6]

다음과 같이 데이터셋을 설정하였다.

 

2. 가중치와 편향 탐색

# 균등분포(uniform distribution)을 가진 무작위 변수
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.random_uniform([1], -1.0, 1.0))

W는 가중치, b는 편향에 해당한다. input을 통해 적절한 output을 도출할 W와 b를 찾아내는 것이 선형 회귀 기법의 목적이기도 하다.

 

3. 플레이스홀더 설정

# X, Y 라는 이름의 플레이스홀더 선언
X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32, name='Y')

플레이스홀더를 선언할 때 name attribute를 지정해주면, Tensor 형태로 출력될 때 이름이 부여된다.

설정한 경우: Tensor("X:0", dtype=float32)
설정하지 않은 경우: Tensor("Placeholder:0", dtype=float32)

 

4. 상관관계 분석을 위한 수식 작성

# 선형 관계 분석을 위한 수식
# hypothesis: Y(함수의 치역)에 해당
hypothesis = W * X + b

 

5. 손실 함수 작성

# 비용(cost) 계산
cost = tf.reduce_mean(tf.square(hypothesis - Y))

손실 함수(loss function)란, 한 쌍의 데이터에 대한 손실값을 계산하는 함수이다.

주로 예측값(hypothesis)과 실제 값(Y) 간의 거리를 손실 비용으로 설정한다.

모델은 반복적으로 학습을 거쳐 이 손실을 최소화하는 W와 b를 도출한다.

tf.reduce_mean 함수는 모든 데이터에 대해 손실 간 거리를 구하고 제곱한 뒤 평균을 도출한다.

 

6. 경사하강법 최적화 함수를 이용한 손실값 최소화

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(cost)

최적화 과정에서는 앞서 설명했듯 가중치와 편향 값(W와 b)을 변경해가며 손실값을 최소화하는 W와 b를 찾는다.

여기에서는 함수의 기울기를 구하고 기울기가 낮은 쪽으로 이동하며 최적의 값을 탐색하는 경사하강법을 이용하였다.

 

7. 대망의 학습

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # 학습
    # 학습 횟수: 100번
    for step in range(100):
        _, cost_val = sess.run([train_op, cost], feed_dict={X: x_data, Y: y_data})
        print(step, cost_val, sess.run(W), sess.run(b))

    print("***결과***")
    print("X: {}, Y: {}".format(5, sess.run(hypothesis, feed_dict={X: 5})))
    print("X: {}, Y: {}".format(2.5, sess.run(hypothesis, feed_dict={X: 2.5})))

with절을 이용함으로써 Session의 생성과 종료를 간편하게 처리할 수 있다. 코드가 with절을 벗어나는 순간 Session이 자동으로 종료되기 때문이다.

주어진 데이터셋을 통해 100번의 학습을 수행하고, 결과를 확인하기 위해 X 값으로 5와 2.5를 입력하였다.

 

결과는 다음과 같다.

X: 5, Y: [9.822843]
X: 2.5, Y: [4.985267]

y = 2x 그래프와 아주 근접한 결과를 도출해내는 것을 확인할 수 있다!

 

+ 반복적인 학습이 이루어지며 cost의 값이 점점 작아지는 것을 확인할 수 있었다.

학습 결과에 대한 출력을 확인하려면 아래 문장을 누르면 된다.

더보기

0 31.929964 [1.5158815] [1.7476043]
1 0.7636607 [1.2686837] [1.5917307]
2 0.37321535 [1.3145533] [1.565911]
3 0.3512566 [1.327939] [1.5269076]
4 0.33452117 [1.344433] [1.4903505]
5 0.31863073 [1.3601553] [1.4545072]
6 0.30349526 [1.3755407] [1.4195436]
7 0.28907922 [1.3905519] [1.3854185]
8 0.27534777 [1.4052027] [1.3521141]
9 0.26226842 [1.4195012] [1.3196101]
10 0.24981058 [1.4334561] [1.2878876]
11 0.23794417 [1.4470754] [1.2569276]
12 0.22664173 [1.4603673] [1.226712]
13 0.21587622 [1.4733397] [1.1972227]
14 0.20562176 [1.4860003] [1.1684424]
15 0.19585471 [1.4983563] [1.1403538]
16 0.1865514 [1.5104156] [1.1129405]
17 0.17769022 [1.5221848] [1.0861862]
18 0.16924979 [1.5336711] [1.060075]
19 0.16121037 [1.5448815] [1.0345916]
20 0.15355256 [1.5558221] [1.0097207]
21 0.14625876 [1.5664998] [0.9854477]
22 0.13931133 [1.5769209] [0.9617582]
23 0.13269405 [1.5870914] [0.9386382]
24 0.12639092 [1.5970175] [0.916074]
25 0.12038725 [1.606705] [0.89405215]
26 0.114668764 [1.6161594] [0.8725597]
27 0.10922194 [1.6253868] [0.851584]
28 0.1040338 [1.6343921] [0.83111244]
29 0.099092156 [1.6431812] [0.8111331]
30 0.09438518 [1.6517589] [0.791634]
31 0.089901835 [1.6601304] [0.77260363]
32 0.08563141 [1.6683006] [0.75403076]
33 0.0815638 [1.6762744] [0.7359044]
34 0.07768948 [1.6840565] [0.71821374]
35 0.07399917 [1.6916516] [0.70094836]
36 0.07048416 [1.6990641] [0.68409806]
37 0.0671361 [1.7062984] [0.6676528]
38 0.06394712 [1.7133589] [0.6516029]
39 0.06090958 [1.7202495] [0.6359388]
40 0.058016326 [1.7269745] [0.62065125]
41 0.05526048 [1.7335378] [0.6057312]
42 0.05263556 [1.7399434] [0.59116983]
43 0.050135363 [1.746195] [0.57695854]
44 0.04775384 [1.7522962] [0.56308883]
45 0.045485567 [1.758251] [0.54955256]
46 0.04332499 [1.7640624] [0.53634167]
47 0.04126696 [1.7697341] [0.52344835]
48 0.03930674 [1.7752696] [0.51086503]
49 0.037439626 [1.780672] [0.49858415]
50 0.03566123 [1.7859445] [0.48659855]
51 0.03396729 [1.7910903] [0.47490108]
52 0.032353777 [1.7961122] [0.46348473]
53 0.03081696 [1.8010136] [0.45234293]
54 0.029353136 [1.8057971] [0.4414689]
55 0.02795889 [1.8104657] [0.4308563]
56 0.026630832 [1.8150219] [0.42049876]
57 0.025365809 [1.8194686] [0.41039026]
58 0.024160897 [1.8238084] [0.40052474]
59 0.023013303 [1.828044] [0.39089644]
60 0.021920117 [1.8321776] [0.38149953]
61 0.020878866 [1.836212] [0.37232858]
62 0.019887133 [1.8401494] [0.36337805]
63 0.018942477 [1.8439921] [0.3546427]
64 0.018042697 [1.8477424] [0.3461173]
65 0.017185653 [1.8514025] [0.33779684]
66 0.01636932 [1.8549747] [0.32967645]
67 0.015591749 [1.858461] [0.32175124]
68 0.014851148 [1.8618636] [0.31401658]
69 0.01414568 [1.8651843] [0.30646783]
70 0.013473757 [1.8684251] [0.29910052]
71 0.012833789 [1.8715882] [0.29191038]
72 0.012224153 [1.8746752] [0.28489304]
73 0.01164351 [1.8776878] [0.27804434]
74 0.011090425 [1.8806281] [0.27136037]
75 0.010563609 [1.8834977] [0.26483706]
76 0.010061818 [1.8862983] [0.25847054]
77 0.009583862 [1.8890316] [0.25225714]
78 0.009128624 [1.8916992] [0.246193]
79 0.008695045 [1.8943027] [0.24027473]
80 0.008282009 [1.8968437] [0.23449868]
81 0.007888626 [1.8993236] [0.22886151]
82 0.007513887 [1.9017437] [0.22335978]
83 0.0071569937 [1.9041058] [0.2179904]
84 0.0068169995 [1.9064108] [0.21274999]
85 0.0064932047 [1.9086607] [0.20763564]
86 0.006184763 [1.9108565] [0.20264427]
87 0.005890979 [1.9129993] [0.1977728]
88 0.005611153 [1.9150908] [0.19301851]
89 0.005344613 [1.917132] [0.18837848]
90 0.005090752 [1.9191241] [0.18384998]
91 0.0048489454 [1.9210683] [0.17943035]
92 0.004618598 [1.9229658] [0.17511694]
93 0.004399232 [1.9248177] [0.17090724]
94 0.00419025 [1.926625] [0.16679874]
95 0.0039912295 [1.928389] [0.16278902]
96 0.0038016343 [1.9301103] [0.15887564]
97 0.0036210532 [1.9317905] [0.1550564]
98 0.0034490328 [1.9334301] [0.15132894]
99 0.0032852108 [1.9350305] [0.14769113]