Hike News
Hike News

深度學習-tensorflow基礎-會話(Session)

Introduction

tensorflow分為前端系統及後端系統

  • 前端系統:定義程序的圖;定義程序的架構
  • 後端系統:運算圖結構

會話

session會去解析使用者定義的graph,並進行運算

  1. 運行圖結構
  2. 分配資源(CPU, GPU)進行計算
  3. 掌握資源(變量, 隊列, 線程…等資源),相當於總管,決定開啟或釋放資源
    • 在tensorflow中是真正的多線程(numpy釋放了GIL)

tf.Session()

運行TensorFlow中操作圖的類,使用默認註冊的圖(但也可指定運行圖)

  • 運行了指定的圖or默認的圖之後,不能去調用其他圖結構的任何op

會話資源

會話可能擁有很多資源(如tf.Variable, tf.QueueBase和tf.ReaderBase等),會話結束後需要進行資源釋放

Method I

  • 須主動close()
1
2
3
4
5
6
7
8
9
10
import tensorflow as tf

a = tf.constant(1.0)
b = tf.constant(2.0)

result = tf.add(a, b)

sess = tf.Session()
print(sess.run(result)) # run方法:啟動整個圖
sess.close() # 釋放資源

Method II 使用上下文管理器(with)

常用且方便,會自動釋放資源

1
2
3
4
5
6
7
8
9
import tensorflow as tf

a = tf.constant(1.0)
b = tf.constant(2.0)

result = tf.add(a, b)

with tf.Session() as sess:
print(sess.run(result))
  • 只要有上下文環境也能使用eval()函數(比起run更為方便)
1
2
with tf.Session() as sess:
print(result.eval())

指定服務器地址運行(target)

tf.Session(target='')

  • target參數留空(預設值),會話將僅使用本地電腦中的設備運行
  • 可以指定 grpc://網址,以便指定Tensorflow服務器的地址
    • 使得會話可以訪問該服務器控制的電腦上的所有設備

指定圖運行

在tf.Session(graph=指定圖)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import tensorflow as tf

g = tf.Graph()
print("g.graph:",g)

with g.as_default():
c = tf.constant(10.0)
print("c.graph in g_graph",c.graph)

graph = tf.get_default_graph()
print("get_default_graphL:",graph)

# 只能運行一個圖
with tf.Session(graph=g) as sess: # 於Session()的graph參數決定運行的圖
print(sess.run(c))
print('sess.graph:',sess.graph)

Result

1
2
3
4
5
g.graph: <tensorflow.python.framework.ops.Graph object at 0x100ff3780>
c.graph in g_graph <tensorflow.python.framework.ops.Graph object at 0x100ff3780>
get_default_graphL: <tensorflow.python.framework.ops.Graph object at 0x102830128>
10.0
sess.graph: <tensorflow.python.framework.ops.Graph object at 0x100ff3780>
  • 預設的圖和指定的圖內存位置不一樣,要是不從graph參數指定圖,默認會運行預設的

查看運行設備情況

Session中的參數config

  • 可知道當前graph所使用的資源及設備
  • 可知道當前graph使用到哪些tensor及operation
  • tf.Session(config=tf.ConfigProto(log_device_placement=True))
1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf

a = tf.constant(1.0)
b = tf.constant(2.0)

result = tf.add(a, b)

with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
print(sess.run(result))
print('result.graph:',result.graph)
print('sess.graph:',sess.graph)

Result

1
2
3
4
5
6
7
Device mapping: no known devices.
Add: (Add): /job:localhost/replica:0/task:0/device:CPU:0
Const: (Const): /job:localhost/replica:0/task:0/device:CPU:0
Const_1: (Const): /job:localhost/replica:0/task:0/device:CPU:0
3.0
result.graph: <tensorflow.python.framework.ops.Graph object at 0x101ebe160>
sess.graph: <tensorflow.python.framework.ops.Graph object at 0x101ebe160>

交互式Session

  • 一般於命令行中進行操作
  • 輸入tf.InteractiveSession()開啟交互式的會話(以含有上下文環境)
  • 可在接下來內文中持續run圖結構;或是使用eval()

Session的run方法

run(fetches, feed_dict=None, graph=None)

  • fetches:嵌套列表、元組、namedtuple、dict或OrderedDict,運行operator和計算tensor(重載的運算符也能運行)

    • 主要參數,要運行的Tensor 或是 Operation
      • 不能是一般float或是int類型,會報錯
        如:(報錯)
        1
        2
        3
        4
        5
        var1 = 2.0
        var2 = 3.0
        sum = var1 + var2
        with tf.Session() as sess:
        print(sess.run(sum))
    • 若要是同時運行多個物件則要放入一個list,如[a,b,result]
    • 有重載機制
      • 當operation與一般類型進行計算時,預設會重載成operation類型
        如:(正常執行)
        1
        2
        3
        4
        5
        var1 = 2.0
        a = tf.constant(3.0)
        sum = var1 + a
        with tf.Session() as sess:
        print(sess.run(sum))
  • feed_dict:允許調用者覆蓋graph中指定張量的值,結合placeholder使用

    • 在程序執行的時候,不確定輸入的是甚麼,提前佔個位,再用feed_dict指定參數
    • placeholder提供佔位符
      • 提供相對應數據shape大小的空間,但沒有具體數據
      • tf.placeholder(dtype, shape=None, name=None)
        • dtype:為數據的類型
        • shape:輸入數據的shape
        • name:暫不討論,於tensorboard中使用
    • 在訓練模型時,傳進去的樣本數不一定都為固定值,因此需實時的提供數據去進行訓練
    • 接收的為一個字典
      • 參數fetches所接收的tf.placeholder變量

Example I 固定筆數的數據

1
2
3
4
5
6
7
8
import tensorflow as tf
import numpy as np

plt = tf.placeholder(tf.float32,shape=[2,3]) #shape為兩行三列的數據

with tf.Session() as sess:
data = np.array([[1,2,3],[4,5,6]])
print(sess.run(plt, feed_dict={plt:data}))

Result

1
2
[[1. 2. 3.]
[4. 5. 6.]]

Example II 不知道筆數的數據

  • 往往實際的情況是,在固定特徵數的條件下,我們不知道樣本的數目
1
2
3
4
5
6
7
8
9
10
import tensorflow as tf
import numpy as np

data = np.array([[1,2,3],[4,5,6],range(3),range(4,7),range(9,12)])

plt = tf.placeholder(tf.float32,shape=[None,3]) #shape不知道為多少列,但知道多少行(特徵)
print(plt)

with tf.Session() as sess:
print(sess.run(plt, feed_dict={plt:data}))

Result

1
2
3
4
5
6
Tensor("Placeholder:0", shape=(?, 3), dtype=float32)
[[ 1. 2. 3.]
[ 4. 5. 6.]
[ 0. 1. 2.]
[ 4. 5. 6.]
[ 9. 10. 11.]]

Session異常

返回值異常案例:

  • RuntimeError:如果它Session處於無效狀態(例如已關閉)
  • TypeError:如果fetches或feed_dict鍵不是合適類型
  • ValueError:如果fetches或feed_dict鍵無效,或引用Tensor不存在