摘要:剛剛舉行的深度學(xué)習(xí)開(kāi)發(fā)者峰會(huì)上,發(fā)布了版本,這一版新增了等一系列并行算法。專注于游戲智能少兒趣味編程兩大領(lǐng)域。有了貝爾曼最優(yōu)方程,我們就可以通過(guò)純粹貪心的策略來(lái)確定,即僅僅把最優(yōu)動(dòng)作的概率設(shè)置為,其他所有非最優(yōu)動(dòng)作的概率都設(shè)置為。
剛剛舉行的 WAVE SUMMIT 2019 深度學(xué)習(xí)開(kāi)發(fā)者峰會(huì)上,PaddlePaddle 發(fā)布了 PARL 1.1 版本,這一版新增了 IMPALA、A3C、A2C 等一系列并行算法。作者重新測(cè)試了一遍內(nèi)置 example,發(fā)現(xiàn)卷積速度也明顯加快,從 1.0 版本的訓(xùn)練一幀需大約 1 秒優(yōu)化到了 0.15 秒(配置:win8,i5-6200U,GeForce-940M,batch-size=32)。
嘿嘿,超級(jí)本實(shí)現(xiàn)游戲智能的時(shí)代終于來(lái)臨!廢話不多說(shuō),我們趕緊試試 PARL 的官方 DQN 算法,玩一玩 Flappy-Bird。
關(guān)于作者:曹天明(kosora),2011 年畢業(yè)于天津科技大學(xué),7 年的 PHP+Java 經(jīng)驗(yàn)。于2018年9月報(bào)名加入光環(huán)國(guó)際人工智能周末轉(zhuǎn)型班進(jìn)行學(xué)習(xí)提升。個(gè)人研究方向——融合 CLRS 與 DRL 兩大技術(shù)體系,并行刷題和模型訓(xùn)練。專注于游戲智能、少兒趣味編程兩大領(lǐng)域。
模擬環(huán)境
相信大家對(duì)于這個(gè)游戲并不陌生,我們需要控制一只小鳥(niǎo)向前飛行,只有飛翔、下落兩種操作,小鳥(niǎo)每穿過(guò)一根柱子,總分就會(huì)增加。由于柱子是高低不平的,所以需要想盡辦法躲避它們。一旦碰到了柱子,或者碰到了上、下邊緣,都會(huì)導(dǎo)致 game-over。下圖展示了未經(jīng)訓(xùn)練的小笨鳥(niǎo),可以看到,他處于人工智障的狀態(tài),經(jīng)常撞柱子或者撞草地:
▲ 未經(jīng)訓(xùn)練的小笨鳥(niǎo)
先簡(jiǎn)要分析一下環(huán)境 Environment 的主要代碼。
BirdEnv.py 繼承自 gym.Env,實(shí)現(xiàn)了 init、reset、reward、render 等標(biāo)準(zhǔn)接口。init 函數(shù),用于加載圖片、聲音等外部文件,并初始化得分、小鳥(niǎo)位置、上下邊緣、水管位置等環(huán)境信息:
def __init__(self):
if not hasattr(self,"IMAGES"): print("InitGame!") self.beforeInit() self.score = self.playerIndex = self.loopIter = 0 self.playerx = int(SCREENWIDTH * 0.3) self.playery = int((SCREENHEIGHT - self.PLAYER_HEIGHT) / 2.25) self.baseShift = self.IMAGES["base"].get_width() - self.BACKGROUND_WIDTH newPipe1 = getRandomPipe(self.PIPE_HEIGHT) newPipe2 = getRandomPipe(self.PIPE_HEIGHT) #...other code
step 函數(shù),執(zhí)行兩個(gè)動(dòng)作,0 表示不采取行動(dòng)(小鳥(niǎo)會(huì)自動(dòng)下落),1 表示飛翔;step 函數(shù)有四個(gè)返回值,image_data 表示當(dāng)前狀態(tài),也就是游戲畫(huà)面,reward 表示本次 step 的即時(shí)獎(jiǎng)勵(lì),terminal 表示是否是吸收狀態(tài),{} 表示其他信息:
def step(self, input_action=0):
pygame.event.pump() reward = 0.1 terminal = False if input_action == 1: if self.playery > -2 * self.PLAYER_HEIGHT: self.playerVelY = self.playerFlapAcc self.playerFlapped = True
#...other code
image_data=self.render()
return image_data, reward, terminal,{}
獎(jiǎng)勵(lì) reward;初始獎(jiǎng)勵(lì)是 +0.1,表示小鳥(niǎo)向前飛行一小段距離;穿過(guò)柱子,獎(jiǎng)勵(lì) +1;撞到柱子,獎(jiǎng)勵(lì)為 -1,并且到達(dá) terminal 狀態(tài):
飛行一段距離,獎(jiǎng)勵(lì)+0.1reward = 0.1
...other codeplayerMidPos = self.playerx + self.PLAYER_WIDTH / 2
for pipe in self.upperPipes:
pipeMidPos = pipe["x"] + self.PIPE_WIDTH / 2 #穿過(guò)一個(gè)柱子獎(jiǎng)勵(lì)加1 if pipeMidPos <= playerMidPos < pipeMidPos + 4: self.score += 1 reward = self.reward(1)...other code
if isCrash:
#撞到邊緣或者撞到柱子,結(jié)束,并且獎(jiǎng)勵(lì)為-1 terminal = True reward = self.reward(-1)
reward 函數(shù),返回即時(shí)獎(jiǎng)勵(lì) r:
def reward(self,r):
return r
reset 函數(shù),調(diào)用 init,并執(zhí)行一次飛翔操作,返回 observation,reward,isOver:
def reset(self,mode="train"):
self.__init__() self.mode=mode action0 = 1 observation, reward, isOver,_ = self.step(action0) return observation,reward,isOver
render 函數(shù),渲染游戲界面,并返回當(dāng)前畫(huà)面:
def render(self):
image_data = pygame.surfarray.array3d(pygame.display.get_surface()) pygame.display.update() self.FPSCLOCK.tick(FPS) return image_data
至此,強(qiáng)化學(xué)習(xí)所需的狀態(tài)、動(dòng)作、獎(jiǎng)勵(lì)等功能均定義完畢。接下來(lái)簡(jiǎn)單推導(dǎo)一下 DQN (Deep-Q-Network) 算法的原理。
DQN的發(fā)展過(guò)程
DQN 的進(jìn)化歷史可謂源遠(yuǎn)流長(zhǎng),從最開(kāi)始 Bellman 在 1956 年提出的動(dòng)態(tài)規(guī)劃,到后來(lái) Watkins 在 1989 年提出的的 Q-learning,再到 DeepMind 的 Nature-2015 穩(wěn)定版,最后到 Dueling DQN、Priority Replay Memory、Parameter Noise 等優(yōu)化算法,橫跨整整一個(gè)甲子,凝聚了無(wú)數(shù)專家、教授們的心血。如今的我們站在先賢們的肩膀上,從以下角度逐步分析:
貝爾曼(最優(yōu))方程與 VQ 樹(shù)
Q-learning
參數(shù)化逼近
DQN 算法框架
貝爾曼 (最優(yōu)) 方程與VQ樹(shù)
我們從經(jīng)典的表格型強(qiáng)化學(xué)習(xí)(Tabular Reinforcement Learning)開(kāi)始,回憶一下馬爾可夫決策(MDP)過(guò)程,MDP 可由五元組 (S,A,P,R,γ) 表示,其中:
S 狀態(tài)集合,維度為 1×|S|
A 動(dòng)作集合,維度為 1×|A|
P 狀態(tài)轉(zhuǎn)移概率矩陣,經(jīng)常寫(xiě)成,其維度為 |S|×|A|×|S|
R 回報(bào)函數(shù),如果依賴于狀態(tài)值函數(shù) V,維度為 1×|S|,如果依賴于狀態(tài)-動(dòng)作值函數(shù) Q,則維度為 |S|×|A|
γ 折扣因子,用來(lái)計(jì)算帶折扣的累計(jì)回報(bào) G(t),維度為 1
S、A、R、γ 均不難理解,可能部分同學(xué)對(duì)有疑問(wèn)——既然 S 和 A 確定了,下一個(gè)狀態(tài) S" 不是也確定了嗎?為什么會(huì)有概率轉(zhuǎn)移矩陣呢?
其實(shí)我初學(xué)的時(shí)候也曾經(jīng)被這個(gè)問(wèn)題困擾過(guò),不妨通過(guò)如下兩個(gè)例子以示區(qū)別:
恒等于 1.0 的情況。如圖 1 所示,也就是上一次我們?cè)诓呗蕴荻人惴ㄖ兴褂玫拿詫m,假設(shè)機(jī)器人處于左上角,這時(shí)候你命令機(jī)器人向右走,那么他轉(zhuǎn)移到紅框所示位置的概率就是 1.0,不會(huì)有任何異議:
▲ 圖1. 迷宮尋寶
不等于 1.0 的情況。假設(shè)現(xiàn)在我們下一個(gè)飛行棋,如圖 2 所示。有兩種骰子,第一種是普通的正方體骰子,可以投出 1~6,第二種是正四面體的骰子,可以投出 1~4。現(xiàn)在飛機(jī)處于紅框所示的位置,現(xiàn)在我們選擇投擲第二種骰子這個(gè)動(dòng)作,由于骰子本身具有均勻隨機(jī)性,所以飛機(jī)轉(zhuǎn)移到終點(diǎn)的概率僅僅是 0.25。這就說(shuō)明,在某些環(huán)境中,給定 S、A 的情況下,轉(zhuǎn)移到具體哪一個(gè) S" 其實(shí)是不確定的:
▲ 圖2. 飛行棋
除了經(jīng)典的五元組外,為了研究長(zhǎng)期回報(bào),還經(jīng)常加入三個(gè)重要的元素,分別是:
策略 π(a∣s),維度為 |S|×|A|
狀態(tài)值函數(shù),維度為 1×|S|,表示當(dāng)智能體采用策略 π 時(shí),累積回報(bào)在狀態(tài) s 處的期望值:
▲ 圖3. 狀態(tài)值函數(shù)
狀態(tài)-行為值函數(shù),也叫狀態(tài)-動(dòng)作值函數(shù),維度為 |S|×|A|,表示當(dāng)智能體采取策略 π 時(shí),累計(jì)回報(bào)在狀態(tài) s 處并執(zhí)行動(dòng)作 a 時(shí)的期望值:
▲ 圖4. 狀態(tài)-行為值函數(shù)
知道了 π、v、q 的具體含義后,我們來(lái)看一個(gè)重要的概念,也就是 V、Q 的遞歸展開(kāi)式。
學(xué)過(guò)動(dòng)態(tài)規(guī)劃的同學(xué)都知道,動(dòng)態(tài)規(guī)劃本質(zhì)上是一個(gè) bootstrap(自舉)問(wèn)題,它包含最優(yōu)子結(jié)構(gòu)與重疊子問(wèn)題兩個(gè)性質(zhì),也就是說(shuō),通常有兩種方法解決動(dòng)態(tài)規(guī)劃:
將總問(wèn)題劃分為 k 個(gè)子問(wèn)題,遞歸求解這些子問(wèn)題,然后將子問(wèn)題進(jìn)行合并,得到總問(wèn)題的最優(yōu)解;對(duì)于重復(fù)的子問(wèn)題,我們可以將他們進(jìn)行緩存(記憶搜索 MemorySearch,請(qǐng)回憶 f(n)=f(n-1)+f(n-2) 這個(gè)遞歸程序);
計(jì)算最小的子問(wèn)題,合并這些子問(wèn)題產(chǎn)生一個(gè)更大的子問(wèn)題,不斷的自底向上計(jì)算,隨著子問(wèn)題的規(guī)模越來(lái)越大,我們會(huì)得到最終的總問(wèn)題的最優(yōu)解(打表 DP,請(qǐng)回憶楊輝三角中的 dp[i-1,j-1]+dp[i-1,j]=dp[i,j])。
這兩種切題技巧,對(duì)于有過(guò) ACM 或者 LeetCode 刷題經(jīng)驗(yàn)的同學(xué),可以說(shuō)是老朋友了,那么能否把以上思想遷移到強(qiáng)化學(xué)習(xí)呢?答案是肯定的!
分別考慮 v、q 的展開(kāi)式:
處在狀態(tài) s 時(shí),由于有策略 π 的存在,故可以把狀態(tài)值函數(shù) v 展開(kāi)成以下形式:
▲ 圖5. v展開(kāi)成q
這個(gè)公式表示:在狀態(tài) s 處的值函數(shù),等于采取策略 π 時(shí),所有狀態(tài)-行為值函數(shù)的總和。
處在狀態(tài) s、并執(zhí)行動(dòng)作 a,可以把狀態(tài)-行為值函數(shù) q 展開(kāi)成以下形式:
▲ 圖6. q展開(kāi)成v
這個(gè)公式表示:在狀態(tài) s 采用動(dòng)作 a 的狀態(tài)行為值函數(shù),等于回報(bào)加上后序可能產(chǎn)生的的狀態(tài)值函數(shù)的總和。
我們可以看到:v 可以展開(kāi)成 q,同時(shí) q 也可以展開(kāi)成 v。
所以可以用以下 v、q 節(jié)點(diǎn)相隔的樹(shù)來(lái)表示以上兩個(gè)公式,這顆樹(shù)比純粹的公式更容易理解,我習(xí)慣上把它叫做 V-Q 樹(shù),它顯然是一個(gè)遞歸的結(jié)構(gòu):
▲ 圖7. V-Q樹(shù)
注意畫(huà)紅圈中的兩個(gè)節(jié)點(diǎn),體現(xiàn)了重疊子問(wèn)題特性。如何理解這個(gè)性質(zhì)呢?不妨回憶一下上文提到的飛行棋,假設(shè)飛機(jī)處在起點(diǎn)位置 1,那么無(wú)論投擲 1 號(hào)骰子還是 2 號(hào)骰子,都是有機(jī)會(huì)可以到達(dá)位置 3 的,這就是重疊子問(wèn)題的一個(gè)例子。
有了這棵遞歸樹(shù)之后,就不難推導(dǎo)出 v 和 v",以及 q 和 q" 自身的遞歸展開(kāi)式:
▲ 圖8. 狀態(tài)值函數(shù)v自身的遞歸展開(kāi)式
▲ 圖9. 狀態(tài)-行為值函數(shù)q自身的遞歸展開(kāi)式
其實(shí)無(wú)論是 v 還是 q,都擁有最優(yōu)子結(jié)構(gòu)特性。不妨利用反證法加以證明:
假設(shè)要求總問(wèn)題 V(s) 的最優(yōu)解,那么它包含的每個(gè)子問(wèn)題 V(s") 也必須是最優(yōu)解;否則,如果某個(gè)子問(wèn)題 V(s") 不是最優(yōu),那么必然有一個(gè)更優(yōu)的子問(wèn)題 V"(s") 存在,使得總問(wèn)題 V"(s) 比原來(lái)的總問(wèn)題 V(s) 更優(yōu),與我們的假設(shè)相矛盾,故最優(yōu)子結(jié)構(gòu)性質(zhì)得證,q(s) 的最優(yōu)子結(jié)構(gòu)性質(zhì)同理。
計(jì)算值函數(shù)的目的是為了構(gòu)建學(xué)習(xí)算法得到最優(yōu)策略,每個(gè)策略對(duì)應(yīng)著一個(gè)狀態(tài)值函數(shù),最優(yōu)策略自然也對(duì)應(yīng)著最優(yōu)狀態(tài)值函數(shù),故而定義如下兩個(gè)函數(shù):
最優(yōu)狀態(tài)值函數(shù),表示在所有策略中最大的值函數(shù),即:
▲ 圖10. 最優(yōu)狀態(tài)值函數(shù)
最優(yōu)狀態(tài)-行為值函數(shù),表示在所有策略中最大的狀態(tài)-行為值函數(shù):
▲ 圖11. 最優(yōu)狀態(tài)-行為值函數(shù)
結(jié)合上文的遞歸展開(kāi)式和最優(yōu)子結(jié)構(gòu)性質(zhì),可以得到 v 與 q 的貝爾曼最優(yōu)方程:
▲ 圖12. v的貝爾曼最優(yōu)方程
▲ 圖13. q的貝爾曼最優(yōu)方程
重點(diǎn)理解第二個(gè)公式,也就是關(guān)于 q 的貝爾曼最優(yōu)方程,它是今天的主角 Q-learning 以及 DQN 的理論基礎(chǔ)。
有了貝爾曼最優(yōu)方程,我們就可以通過(guò)純粹貪心的策略來(lái)確定 π,即:僅僅把最優(yōu)動(dòng)作的概率設(shè)置為 1,其他所有非最優(yōu)動(dòng)作的概率都設(shè)置為 0。這樣做的好處是:當(dāng)算法收斂的時(shí)候,策略 π(a|s) 必然是一個(gè) one-hot 型的矩陣。用數(shù)學(xué)公式表達(dá)如下:
▲ 圖14. 算法收斂時(shí)候的策略π
強(qiáng)化學(xué)習(xí)中的動(dòng)態(tài)規(guī)劃方法實(shí)質(zhì)上是一種 model-based(模型已知)方法,因?yàn)?MDP 五元組是已知的,特別是狀態(tài)轉(zhuǎn)移概率矩陣是已知的。
也就是說(shuō),所有的環(huán)境信息對(duì)于我們來(lái)說(shuō)是 100% 完備的,故而可以對(duì)整個(gè)解空間樹(shù)進(jìn)行全局搜索,下圖展示了動(dòng)態(tài)規(guī)劃方法的示意圖,在確定根節(jié)點(diǎn)狀態(tài) S(t) 的最優(yōu)值的時(shí)候,必須遍歷他所有的 S(t+1) 子節(jié)點(diǎn)并選出最優(yōu)解:
▲ 圖15. 動(dòng)態(tài)規(guī)劃方法的解空間搜索過(guò)程
不過(guò),和傳統(tǒng)的刷題動(dòng)態(tài)規(guī)劃略有不同,強(qiáng)化學(xué)習(xí)往往是利用值迭代(Value Iteration)、策略迭代(Policy Iteration)、策略改善(Policy Improve)等方式使 v、q、π 等元素達(dá)到收斂狀態(tài),當(dāng)然也有直接利用矩陣求逆計(jì)算解析解的方法,有興趣的同學(xué)可以參考相關(guān)文獻(xiàn),這里不再贅述。
Q-learning
上文提到的動(dòng)態(tài)規(guī)劃方法是一種 model-based 方法,僅僅適用于已知的情況。若狀態(tài)轉(zhuǎn)移概率矩陣未知,model-free(無(wú)模型)方法就派上用場(chǎng)了,上一期的 MCPG 算法就是一種典型的 model-free 方法。它搜索解空間的方式更像是 DFS(深度優(yōu)先搜索),而且一條道走到黑,沒(méi)有指針回溯的操作,下圖展示了蒙特卡洛算法的求解示意圖:
▲ 圖16. MC系列方法的解空間搜索過(guò)程
雖然每次只能走一條分支,但隨機(jī)數(shù)發(fā)生器會(huì)幫助算法遍歷整個(gè)解空間,再通過(guò)大量的迭代,所有節(jié)點(diǎn)也會(huì)收斂到最優(yōu)解。
不過(guò),MC 類方法有兩個(gè)小缺點(diǎn):
使用作為訓(xùn)練標(biāo)簽,其本身就是值函數(shù)準(zhǔn)確的無(wú)偏估計(jì)。但是,這也正是它的缺點(diǎn),因?yàn)?MC 方法會(huì)經(jīng)歷很多隨機(jī)的狀態(tài)和動(dòng)作,使得每次得到的 G(t) 隨機(jī)性很大,具有很高的方差。
由于采用的是一條道走到黑的方式從根節(jié)點(diǎn)遍歷到葉子節(jié)點(diǎn),所以必須要等到 episode 結(jié)束才能進(jìn)行訓(xùn)練,而且每輪 episode 產(chǎn)生的數(shù)據(jù)只訓(xùn)練一次,每輪 episode 產(chǎn)生數(shù)據(jù)的 batch-size 還不一定相同,所以在訓(xùn)練過(guò)程中,MC 方法的 loss 函數(shù)(或者 TD-Error)的波動(dòng)幅度較大,而數(shù)據(jù)利用效率不高。
那么,能否邊產(chǎn)生數(shù)據(jù)邊訓(xùn)練呢?可以!時(shí)序差分(Temporal-Difference-Learning,簡(jiǎn)稱 TD)算法應(yīng)運(yùn)而生了。
時(shí)序差分學(xué)習(xí)是模擬(或者經(jīng)歷)一段序列,每行動(dòng)一步(或者幾步)就根據(jù)新?tīng)顟B(tài)的價(jià)值估計(jì)當(dāng)前執(zhí)行的狀態(tài)價(jià)值。大致可以分為兩個(gè)小類:
TD(0) 算法,只向后估計(jì)一個(gè) step。其值函數(shù)更新公式為:
▲ 圖17. TD(0)算法的更新公式
其中,α 為學(xué)習(xí)率,稱為 TD 目標(biāo),MC 方法中的 G(t) 也可以叫做 TD 目標(biāo),稱為 TD-Error,當(dāng)模型收斂時(shí),TD-Error 會(huì)無(wú)限接近于 0。
Sarsa(λ) 算法,向后估計(jì) n 步,n 為有限值,還有一個(gè)衰減因子 λ。其值函數(shù)的更新公式為:
▲ 圖18. Sarsa(λ)算法的更新公式
▲ 圖19. 的計(jì)算方法
與 MC 方法相比,TD 方法只用到了一步或者有限步隨機(jī)狀態(tài)和動(dòng)作,因此它是一個(gè)有偏估計(jì)。不過(guò),由于 TD 目標(biāo)的隨機(jī)性比 MC 方法的 G(t) 要小,所以方差也比 MC 方法小的多,值函數(shù)的波動(dòng)幅度較小,訓(xùn)練比較穩(wěn)定。
看一下 TD 方法的解空間搜索示意圖,紅框表示 TD(0),藍(lán)框表示 Sarsa(λ)。雖然每次估計(jì)都有一定的偏差,但隨著算法的不斷迭代,所有的節(jié)點(diǎn)也會(huì)收斂到最優(yōu)解:
▲ 圖20. TD方法的解空間搜索過(guò)程
有了 TD 的框架,既然我們要求狀態(tài)值函數(shù) v、狀態(tài)-行為值函數(shù) q 的最優(yōu)解,那么是否能直接選擇最優(yōu)的 TD 目標(biāo)作為 Target 呢?答案是肯定的,這也是 Q-Learning 算法的基本思想,其公式如下所示:
▲ 圖21. Q-learning算法的學(xué)習(xí)公式
其中,動(dòng)作 a 由 ε-greedy 策略選出,從而在狀態(tài) s 處執(zhí)行 a 之后產(chǎn)生了另一個(gè)狀態(tài) s",接下來(lái)選出狀態(tài) s" 處最大的狀態(tài)-行為值函數(shù) q(s",a"),這樣,TD 目標(biāo)就可以確定為 R+γmax[a′]Q(s′,a′)。這種思想很像貪心算法中的總是選擇在當(dāng)前看來(lái)最優(yōu)的決策,它一開(kāi)始可能會(huì)得到一個(gè)局部最優(yōu)解,不過(guò)沒(méi)關(guān)系,隨著算法的不斷迭代,整個(gè)解空間樹(shù)也會(huì)收斂到全局最優(yōu)解。
以下是 Q-learning 算法的偽代碼,和 on-policy 的 MC 方法對(duì)應(yīng),它是一種 off-policy(異策略)方法:
define maxEpisode=65535 //定義最大迭代輪數(shù) define maxStep=1024 //定義每一輪最多走多少步initialize Q_table[|S|,|A|] //初始化Q矩陣
for i in range(0,maxEpisode):
s=env.reset() //初始化狀態(tài)s for j in range(0,maxStep): //用ε-greedy策略在s行選一個(gè)動(dòng)作a choose action a using ε-greedy from Q_table[s] s",R,terminal,_=env.step(a) //執(zhí)行動(dòng)作a,得到下一個(gè)狀態(tài)s",獎(jiǎng)勵(lì)R,是否結(jié)束terminal max_s_prime_action=np.max(Q_table[s",:]) //選s"對(duì)應(yīng)的最大行為值函數(shù) td=R+γ*max_s_prime_action //計(jì)算TD目標(biāo) Q_table[s,a]= Q_table[s,a]+α*(td-Q_table[s,a]) //學(xué)習(xí)Q(s,a)的值 s=s" //更新s,注意,和sarsa算法不同,這里的a不用更新 if terminal: break
Q-learning 是一種優(yōu)秀的算法,不僅簡(jiǎn)單直觀,而且平均速度比 MC 快。在 DRL 未出現(xiàn)之前,它在強(qiáng)化學(xué)習(xí)中的地位,差不多可以媲美 SVM 在機(jī)器學(xué)習(xí)中的地位。
參數(shù)化逼近
有了 Q-learning 算法,是否就能一招吃遍天下鮮了呢?答案是否定的,我們看一下它存在的問(wèn)題。
上文所提到的,無(wú)論是 DP、MC 還是 TD,都是基于表格(tabular)的方法,當(dāng)狀態(tài)空間比較小的時(shí)候,計(jì)算機(jī)內(nèi)存完全可以裝下,表格式型強(qiáng)化學(xué)習(xí)是完全適用的。但遇到高階魔方(三階魔方的總變化數(shù)是)、圍棋()這類問(wèn)題時(shí),S、V、Q、P 等表格均會(huì)出現(xiàn)維度災(zāi)難,早就超出了計(jì)算機(jī)內(nèi)存甚至硬盤(pán)容量。這時(shí)候,參數(shù)化逼近方法就派上用場(chǎng)了。
所謂參數(shù)化逼近,是指值函數(shù)可以由一組參數(shù) θ 來(lái)近似,如 Q-learning 中的 Q(s,a) 可以寫(xiě)成 Q(s,a|θ) 的形式。這樣,不但降低了存儲(chǔ)維度,還便于做一些額外的特征工程,而且 θ 更新的同時(shí),Q(s,a|θ) 會(huì)進(jìn)行整體更新,不僅避免了過(guò)擬合情況,還使得模型的泛化能力更強(qiáng)。
既然有了可訓(xùn)練參數(shù),我們就要研究損失函數(shù)了,Q-Learning 的損失函數(shù)是什么呢?
先看一下 Q-Learning 的優(yōu)化目標(biāo)——使得 TD-Error 最小:
▲ 圖22. Q-Learning的優(yōu)化目標(biāo)
加入?yún)?shù) θ 之后,若將 TD 目標(biāo)作為標(biāo)簽 target,將 Q(s,a) 作為模型的輸出 y,則問(wèn)題轉(zhuǎn)化為:
▲ 圖23. 帶參數(shù)的優(yōu)化目標(biāo)
這是我們所熟悉的監(jiān)督學(xué)習(xí)中的回歸問(wèn)題,顯然 loss 函數(shù)就是 mse,故而可以用梯度下降算法最小化 loss,從而更新參數(shù) θ:
▲ 圖24. loss函數(shù)的梯度下降公式
注意到,TD 目標(biāo)是標(biāo)簽,所以 Q(s",a"|θ) 中的 θ 是不能更新的,這種方法并非完全的梯度法,只有部分梯度,稱為半梯度法,這是 NIPS-2013 的雛形。
后來(lái),DeepMind 在 Nature-2015 版本中將 TD 網(wǎng)絡(luò)多帶帶分開(kāi),其參數(shù)為 θ",它本身并不參與訓(xùn)練,而是每隔固定步數(shù)將值函數(shù)逼近的網(wǎng)絡(luò)參數(shù) θ 拷貝給 θ",這樣保證了 DQN 的訓(xùn)練更加穩(wěn)定:
▲ 圖25. 含有目標(biāo)網(wǎng)絡(luò)參數(shù)θ"的梯度下降公式
至此,DQN 的 Loss 函數(shù)、梯度下降公式推導(dǎo)完畢。
DQN算法框架
接下來(lái),還要解決兩個(gè)問(wèn)題——數(shù)據(jù)從哪里來(lái)?如何采集?
針對(duì)以上兩個(gè)問(wèn)題,DeepMind 團(tuán)隊(duì)提出了深度強(qiáng)化學(xué)習(xí)的全新訓(xùn)練方法:經(jīng)驗(yàn)回放(experience replay)。
在強(qiáng)化學(xué)習(xí)過(guò)程中,智能體將數(shù)據(jù)存儲(chǔ)到一個(gè) ReplayBuffer 中(任何一種集合,可以是哈希表、數(shù)組、隊(duì)列,也可以是數(shù)據(jù)庫(kù)),然后利用均勻隨機(jī)采樣的方法從 ReplayBuffer 中抽取數(shù)據(jù),這些數(shù)據(jù)就可以進(jìn)行 Mini-Batch-SGD,這樣就打破了數(shù)據(jù)之間的相關(guān)性,使得數(shù)據(jù)之間盡量符合獨(dú)立同分布原則。
DQN 的基本網(wǎng)絡(luò)結(jié)構(gòu)如下:
▲ 圖26. DQN的基本網(wǎng)絡(luò)結(jié)構(gòu)
要特別注意:
與參數(shù) θ 做線性運(yùn)算 (wx+b) 的僅僅是輸入狀態(tài) s,這一步?jīng)]有動(dòng)作 a 的參與;
output_1 的維度為 |A|,表示神經(jīng)網(wǎng)絡(luò) Q(s,θ) 的輸出;
輸入動(dòng)作 a 是 one-hot,與 output_1 作哈達(dá)馬積后產(chǎn)生的 output_2 是一個(gè)數(shù)字,作為損失函數(shù)中的 Q(s,a|θ),也就是 y。
以下是 DQN 算法的偽代碼:
Deep-Q-Network,Nature 2015 version 定義為一個(gè)雙端隊(duì)列D,作為經(jīng)驗(yàn)回放區(qū)域,最大長(zhǎng)度為max_sizeInitialize replay_memory D as a deque,mas_size=50000
初始化狀態(tài)-行為值函數(shù)Q的神經(jīng)網(wǎng)絡(luò),權(quán)值隨機(jī)Initialize action-value function Q(s,a|θ) as Neural Network with random-weights-initializer
初始化TD目標(biāo)網(wǎng)絡(luò),初始權(quán)值和θ相等Initialize target action-value function Q(s,a|θ) with weights θ"=θ
迭代max_episode個(gè)輪次for episode in range(0,max_episode=65535):
#重置環(huán)境env,得到初始狀態(tài)s s=env.reset() #循環(huán)事件的每一步,最多迭代max_step_limit個(gè)step for step in range(0,max_step_limit=1024): #通過(guò)ε-greedy的方式選出一個(gè)動(dòng)作action With probability ε select a random action a or select a=argmax(Q(s,θ)) #在env中執(zhí)行動(dòng)作a,得到下一個(gè)狀態(tài)s",獎(jiǎng)勵(lì)R,是否終止terminal s",R,terminal,_=env.step(a) #將五元組(s,a,s",R,terminal)壓進(jìn)隊(duì)尾 D.addLast(s,a,s",R,terminal) #如果隊(duì)列滿,彈出隊(duì)頭元素 if D.isFull(): D.removeFirst() #更新?tīng)顟B(tài)s s=s" #從隊(duì)列中進(jìn)行隨機(jī)采樣 batch_experience[s,a,s",R,terminal]=random_select(D,batch_size=32) #計(jì)算TD目標(biāo) target = R + γ*(1- terminal) * np.max(Q(s",θ")) #對(duì)loss函數(shù)執(zhí)行Gradient-decent,訓(xùn)練參數(shù)θ θ=θ+α*(target-Q(s,a|θ))▽Q(s,a|θ) #每隔C步,同步θ與θ"的權(quán)值 Every C steps set θ"=θ #是否結(jié)束 if terminal: break
我們玩的游戲 Flappy-Bird,它的輸入是一幀一幀的圖片,所以,經(jīng)典的 Atari-CNN 模型就可以派上用場(chǎng)了:
▲ 圖27. Atari游戲的CNN網(wǎng)絡(luò)結(jié)構(gòu)
網(wǎng)絡(luò)的輸入是被處理成灰度圖的最近 4 幀 84*84 圖像(4 是經(jīng)驗(yàn)值),經(jīng)過(guò)若干 CNN 和 FullyConnect 后,輸出各個(gè)動(dòng)作所對(duì)應(yīng)的狀態(tài)-行為值函數(shù) Q。以下是每一層的具體參數(shù),由于 atari 游戲最多有 18 個(gè)動(dòng)作,所以最后一層的維度是 18:
▲ 圖28. 神經(jīng)網(wǎng)絡(luò)的具體參數(shù)
至此,理論部分推導(dǎo)完畢。下面,我們分析一下 PARL 中的 DQN 部分的源碼,并實(shí)現(xiàn) Flappy-Bird 的游戲智能。
代碼實(shí)現(xiàn)
依次分析 env、model、algorithm、agent、replay_memory、train 等模塊。
BirdEnv.py,環(huán)境;上文已經(jīng)分析過(guò)了。
BirdModel.py,神經(jīng)網(wǎng)絡(luò)模型;使用三層 CNN+兩層 FC,CNN 的 padding 方式都是 valid,最后輸出狀態(tài)-行為值函數(shù) Q,維度為 |A|。注意輸入圖片歸一化,并按照官方模板填入代碼:
class BirdModel(Model):
def __init__(self, act_dim): self.act_dim = act_dim #padding方式為valid p_valid=0 self.conv1 = layers.conv2d( num_filters=32, filter_size=8, stride=4, padding=p_valid, act="relu") self.conv2 = layers.conv2d( num_filters=64, filter_size=4, stride=2, padding=p_valid, act="relu") self.conv3 = layers.conv2d( num_filters=64, filter_size=3, stride=1, padding=p_valid, act="relu") self.fc0=layers.fc(size=512) self.fc1 = layers.fc(size=act_dim) def value(self, obs): #輸入歸一化 obs = obs / 255.0 out = self.conv1(obs) out = self.conv2(out) out = self.conv3(out) out = layers.flatten(out, axis=1) out = self.fc0(out) out = self.fc1(out) return out
dqn.py,算法層;官方倉(cāng)庫(kù)已經(jīng)提供好了,我們無(wú)需自己再寫(xiě),直接復(fù)用算法庫(kù)(parl.algorithms)里邊的 DQN 算法即可。
簡(jiǎn)單分析一下 DQN 的源碼實(shí)現(xiàn)。
define_learn 函數(shù),用于神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)。接收 [狀態(tài) obs, 動(dòng)作 action, 即時(shí)獎(jiǎng)勵(lì) reward, 下一個(gè)狀態(tài) next_obs, 是否終止 terminal] 這樣一個(gè)五元組,代碼實(shí)現(xiàn)如下:
根據(jù)obs以及參數(shù)θ計(jì)算狀態(tài)-行為值函數(shù)pred_value,對(duì)應(yīng)偽代碼中的Q(s,θ)pred_value = self.model.value(obs)
根據(jù)next_obs以及參數(shù)θ"計(jì)算目標(biāo)網(wǎng)絡(luò)的狀態(tài)-行為值函數(shù)next_pred_value,對(duì)應(yīng)偽代碼中的Q(s",θ")next_pred_value = self.target_model.value(next_obs)
選出next_pred_value的最大值best_v,對(duì)應(yīng)偽代碼中的np.max(Q(s",θ"));注意θ"不參與訓(xùn)練,所以要stop_gradientbest_v = layers.reduce_max(next_pred_value, dim=1)
best_v.stop_gradient = True
target = reward + (1.0 - layers.cast(terminal, dtype="float32")) self.gamma best_v
輸入的動(dòng)作action與pred_value作哈達(dá)瑪積,選出要評(píng)估的狀態(tài)-行為值函數(shù)pred_action_value,對(duì)應(yīng)偽代碼中的 Q(s,a|θ)action_onehot = layers.one_hot(action, self.action_dim)
action_onehot = layers.cast(action_onehot, dtype="float32")
pred_action_value = layers.reduce_sum(layers.elementwise_mul(action_onehot, pred_value), dim=1)
cost = layers.square_error_cost(pred_action_value, target)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(self.lr, epsilon=1e-3)
optimizer.minimize(cost)
sync_target 函數(shù)用于同步網(wǎng)絡(luò)參數(shù):
def sync_target(self, gpu_id):
""" sync parameters of self.target_model with self.model """ self.model.sync_params_to(self.target_model, gpu_id=gpu_id)
BirdAgent.py,智能體。其中,build_program 函數(shù)封裝了 algorithm 中的 define_predict 和 define_learn,sample 函數(shù)以 ε-greedy 策略選擇動(dòng)作,predict 函數(shù)以 100% 貪心的策略選擇 argmax 動(dòng)作,learn 函數(shù)接收五元組 (obs, act, reward, next_obs, terminal) 完成學(xué)習(xí)功能,這些函數(shù)和 Policy-Gradient 的寫(xiě)法類似。
除了這些常用功能之外,由于游戲的訓(xùn)練時(shí)間比較長(zhǎng),所以附加了兩個(gè)函數(shù),save_params 用于保存模型,load_params 用于加載模型:
保存模型def save_params(self, learnDir,predictDir):
fluid.io.save_params( executor=self.fluid_executor, dirname=learnDir, main_program=self.learn_programs[0]) fluid.io.save_params( executor=self.fluid_executor, dirname=predictDir, main_program=self.predict_programs[0])加載模型
def load_params(self, learnDir,predictDir):
fluid.io.load_params( executor=self.fluid_executor, dirname=learnDir, main_program=self.learn_programs[0]) fluid.io.load_params( executor=self.fluid_executor, dirname=predictDir, main_program=self.predict_programs[0])
另外,還有四個(gè)超參數(shù),可以進(jìn)行微調(diào):
每訓(xùn)練多少步更新target網(wǎng)絡(luò),超參數(shù)可調(diào)self.update_target_steps = 5000
初始探索概率ε,超參數(shù)可微調(diào)self.exploration = 0.8
每步探索的衰減程度,超參數(shù)可微調(diào)self.exploration_dacay=1e-6
最小探索概率,超參數(shù)可微調(diào)self.min_exploration=0.05
replay_memory.py,經(jīng)驗(yàn)回放單元。雙端隊(duì)列 _context 是一個(gè)滑動(dòng)窗口,用來(lái)記錄最近 3 幀(再加上新產(chǎn)生的 1 幀就是 4 幀);state、action、reward 等用 numpy 數(shù)組存儲(chǔ),因?yàn)?numpy 的功能比雙端隊(duì)列更豐富,max_size 表示 replay_memory 的最大容量:
self.state = np.zeros((self.max_size, ) + state_shape, dtype="int32")
self.action = np.zeros((self.max_size, ), dtype="int32")
self.reward = np.zeros((self.max_size, ), dtype="float32")
self.isOver = np.zeros((self.max_size, ), dtype="bool")
self._context = deque(maxlen=context_len - 1)
其他的 append、recent_state、sample_batch 等函數(shù)并不難理解,都是基于 numpy 數(shù)組的進(jìn)一步封裝,略過(guò)一遍即可看懂。
Train_Test_Working_Flow.py,訓(xùn)練與測(cè)試,讓環(huán)境 evn 和智能體 agent 進(jìn)行交互。最重要的就是 run_train_episode 函數(shù),體現(xiàn)了 DQN 的主要邏輯,重點(diǎn)分析注釋部分與 DQN 偽代碼的對(duì)應(yīng)關(guān)系,其他都是編程細(xì)節(jié):
訓(xùn)練一個(gè)episodedef run_train_episode(env, agent, rpm):
global trainEpisode global meanReward total_reward = 0 all_cost = [] #重置環(huán)境 state,_, __ = env.reset() step = 0 #循環(huán)每一步 while True: context = rpm.recent_state() context.append(resizeBirdrToAtari(state)) context = np.stack(context, axis=0) #用ε-greedy的方式選一個(gè)動(dòng)作 action = agent.sample(context) #執(zhí)行動(dòng)作 next_state, reward, isOver,_ = env.step(action) step += 1 #存入replay_buffer rpm.append(Experience(resizeBirdrToAtari(state), action, reward, isOver)) if rpm.size() > MEMORY_WARMUP_SIZE: if step % UPDATE_FREQ == 0: #從replay_buffer中隨機(jī)采樣 batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch(batchSize) batch_state = batch_all_state[:, :CONTEXT_LEN, :, :] batch_next_state = batch_all_state[:, 1:, :, :] #執(zhí)行SGD,訓(xùn)練參數(shù)θ cost=agent.learn(batch_state,batch_action, batch_reward,batch_next_state, batch_isOver) all_cost.append(float(cost)) total_reward += reward state = next_state if isOver or step>=MAX_Step_Limit: break if all_cost: trainEpisode+=1 #以滑動(dòng)平均的方式打印平均獎(jiǎng)勵(lì) meanReward=meanReward+(total_reward-meanReward)/trainEpisode print(" trainEpisode:{},total_reward:{:.2f}, meanReward:{:.2f} mean_cost:{:.3f}" .format(trainEpisode,total_reward, meanReward,np.mean(all_cost))) return total_reward, step
除了主要邏輯外,還有一些常見(jiàn)的優(yōu)化手段,防止訓(xùn)練過(guò)程中出現(xiàn) trick:
充滿replay-memory,使其達(dá)到warm-up-size才開(kāi)始訓(xùn)練MEMORY_WARMUP_SIZE = MEMORY_SIZE//20
一輪episode最多執(zhí)行多少次step,不然小鳥(niǎo)會(huì)無(wú)限制的飛下去,相當(dāng)于gym.env中的_max_episode_steps屬性MAX_Step_Limit=int(1<<12)
用一個(gè)雙端隊(duì)列記錄最近16次episode的平均獎(jiǎng)勵(lì)avgQueue=deque(maxlen=16)
另外,還有其他一些超參數(shù),比如學(xué)習(xí)率 LEARNING_RATE、衰減因子 GAMMA、記錄日志的頻率 log_freq 等等,都可以進(jìn)行微調(diào):
衰減因子GAMMA = 0.99
學(xué)習(xí)率LEARNING_RATE = 1e-3 * 0.5
記錄日志的頻率log_freq=10
main 函數(shù)在這里,輸入 train 訓(xùn)練網(wǎng)絡(luò),輸入 test 進(jìn)行測(cè)試:
if name == "__main__":
print("train or test ?") mode=input() print(mode) if mode=="train": train() elif mode=="test": test() else: print("Invalid input!")
這是模型在我本機(jī)訓(xùn)練的輸出日志,大概 3300 個(gè) episode、50 萬(wàn)步之后,模型就收斂了:
▲ 圖29. 模型訓(xùn)練的輸出日志
平均獎(jiǎng)勵(lì):
▲ 圖30. 最近16次平均獎(jiǎng)勵(lì)變化曲線
各位同學(xué)可以試著調(diào)節(jié)超參數(shù),或者修改網(wǎng)絡(luò)模型,看看能不能遇到一些坑?哪些因素會(huì)影響訓(xùn)練效率?如何提升收斂速度?
接下來(lái)就是見(jiàn)證奇跡的時(shí)刻,當(dāng)初懵懂的小笨鳥(niǎo),如今已修煉成精了!
▲ 訓(xùn)練完的FlappyBird
觀看 4 分鐘完整版:
https://www.bilibili.com/vide...
Github源碼:
https://github.com/kosoraYint...
參考文獻(xiàn)
[1] Bellman, R.E. & Dreyfus, S.E. (1962). Applied dynamic programming. RAND Corporation.
[2] Sutton, R.S. (1988). Learning to predict by the methods of temporal difference.Machine Learning, 3, pp. 9–44.
[3] V. Mnih, K. Kavukcuoglu, D. Silver, A. A. Rusu, et al., "Human-level control through deep reinforcement learning," Nature, vol. 518(7540), pp. 529-533, 2015.
[4] https://leetcode.com/problems...
[5] https://leetcode.com/problems...
[6] https://github.com/yenchenlin...
[7] https://github.com/MorvanZhou...
文章版權(quán)歸作者所有,未經(jīng)允許請(qǐng)勿轉(zhuǎn)載,若此文章存在違規(guī)行為,您可以聯(lián)系管理員刪除。
轉(zhuǎn)載請(qǐng)注明本文地址:http://specialneedsforspecialkids.com/yun/20028.html
摘要:和的得分均未超過(guò)右遺傳算法在也表現(xiàn)得很好。深度遺傳算法成功演化了有著萬(wàn)自由參數(shù)的網(wǎng)絡(luò),這是通過(guò)一個(gè)傳統(tǒng)的進(jìn)化算法演化的較大的神經(jīng)網(wǎng)絡(luò)。 Uber 涉及領(lǐng)域廣泛,其中許多領(lǐng)域都可以利用機(jī)器學(xué)習(xí)改進(jìn)其運(yùn)作。開(kāi)發(fā)包括神經(jīng)進(jìn)化在內(nèi)的各種有力的學(xué)習(xí)方法將幫助 Uber 發(fā)展更安全、更可靠的運(yùn)輸方案。遺傳算法——訓(xùn)練深度學(xué)習(xí)網(wǎng)絡(luò)的有力競(jìng)爭(zhēng)者我們驚訝地發(fā)現(xiàn),通過(guò)使用我們發(fā)明的一種新技術(shù)來(lái)高效演化 DNN,...
摘要:可以想象,監(jiān)督式學(xué)習(xí)和增強(qiáng)式學(xué)習(xí)的不同可能會(huì)防止對(duì)抗性攻擊在黑盒測(cè)試環(huán)境下發(fā)生作用,因?yàn)楣魺o(wú)法進(jìn)入目標(biāo)策略網(wǎng)絡(luò)。我們的實(shí)驗(yàn)證明,即使在黑盒測(cè)試中,使用特定對(duì)抗樣本仍然可以較輕易地愚弄神經(jīng)網(wǎng)絡(luò)策略。 機(jī)器學(xué)習(xí)分類器在故意引發(fā)誤分類的輸入面前具有脆弱性。在計(jì)算機(jī)視覺(jué)應(yīng)用的環(huán)境中,對(duì)這種對(duì)抗樣本已經(jīng)有了充分研究。論文中,我們證明了對(duì)于強(qiáng)化學(xué)習(xí)中的神經(jīng)網(wǎng)絡(luò)策略,對(duì)抗性攻擊依然有效。我們特別論證了,...
摘要:摘要本文主要是講解了機(jī)器學(xué)習(xí)中的增強(qiáng)學(xué)習(xí)方法的基本原理,常用算法及應(yīng)用場(chǎng)景,最后給出了學(xué)習(xí)資源,對(duì)于初學(xué)者而言可以將其作為入門(mén)指南。下圖表示了強(qiáng)化學(xué)習(xí)模型中涉及的基本思想和要素。 摘要: 本文主要是講解了機(jī)器學(xué)習(xí)中的增強(qiáng)學(xué)習(xí)方法的基本原理,常用算法及應(yīng)用場(chǎng)景,最后給出了學(xué)習(xí)資源,對(duì)于初學(xué)者而言可以將其作為入門(mén)指南。 強(qiáng)化學(xué)習(xí)(Reinforcement Learning)是當(dāng)前最熱門(mén)的...
摘要:在這個(gè)問(wèn)題強(qiáng)化學(xué)習(xí)里,我遇到過(guò)很多人,他們始終不相信我們能夠通過(guò)一套算法,從像素開(kāi)始從頭學(xué)會(huì)玩游戲這太驚人了,我自己也曾經(jīng)這么想。基于像素的乒乓游戲乒乓游戲是研究簡(jiǎn)單強(qiáng)化學(xué)習(xí)的一個(gè)非常好的例子。 這是一篇早就應(yīng)該寫(xiě)的關(guān)于強(qiáng)化學(xué)習(xí)的文章。強(qiáng)化學(xué)習(xí)現(xiàn)在很火!你可能已經(jīng)注意到計(jì)算機(jī)現(xiàn)在可以自動(dòng)(從游戲畫(huà)面的像素中)學(xué)會(huì)玩雅達(dá)利(Atari)游戲[1],它們已經(jīng)擊敗了圍棋界的世界冠軍,四足機(jī)器人學(xué)會(huì)...
閱讀 2984·2021-10-19 11:46
閱讀 979·2021-08-03 14:03
閱讀 2934·2021-06-11 18:08
閱讀 2905·2019-08-29 13:52
閱讀 2744·2019-08-29 12:49
閱讀 480·2019-08-26 13:56
閱讀 924·2019-08-26 13:41
閱讀 849·2019-08-26 13:35