<dl id="2ki44"><tbody id="2ki44"></tbody></dl>
  • <dfn id="2ki44"><pre id="2ki44"></pre></dfn>
  • <pre id="2ki44"><cite id="2ki44"></cite></pre>
  • <pre id="2ki44"></pre>
  • DeepMind 新作 AlphaDev ---- 強(qiáng)化學(xué)習(xí)探索更優(yōu)排序算法

    2023-06-22 08:53:23 來(lái)源: 程序員客棧

    前言

    DeepMind 最近在 Nature 發(fā)表了一篇論文 AlphaDev[2, 3],一個(gè)利用強(qiáng)化學(xué)習(xí)來(lái)探索更優(yōu)排序算法的AI系統(tǒng)。

    AlphaDev 系統(tǒng)直接從 CPU 匯編指令的層面入手去探索更優(yōu)的排序算法,因?yàn)橄鄬?duì)于高級(jí)編程語(yǔ)言來(lái)說(shuō),在匯編指令層級(jí)對(duì)存儲(chǔ)和寄存器的操作可以更加的靈活,所以能發(fā)現(xiàn)更多潛在的調(diào)優(yōu)策略。


    【資料圖】

    在 AlphaDev 的論文中,只關(guān)注探索短序列排序:

    定長(zhǎng)序列排序(比如 sort3 算法只能對(duì)長(zhǎng)度為3的序列進(jìn)行排序)

    變長(zhǎng)序列排序(比如 variable sort5 算法可以對(duì)長(zhǎng)度為1~5的變長(zhǎng)序列進(jìn)行排序)

    而對(duì)于長(zhǎng)序列的排序,可以被分解為短序列的排序。

    DeepMind 通過(guò) AlphaDev 發(fā)現(xiàn)了比目前人工調(diào)優(yōu)算法更優(yōu)的定長(zhǎng)短序列排序算法 sort3,sort4 和 sort5 ,并且已經(jīng)將代碼提交到了 LLVM 標(biāo)準(zhǔn) C++ 庫(kù)[4]。

    簡(jiǎn)單來(lái)說(shuō),AlphaDev 將探索更高效排序算法的過(guò)程,建模為一個(gè)單玩家的匯編游戲(single-player game, AssemblyGame)。

    游戲的過(guò)程就是玩家從 CPU 匯編指令集合中,選取一系列的指令組合得到一個(gè)新的排序算法。不過(guò)這個(gè)過(guò)程是非常有挑戰(zhàn)的,玩家需要考慮,匯編指令的組合空間并最終得得到一個(gè)正確和高效的算法。

    該游戲主要包括以下難點(diǎn):

    匯編游戲的搜索空間和圍棋類似(10^700)

    只要有一條指令沒(méi)弄對(duì),可能就會(huì)導(dǎo)致整個(gè)算法錯(cuò)誤

    AlphaDev 系統(tǒng)詳解將排序算法表示為 CPU 匯編指令

    首先來(lái)看一個(gè)簡(jiǎn)單的變長(zhǎng)(variable sort2)短排序函數(shù)的 C 代碼實(shí)現(xiàn),排序結(jié)果從小到大:

    voidvariable_sort_2(intlength,int*a){switch(length){case0:case1:return;case2:inttmp=a[0];//a[0]保存兩者之間的最小值a[0]=(a[1]

    通過(guò) gcc生成對(duì)應(yīng)的匯編代碼,我用的 gcc版本是 11.3.0,命令 gcc -S -O1 -o sort2.s sort2.c

    匯編代碼只保留了核心部分,生成的結(jié)果和論文中的示例有些許不同但是原理是一致的:

    variable_sort_2:  .LFB0:; %edi 寄存器保存參數(shù) length 的值; cmpl 指令對(duì)比 %edi 和 常量 2cmpl$2, %edi ; 相等就跳轉(zhuǎn)到 .L3 標(biāo)簽處,        ; 對(duì)應(yīng) C 代碼的 case 2je.L3.L1:; 不等于 2 就直接返回,        ; 對(duì)應(yīng) C 代碼 case 0 和 1ret .L3:; 將 a[0] 賦值給寄存器 %edx movl(%rsi), %edx; 將 a[1] 賦值給寄存器 %eax movl4(%rsi), %eax; 對(duì)比 %edx 和 %eaxcmpl%edx, %eax; 將 %edx 賦值給 %ecxmovl%edx, %ecx; cmov 是條件移動(dòng)指令根據(jù) cmpl ; 指令的結(jié)果判斷是否執(zhí)行; 如果 %eax <= %edx ; 則將 %eax 賦值給 %ecxcmovle%eax, %ecx; 此時(shí) %ecx 保存了最小值; 將 %ecx 賦值給 a[0]movl%ecx, (%rsi); 如果 %eax 小于 %edx; 則將 %edx 賦值給 %eaxcmovl%edx, %eax; 此時(shí) %eax 保存了最大值; 將 %eax 賦值給 a[1]movl%eax, 4(%rsi)jmp.L1

    一般來(lái)說(shuō)匯編程序所做的事情基本都是,將內(nèi)存的值復(fù)制到寄存器,然后對(duì)寄存器的值作修改,再將寄存器的值寫回到內(nèi)存中。

    而 AlphaDev 系統(tǒng)只關(guān)注 x86 處理器架構(gòu)所支持的匯編指令集合的一個(gè)子集。

    每條匯編指令的格式均為:操作碼<操作數(shù)A, 操作數(shù)B>比如:

    mov移動(dòng)指令,表示將 A 的值賦值給 B

    cmp比較指令,相當(dāng)于 執(zhí)行 A - B 操作,但是不會(huì)對(duì) A 和 B 做修改,而是根據(jù)相減的結(jié)果設(shè)置特殊的 flag 寄存器,更多內(nèi)容可以參考[5]

    cmovX條件移動(dòng)指令,根據(jù) X和 flag 寄存器的值判斷是否執(zhí)行將 A 賦值給 B 的操作,一般都是出現(xiàn)在 cmp指令之后。X可以是 L(是否滿足小于條件), G(是否滿足大于條件),LE(是否滿足小于或等于條件),GE(是否滿足大于等于條件)。

    jX條件跳轉(zhuǎn)指令,根據(jù) X和 flag 寄存器的值判斷是否執(zhí)行跳轉(zhuǎn)到指定標(biāo)記位置操作,A 可以是匯編程序代碼中的標(biāo)記位置,如上面所示匯編代碼的 .L1和 .L3。X可以是 NE(是否不等于),E(是否等于)或者可以填表示無(wú)條件跳轉(zhuǎn)。

    將探索更優(yōu)排序算法表示為強(qiáng)化學(xué)習(xí)問(wèn)題

    AlphaDev 將 CPU 匯編指令層面的算法優(yōu)化過(guò)程轉(zhuǎn)化為一個(gè)單玩家的游戲。

    游戲每一步的狀態(tài)定義為 : St =

    其中, Pt表示游戲到至今為止所生成的算法,Zt則表示在給定輸入的前提下執(zhí)行完 Pt里的指令之后,內(nèi)存和寄存器的狀態(tài)。

    如上圖所示,在時(shí)間步 t,AlphaDev 接受到當(dāng)前狀態(tài) St和 所要執(zhí)行的動(dòng)作 at(比如 mov),也就是往當(dāng)前生成的算法 Pt中添加的合法匯編指令。

    在添加完指令之后,就是計(jì)算獎(jiǎng)勵(lì)分?jǐn)?shù) rt(包括評(píng)估算法的正確性和延遲)。

    算法正確性評(píng)估

    正確性評(píng)估就是將 N組測(cè)試序列輸入到算法 Pt中,得到N組輸出,和正確的排序結(jié)果最比較來(lái)計(jì)算獎(jiǎng)勵(lì)分?jǐn)?shù)。

    論文中給出了3種正確性評(píng)估函數(shù),首先定義 P為輸入序列長(zhǎng)度, PCt為在時(shí)間步 t序列中,位置正確的值的個(gè)數(shù),這里我理解應(yīng)該是和正確的排序結(jié)果逐個(gè)位置對(duì)比,統(tǒng)計(jì)相等的個(gè)數(shù)。

    三個(gè)函數(shù)分別定義如下:

    func1 = (P - PCt) / P

    func2 = sqrt(func1)

    func3 = sqrt(PCt)

    論文中提到采用第三個(gè)函數(shù)效果最好。

    延遲評(píng)估

    延遲分?jǐn)?shù)的計(jì)算可以是:

    對(duì)系統(tǒng)增加代碼長(zhǎng)度計(jì)算懲罰,因?yàn)榇a的長(zhǎng)度一般都是和耗時(shí)高度相關(guān)

    直接計(jì)算算法的真實(shí)耗時(shí)

    整個(gè)強(qiáng)化學(xué)習(xí)的游戲在執(zhí)行有限步驟之后就會(huì)被終止。只有生成正確而又低延遲的匯編代碼才算贏得游戲。而不管是生成了錯(cuò)誤的代碼還是正確但低效的實(shí)現(xiàn)都視為游戲輸了。

    AlphaDev 采用的強(qiáng)化學(xué)習(xí)算法是對(duì) AlphqaZero 算法的擴(kuò)展,也是采用深度神經(jīng)網(wǎng)絡(luò)來(lái)引導(dǎo)蒙特卡洛樹(shù)搜索(MCTS)的規(guī)劃過(guò)程。網(wǎng)絡(luò)模型的輸入是 St,輸出是對(duì)動(dòng)作策略和獎(jiǎng)勵(lì)的預(yù)測(cè)。

    整個(gè)游戲過(guò)程簡(jiǎn)單來(lái)說(shuō)就是,用一個(gè)固定參數(shù)的網(wǎng)絡(luò)模型,通過(guò)給定的當(dāng)前狀態(tài)執(zhí)行一個(gè)蒙特卡洛樹(shù)搜索過(guò)程,然后采取下一步動(dòng)作。然后可以用生成的游戲過(guò)程(包含每一步的狀態(tài)和獎(jiǎng)勵(lì))去訓(xùn)練和更新網(wǎng)絡(luò)的參數(shù)。

    網(wǎng)絡(luò)模型結(jié)構(gòu)

    模型包含兩部分:

    一個(gè) Transformer 編碼器模塊,用于建模算法,輸入是至今為止生成的匯編指令序列

    一個(gè) CPU 狀態(tài)編碼器 MLP 模塊,輸入當(dāng)前寄存器和內(nèi)存的狀態(tài)

    兩個(gè)網(wǎng)絡(luò)的輸出 embedding 會(huì)合并在一起來(lái)表示當(dāng)前的狀態(tài)。

    網(wǎng)絡(luò)模型整體的結(jié)構(gòu)如下:

    Transformer 編碼器模塊具體圖示

    如上圖所示,把當(dāng)前生成的匯編代碼序列的每一條指令的操作碼和操作數(shù)都轉(zhuǎn)換為 one-hot 編碼序列,然后輸入到網(wǎng)絡(luò)中。

    但是具體的 one-hot 編碼規(guī)則、詞表怎么設(shè)置、還有對(duì)于 CPU 狀態(tài)編碼網(wǎng)絡(luò)寄存器和內(nèi)存的狀態(tài)是怎么表示為網(wǎng)絡(luò)的輸入的等等,這些細(xì)節(jié)我在論文里沒(méi)找到。

    然后兩個(gè)網(wǎng)絡(luò)的輸出 embedding 會(huì)合并到一起接著輸入到幾個(gè)函數(shù)頭里計(jì)算,分別是預(yù)測(cè)下一步策略的函數(shù)頭,預(yù)測(cè)算法正確性的函數(shù)頭和預(yù)測(cè)算法真實(shí)延遲的函數(shù)頭。

    網(wǎng)絡(luò)參數(shù)超參設(shè)置

    論文的補(bǔ)充資料中提供了網(wǎng)絡(luò)的參數(shù)和三個(gè)函數(shù)頭的具體配置。

    而對(duì)于策略的預(yù)測(cè),論文中提到為了簡(jiǎn)化問(wèn)題和提高收斂性,而對(duì)動(dòng)作空間做了一些限制,規(guī)則如下:

    必須按照升序方式讀取內(nèi)存

    寄存器按照升序分配

    cmp和 cmovX指令的操作數(shù)不能出現(xiàn)內(nèi)存地址

    對(duì)每個(gè)內(nèi)存位置,只能讀取和寫入一次

    每個(gè)寄存器在使用之前,必須初始化

    不能連續(xù)調(diào)用 cmp指令

    訓(xùn)練細(xì)節(jié)

    AlphaDev 的訓(xùn)練采用了 TPU v3,每個(gè) TPU 核的 batch size 是 1024 ,總共用了 16 個(gè) TPU 核,總共訓(xùn)練了 100 萬(wàn)次迭代。而在對(duì)于玩游戲積累訓(xùn)練數(shù)據(jù)來(lái)說(shuō),則是在 TPU v4 上進(jìn)行,總共用了 512 個(gè) TPU 核。

    實(shí)驗(yàn)結(jié)果表明,最多只需2天模型就能訓(xùn)收斂。

    實(shí)驗(yàn)結(jié)果生成的算法和人工調(diào)優(yōu)對(duì)比

    從實(shí)驗(yàn)結(jié)果表格可以看到,對(duì)于短序列排序算法 AlphaDev 生成的代碼長(zhǎng)度更短,而且平均耗時(shí)也更低。

    對(duì)生成算法延遲的評(píng)估方式,比如對(duì)于 sort3則是在 100 臺(tái)機(jī)器上做評(píng)估,每臺(tái)機(jī)器隨機(jī)生成 1000 條 3個(gè)數(shù)的序列,然后每條序列輸入到算法中,對(duì)這 1000 次評(píng)估取第5百分位數(shù)作為最終的評(píng)估結(jié)果(排除 cache miss 和 任務(wù)搶占 等因素)。

    耗時(shí)采用的是 CPU_CLK_UNHALTED.CORE這個(gè)計(jì)數(shù)器結(jié)果, 其計(jì)數(shù)值表示在一個(gè)特定時(shí)間段內(nèi),處理器內(nèi)核的時(shí)鐘周期數(shù)。這個(gè)值越高,意味著處理器內(nèi)核在該時(shí)間段內(nèi)執(zhí)行了更多的指令。

    AlphaDev 發(fā)現(xiàn)新的算法

    對(duì)于定長(zhǎng)序列排序,當(dāng)應(yīng)用到排序網(wǎng)絡(luò)算法[6](sorting network algorithm)的時(shí)候 AlphaDev 生成的代碼中包含了一些有趣指令序列,相對(duì)于原始指令序列可以減少一條匯編指令,論文中稱之為:

    AlphaDev swap move

    AlphaDev copy move

    啥是排序網(wǎng)絡(luò)算法?

    排序網(wǎng)絡(luò)算法(Sorting Network Algorithm)是一種能夠?qū)σ唤M輸入數(shù)據(jù)進(jìn)行排序的并行算法,其具有較好的并行性能適用于多處理器或多核心系統(tǒng)。

    該算法的特點(diǎn)是,它將所有的比較和交換操作預(yù)先規(guī)劃好形成一個(gè)固定的結(jié)構(gòu),然后將輸入數(shù)據(jù)按照這個(gè)結(jié)構(gòu)進(jìn)行排序。

    排序網(wǎng)絡(luò)由比較器(comparator)和線(wire)組成,如下圖所示:

    水平線表示 wire,每條水平線持有一個(gè)待排序的值。兩條 wire 之間的垂直線段就表示一個(gè)比較器,比較器對(duì)比兩條水平線的值,如果比較器下方的值小于上方的值則交換兩條橫線的值,否則則不交換。

    一個(gè)優(yōu)化過(guò)的排序網(wǎng)絡(luò)可以以最少的比較器,并將這些比較器放置在特定位置上,來(lái)實(shí)現(xiàn)對(duì)任意序列進(jìn)行排序。

    下圖是對(duì)一個(gè)構(gòu)造好的排序網(wǎng)絡(luò),輸入真實(shí)待排序序列的例子:

    可見(jiàn)初始輸入是 [2, 3, 1, 4],這些隨機(jī)數(shù)從左到右按順序經(jīng)過(guò)這些比較器之后,就得到了排序好的序列 [1, 2, 3, 4]。

    AlphaDev swap move

    先來(lái)看這個(gè)排序網(wǎng)絡(luò),只看紅圈部分的功能就是對(duì)給定的輸入 [A, B, C]將其轉(zhuǎn)換為 [min(A,B,C), max(min(A,C),B), max(A,C)]。

    然后經(jīng)過(guò) AlphaDev 優(yōu)化之后,可以將第一個(gè)輸出的 min(A,B,C)改為只計(jì)算 min(A,B),原因是因?yàn)榍懊娴?B和 C橫線之間經(jīng)過(guò)比較器之后已經(jīng)有了前置條件 B <= C。

    而通過(guò)這個(gè)優(yōu)化就能省去一條匯編指令,下圖是紅圈部分的偽代碼實(shí)現(xiàn):

    左邊是原始偽代碼實(shí)現(xiàn),右邊是經(jīng)過(guò) AlphaDev 優(yōu)化之后的實(shí)現(xiàn),可以看到少了一條匯編指令 mov S P。

    AlphaDev copy move

    接下來(lái)看對(duì)4個(gè)元素進(jìn)行排序的排序網(wǎng)絡(luò),是在對(duì) sort8這個(gè)算法優(yōu)化過(guò)程中發(fā)現(xiàn)的。該排序網(wǎng)絡(luò)對(duì)于輸入序列 [A, B, C, D]轉(zhuǎn)換為 [min(A, B, C, D), max(B, min(A, C, D), max(C, min(A, D)), max(A, D) ]。

    該排序網(wǎng)絡(luò)是 sort8的一個(gè)子排序網(wǎng)絡(luò),而根據(jù)比較器的放置位置來(lái)看,A和 D比較之后后續(xù)就不再和其他元素比較了,所以D出來(lái)的結(jié)果就是四個(gè)元素中最大的,所以隱含了一個(gè)條件就是 D >= min(A, C)。

    因此對(duì)第二個(gè)輸出元素的計(jì)算可以從 max(B, min(A, C, D))改為 max(B, min(A, C)),就可以節(jié)省一條匯編指令。

    偽代碼如下:

    左邊是原始偽代碼實(shí)現(xiàn),右邊是經(jīng)過(guò) AlphaDev 優(yōu)化之后的實(shí)現(xiàn),可以看到少了一條匯編指令 mov P T。

    總結(jié)

    這篇文章只是對(duì) AlphaDev 論文中的主要內(nèi)容作解讀,對(duì)于更多的內(nèi)容和細(xì)節(jié)感興趣的讀者可以查閱原論文和論文的補(bǔ)充資料 [2,3],DeepMind 也也開(kāi)源了一份偽代碼實(shí)現(xiàn) [7]。

    參考資料

    [1] https://ee.usc.edu/~redekopp/cs356/slides/CS356Unit5_x86_Control

    [2] https://www.nature.com/articles/s41586-023-06004-9#MOESM1

    [3] https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-023-06004-9/MediaObjects/41586_2023_6004_MOESM1_ESM.pdf

    [4] ? D118029 Introduce branchless sorting functions for sort3, sort4 and sort5. (llvm.org)

    [5] 小信豬的原始部落: PC Assembly Language 學(xué)習(xí)筆記(5) - Control Structures (godleon.blogspot.com)

    [6] https://en.wikipedia.org/wiki/Sorting_network#:~:text=as%20the%20contrapositive.-,Constructing%20sorting%20networks,are%20often%20used%20in%20practice.

    [7] https://github.com/deepmind/alphadev

    標(biāo)簽:

    相關(guān)熱詞搜索:

    相關(guān)閱讀

    最近更新

    国产精品无打码在线播放9久,91高清在线视频,极品主播的慰在线播放,国产在线播放不卡
    <dl id="2ki44"><tbody id="2ki44"></tbody></dl>
  • <dfn id="2ki44"><pre id="2ki44"></pre></dfn>
  • <pre id="2ki44"><cite id="2ki44"></cite></pre>
  • <pre id="2ki44"></pre>
  • 主站蜘蛛池模板: 亚洲色欲久久久综合网| 日本黄色影院在线观看| 亚洲精品无码久久久久秋霞| 久久精品国产99国产精品亚洲| 成人短视频完整版在线播放| 国产熟女一区二区三区五月婷 | 毛片亚洲AV无码精品国产午夜| 岛国免费在线观看| 午夜精品久久久久久中宇| 亚洲AV无码一区二区一二区| 18成禁人视频免费网站| 欧美国产精品久久| 国产精品久关键词| 亚洲一区二区影院| 国产1000部成人免费视频| 欧美野外疯狂做受xxxx高潮| 国内揄拍国内精品| 亚洲处破女AV日韩精品| free性欧美另类高清| 最近中文字幕在线中文视频 | 无翼乌日本漫画| 国产精品久久久久aaaa| 亚洲av午夜成人片| 高校饥渴男女教室野战| 无码日韩精品一区二区三区免费| 叶山豪是真吃蓝燕奶| xxxx日本在线| 欧美黑人粗硬大在线看| 国产精品久久久久久搜索| 久久精品国产99久久久古代| 草莓视频秋葵视频在线观看ios| 成人毛片免费网站| 亚洲色图古典武侠| 无遮挡1000部拍拍拍免费凤凰| 欧美精品blacked中文字幕| 国产精品久久久小说| 久久精品人人槡人妻人人玩AV| 草草久久久无码国产专区 | 国产在线jyzzjyzz免费麻豆| 中日韩精品视频在线观看| 男生女生差差差很痛|