Trust Region Policy Optimization

Abstract

  • policy optimization에서 monotonic improvement를 하기 위한 반복 프로시저.
  • 이론적으로 성립하는 절차에 여러 근사를 가정함.
  • Trust Region Policy Optimization(TRPO) 이라 하는 practical한 알고리즘을 개발함.
  • 이 알고리즘은 Natural Policy Gradient 메소드와 유사하고, 뉴럴넷과 같은 대규모 nonlinear policy의 최적화에 효과적임.
  • 우리의 실험은 다양한 종류의 태스크에서 알고리즘의 강건한 성능을 보여줌.
  • 근사치들이 이론에서 벗어나긴 해도, TRPO는 크게 하이퍼파라미터 튜닝 없이도 monotonic importvement를 보이는 경향이 있음.

Introduction

Background

flowchart TD
	subgraph GB1[Policy Iteration]
	GB1A[ADP]
	GB1B[CPI]
	end
	
	subgraph GB2[Policy Gradient]
	direction TB
	GB2A[REINFORCE]
	GB2B[NPG]
	GB2C[Actor-Critic]
	GB2A ~~~ GB2B
	GB2A ~~~ GB2C
	end
	
	subgraph DFO[Derivative-Free]
	DFO1[CEM]
	DFO2[CMA]
	end
	
	GB[Gradient-Based]
	GB --- GB1
	GB --- GB2
	
	A["Policy Optimization<br/>Algorithm"]
	A === GB
	A === DFO
  • Derivative-Free Stochastic Optimization, 예를 들어 CEM(cross-entropy)이나 CMA(covariance matrix)가 많이 쓰였음. 구현과 이해가 단순한 것에 비해 좋은 성능을 내기 때문에.
    • Tetris는 ADP(Approximate Dynamic Programming)의 고전적 벤치마크이지만, stochastic optimization을 이기기 어려움.
    • 연속 제어 문제에서는 CMA가 좋았음.
      • Continuous Control Problem: 연속행동공간, 연속상태공간을 다루는 문제.
        • e.g. 로봇 관절 제어
  • Gradient 기반 최적화 알고리즘이 그래디언트 없는 방법보다 훨씬 더 나은 샘플 복잡성 보장을 제공함에도, ADP와 Gradient 기반 방법이 Gradient-free 랜덤 검색을 일관되게 능가하지 못함.
    • Is ADP one of the gradient-based methods? ─ 아님. ADP는 Policy Iteration 계열인데, 얘도 Gradient를 씀. 어찌 됐건 Gradient를 쓰는 방법이 안 쓰는 방법을 못 이기는 게 문제 의식.
    • Why does sample complexity matter? ─ Sample Complexity란, 일정 수준에 도달하기 위해 필요한 샘플링(에이전트-환경 상호작용) 수. RL에서는 샘플링 수가 곧 비용. Gradient 기반 방법론은 이론적으로 더 적은 샘플링으로 정책 향상이 가능해야 함.
      • Why “complexity”? ─ 계산이론 용어인 듯. 자세한 건 모르겠네…
  • continuous gradient 기반 optimization은 대규모 모델의 지도 학습 과제에서 function approximation 태스크에 성공적.
    • 이것을 강화학습으로 확장하면 복잡하고 강력한 정책의 효율적 학습이 가능해질 것.
    • Insight Sources ─ 지금까진 지도학습에서 Gradient 기반 최적화가 수M~수B 파라미터의 대규모 비선형 함수 근사기(=뉴럴넷) 를 성공적으로 학습시킴. 그런데 RL은 대규모 비선형 정책 최적화가 어려워서, 저차원의 정책과 gradient-free 방법에 의존해왔음. SL의 성공을 RL로 adaptation한다면, 뉴럴넷 정책도 효율적으로 학습이 가능할 것.
      • Why is it so difficult to optimize large-scale policy in RL? ─ (1) RL의 목적 함수가 Non-stationary, (2) Policy가 데이터 분포 자체를 바꾸기 때문에 SGD를 적용하면 학습이 매우 불안정함. 최적화해야 하는 함수가 최적화 스텝마다 바뀌는 느낌. \(\rightarrow\) 핵심 Gap
        • delayed reward: credit assignment 문제의 원인
        • high variance: \(\sum_l \gamma^l r\) (MC return)의 추정이 trajectory 마다 크게 달라서 추정량 분산이 높음
        • catastrophic policy degradation: 스텝사이즈가 조금만 커도 붕괴함. bad policy \(\to\) bad data \(\to\) bad policy \(\to \cdots\) 악순환 발생

Preliminary

Notations

Symbol Definition Name 의미
\(\mathcal{S}\) \(\mathcal{S} = \{s: \text{state}\}\) Finite set of states 유한상태공간
\(\mathcal{A}\) \(\mathcal{A} = \{a: \text{action}\}\) Finite set of actions 유한행동공간
\(P(s' \mid s, a)\) \(P: \mathcal{S} \times \mathcal{A} \times \mathcal{S} \to \mathbb{R}\) Transition probability distribution 상태 \(s\) 에서 행동 \(a\) 후 상태 \(s'\) 로 전이 확률
\(r(s)\) \(r: \mathcal{S} \to \mathbb{R}\) Reward function 상태 \(s\) 의 보상
\(\rho_{0}(s)\) \(\rho_0: \mathcal{S} \to \mathbb{R}\) The distribution of the initial state 초기상태 \(s_0\) 의 분포
\(\gamma\) \(\gamma \in (0, 1)\) Discount factor 미래 보상 할인 인자
\(\pi(a \mid s)\) \(\pi: \mathcal{S} \times \mathcal{A} \to [0,1]\) Stochastic policy 상태 \(s\) 에서 행동 \(a\) 를 선택할 확률
\(\tilde{\cdot}\) target, sample, revised, \(\cdots\) - 문맥에 따름
\(\eta(\pi)\) \(\mathbb{E}_{s_0, a_0, \cdots}\left[\sum_{t=0}^{\infty} \gamma^t r(s_t)\right]\) Expected discounted reward of stochastic policy \(\pi\) 정책 \(\pi\) 의 기대 할인 보상
\(Q_\pi(s_t,a_t)\) \(\mathbb{E}_{s_{t+1}, a_{t+1}, \cdots}\left[\sum_{l=0}^{\infty} \gamma^l r(s_{t+l})\right]\) State-action value function 상태 \(s_t\) 에서 특정 행동 \(a_t\) 에 대한 기대 보상
\(V_\pi(s_t)\) \(\mathbb{E}_{a_t, s_{t+1}, \cdots}\left[\sum_{l=0}^{\infty} \gamma^l r(s_{t+l})\right]\) \(=\) \(\sum_{a_t} \pi(a_t \mid s_t) Q_{\pi} (s_t, a_t)\) Value function 상태 \(s_t\) 에서 모든 행동에 대한 기대 보상 가중 평균
\(A_\pi(s,a)\) \(Q_\pi(s,a) - V_\pi(s)\) Advantage function 행동 \(a\) 가 평균 대비 얼마나 좋은가
\(\rho_\pi(s)\) \(\sum_{t=0}^{\infty} \gamma^t P(s_t = s)\) Discounted visitation frequency 할인이 적용된 상태 방문 빈도
\(\mathbb{D}_{TV}(p \;|\; q)\) \(\frac{1}{2} \sum_i \lvert{p_i - q_i}\rvert\) Total variation divergence 사건 확률 거리. 아예 안 겹치면 1
\(\mathbb{D}^{max }_{TV}(\pi, \tilde{\pi})\) \(\underset{s}{max}\; \mathbb{D}_{TV}(\pi(\cdot \mid s) \;|\; \tilde{\pi}(\cdot \mid s))\) Max total variation 모든 상태 중 최악(max)의 TV divergence
\(\mathbb{D}_{KL}(p \;|\; q)\) \(\sum_i p_i \log{\frac{p_i}{q_i}}\) KL divergence 분포 \(q\) 로 \(p\) 를 인코딩할 때 정보 손실
\(\mathbb{D}^{max }_{KL}(\pi, \tilde{\pi})\) \(\underset{s}{max}\; \mathbb{D}_{KL}(\pi(\cdot \mid s) \;|\; \tilde{\pi}(\cdot \mid s))\) Max KL divergence 모든 상태 중 최악(max)의 KL divergence
\(\bar{\mathbb{D}}^{\rho}_{KL}(\pi, \tilde{\pi})\) \(\mathbb{E}_{s \sim \rho}\left[\mathbb{D}_{KL}\bigl(\pi(\cdot \mid s) \;|\; \tilde{\pi}(\cdot\mid s)\bigr)\right]\) Average KL divergence 상태 분포 \(\rho\) 하의 평균 KL divergence
\(F_{ij}\) \(\mathbb{E}_{s \sim \rho_{\theta_{odd}}} \left[{\frac{{\partial}^2}{\partial\theta_i\partial\theta_j}} \mathbb{D}_{KL}\left(\pi_{\theta_{old}}(\cdot \mid s) \;|\; \pi_\theta(\cdot \mid s)\right)\right]\) Fisher information matrix (Hessian of KL) 해석적 FIM

Kullback-Leibler Divergence

  • Entropy \(H(p) = H(p, p) = -\sum_i p_i \log{p_i}\)
  • Cross-entropy \(H(p, q) = -\sum_i p_i \log{q_i}\)
  • KL divergence \(\mathbb{D}_{KL}(p\;\|\;q) = H(p, q) - H(p) = \sum_i p_i \log \frac{p_i}{q_i}\)
    • Intuition ─ 분포 \(q\) 로 \(p\) 를 encoding하려고 할 때 생기는 정보 손실. 이 손실이 클 수록, \(p\) 에서 나타나는 trajectory가 \(q\) 에서 발생하기 어려움. \(\to\) asymmetry. KL divergence가 distance가 아닌 divergence인 이유.
    • Observation. TRPO vs. RLHF
      • TRPO: \(\mathbb{D}_{KL} (\pi_{old} \;\|\; \pi_{new})\)
      • RLHF: \(\mathbb{D}_{KL} (\pi_{\theta} \;\|\; \pi_{ref})\)
      • Why? ─ TRPO의 경우, old policy가 높은 확률을 메기는 곳에서 new policy가 낮은 확률을 줄 때 penalty를 가하기 위함. 반면, RLHF는 reward hacking을 막아야 함. 기존 모델이 거의 생성하지 않는 토큰을 생성하는 경우(mode-seeking)를 방지하기 위함.
  • Properties of KL divergence
    • \(\mathbb{D}_{KL} (p \;\|\; q) \geq 0.\) 깁스 부등식에 의해.
    • \[\mathbb{D}_{KL} (p \;\|\; q) = 0 \implies p=q.\]
  • Relationship to TV divergence
    • \[\mathbb{D}^2_{TV} (p \;\|\; q) \leq \mathbb{D}_{KL} (p \;\|\; q)\]

Conservative Policy Iteration(CPI)

  • Background. traditional policy iteration은 정확한 가치 함수와 정책을 사용할 때만 수렴 보장이 가능. 보통 실용적으로는 정책을 근사해야 함. 스텝 사이즈를 제한하는 보수적인(conservative) PI 를 제안
  • current policy \(\pi_{old}\)
  • greedy policy \(\pi'\)
  • mixture policy \(\pi_{new}\) \(\pi_{new} = (1-\alpha)\pi_{old} + \alpha\pi'\)
  • Kakade-Langford bound \(\eta(\pi_{new}) \geq L_{\pi_{old}} (\pi_{new}) - \frac{2\epsilon\gamma}{(1-\gamma)^2} {\alpha}^2\) \(\quad\quad\quad\quad\quad\quad \text{where} \quad \epsilon = \underset{s}{max} \lvert{\mathbb{E}_{a \sim \pi'(a \mid s)}\left[A_\pi(s, a)\right]}\rvert\)
  • Key Result.
    • step size 에 대한 최초의 explicit lower bound
  • Limitations.
    • mixture policy 에만 적용.
    • \(\to\) TRPO의 Theorem 1 에서 일반적인 policy 로 확장.

Natural Policy Gradient (NPG)

  • The optimization problem \(\underset{\Delta\theta}{max} \;g^{T} \Delta\theta \quad \text{s.t.} \quad \lvert\lvert{\Delta\theta}\rvert\rvert^2 \leq \epsilon\)
  • Vanilla PG: \(\theta + \alpha g\)
  • The optimization problem \(\underset{\Delta\theta}{max} \;g^{T} \Delta\theta \quad \text{s.t.} \quad \Delta\theta^{T} F \Delta\theta \leq \epsilon\)
  • Natural PG: \(\theta + \alpha F^{-1}g\)
  • Key Result.
    • Policy Optimization 에 최초로 Fisher Information Matrix \(F\) 를 사용
    • 왜 FIM 인가?
      • KL divergence 의 2차 approximation 으로 동작. (분포 간 거리 제약)
  • Limitations.
    • No principled way to choose \(\alpha\)
    • \(\rightarrow\) TRPO의 Trust Region 이 해결

Overview

Surrogate objective function

\(L_\pi (\tilde{\pi}) = \eta (\pi) + \sum_{s} {\rho_{\pi} (s) \sum_{a} {\tilde{\pi} (a \vert s) A^{\pi}(s, a)}}\)

  • 대리 목적 함수를 최적화하면, 비자명한 스텝 사이즈로 정책 개선을 보장함. (to be proved)
    • 왜 surrogate? max 문제를 min 문제로 변환하는 느낌인가?
      • No. 원래의 목적 함수 자체가 closed-form 최적화 불가능함.

Expected return(\(\eta\)) of policy \(\tilde{\pi}\)

\(\eta(\tilde{\pi}) = \eta(\pi) + \sum_{s} {\rho_{\tilde{\pi}} (s) \sum_{a} {\tilde{\pi} (a \vert s) A^{\pi}(s, a)}}\)

  • 이게 원래의 목적 함수. \(\eta(\tilde{\pi})\) 를 계산하려면 \(\rho_{\tilde{\pi}}(s)\) 가 필요한데, \(\rho_{\tilde{\pi}}(s)\) 가 \(\tilde{\pi}\) 에 복잡하게 의존해버림(intractable). 그래서 \(\tilde{\pi}\) 를 \(\pi\) 로 대체한 surrogate objective function 을 생각.
    • \(\eta(\tilde{\pi})\) 와 \(L_{\pi} (\tilde{\pi})\) 는 \(\theta_0\) 근방에서 1차항까지 일치함.
      1. \[L_{\pi_{\theta_0}} (\pi_{\theta_0}) = \eta (\pi_{\theta_0})\]
      2. \[\nabla_\theta L_{\pi_{\theta_0}} (\pi_\theta) \vert_{\theta = \theta_0} = \nabla_\theta \eta(\pi_\theta) \vert_{\theta = \theta_0}\]
    • 충분히 작은 step size 를 가정하면, \(L_{\pi} (\tilde{\pi})\) 를 개선하는 것이 곧 \(\eta(\tilde{\pi})\) 의 개선이라고 볼 수 있음.
    • 이때, step size 의 범위를 Trust Region 이라고 함.

The Theoretically Justified Algorithm

  • 이론적으로 monotonic improvement 가 증명된 알고리즘을 소개함.
    1. 현재의 Policy \(\pi_{i}\) 에 대한 어드밴티지 함수 \(A_{\pi_{i}}\) 를 정확하게 evaluate
    2. surrogate function \(M_{i} (\pi) = L_{\pi_{i}} (\pi) - C \cdot \mathbb{D}^{max}_{KL} (\pi_{i}, \pi)\) 을 구성
    3. \(M_{i} (\pi)\) 를 최대화하는 \(\pi_{i+1}\) 을 계산
    4. Repeat. 단조증가 보장 \(\Rightarrow\) \(\eta(\pi_0) \leq \eta(\pi_1) \leq \cdots\)
  • 어떻게 나왔는가?
Minorization-Maximization (MM) algorithm
  • Background. 고전적인 optimization methods 는 high dimensional setting에 부적합(매우 큰 matrix를 다뤄야 함).
    • Other algorithms dealing with high dimensional setting? ─ 왜 하필 MM이었을까?
  • expectation을 최대화할 수 있는 방법
  • objective function \(f = \eta\)
  • surrogate function \(g = M_i\)
  • conditions
    • \[M_i(\pi_i) = \eta(\pi_i)\]
    • \(M_i(\pi) \leq \eta(\pi)\) for \(\forall \pi\)
  • MM Algorithm.
  • Key Result. 단조증가를 보장함.

Approximations

  • 위 알고리즘에 일련의 근사를 적용.
1. max KL \(\approx\) avg KL

\(\mathbb{D}^{max}_{KL} (\theta_{old}, \theta) \approx \bar{\mathbb{D}}^{\rho_{\theta_{old}}}_{KL}\) \(\max_s D_{KL}\bigl(\pi(\cdot\vert s) \| \tilde{\pi}(\cdot\vert s)\bigr) \approx \mathbb{E}_{s \sim \rho}\left[D_{KL}\bigl(\pi(\cdot\vert s) \| \tilde{\pi}(\cdot\vert s)\bigr)\right]\)

  • Why? ─ 모든 state에서 KL divergence 를 계산하고 제약해야 함. 평균 KL divergence 는 샘플로 쉽게 추정이 가능. 실험적으로 max KL 제약과 비슷한 성능.
2. KL penalty \(\to\) KL constraint

\(\underset{\theta}{\max}\left[L_{\theta_{old}}(\theta) - C \cdot D^{max}_{KL}(\theta_{old}, \theta)\right] \;\;\longrightarrow\;\; \underset{\theta}{\max}\; L_{\theta_{old}}(\theta) \;\;\text{subject to}\;\; \bar{D}^{\rho_{\theta_{old}}}_{KL}(\theta_{old}, \theta) \leq \delta\)

  • Why? ─ 이론에서 얻은 페널티 계수 \(C = \frac{4\epsilon\gamma}{(1-\gamma)^2}\) 은 매우 큼. 경험적으로도 \(C\) 를 robust하게 선택하기 어려움. 대신, \(KL \leq \delta\) 라는 hard constraint 를 사용하면 하이퍼파라미터 \(\delta\) 만 선택하면 되고, 더 큰 업데이트를 강건하게 수행할 수 있음.
    • Why is \(C\) so big? ─ Thm.1 증명에 쓰인 보수적인 가정(worst-case bound) 때문. discount factor \(\gamma \to 1\) 이면 발산.
    • Why is hard constraint robust? ─ 페널티 방식은 \(L(\theta) - C \cdot \mathbb{D}_{KL}\) 을 maximize, \(C\) 값에 민감함(안 움직이거나 catastrophic update). 반면, 제약 방식은 \(\mathbb{D}_{KL} \leq \delta\) 가 보장되므로 이 영역 안에서 \(L(\theta)\) 를 maxiimize하면 됨. Catastrophic update를 방지, \(\delta\) 값 자체에 덜 민감함.
    • Is this conversion heuristic method? ─ 완전한 휴리스틱은 아님. Lagrangian duality 로, 적절한 \(\delta\) 에 대응하는 \(C\) 가 존재하여 같은 해를 찾을 수 있음. (TODO: 직접 해보기)
3. \(A_{\pi}\) estimation via Monte-Carlo sampling
  • Objective: \(\sum_a \pi_{\theta} (a \mid s) A_{\pi_{old}}(s, a)\) \(\sum_a \pi_{\theta} (a \mid s) A_{\pi_{old}}(s, a) = \sum_a \pi_{\theta} (a \mid s) Q_{\pi_{old}}(s, a) - \cancelto{1}{\sum_a \pi_{\theta}(a \mid s)} V_{\pi_{old}}(s)\)
  • Surrogate: \(\sum_a \pi_{\theta} (a \mid s) Q_{\pi_{old}}(s, a)\)
  • 어드밴티지 대신 Q-value를 최대화, \(Q_{\pi_{old}}\) 를 몬테카를로 샘플링으로 추정.
    • Why? ─ 모든 state-action pairs에 대해 어드밴티지를 계산할 수 없으므로, trajectory 샘플로부터 Q-value 를 추정함.
Single-path

\(\hat{Q}(s_t, a_t) = \sum_{l=0}^{T-t} {\gamma}^{l} r(s_{t+l})\)

  • 하나의 trajectory 를 시뮬레이션
  • \((s, a)\) pairs 의 미래 할인 보상 합으로 Q-value 추정
  • model-free 설정에 적용 가능
Vine

\(\hat{Q}(s_n, a_{n, k}) = r(s_n) + \gamma r(s_1') + {\gamma}^2 r(s_2') + \cdots\)

  • 상태 \(s_n\) 에서 \(k\) 개의 행동(\(a_{n, k}\)) 에 대해 짧은 Roll-out
  • same random number sequence 사용
    • common random numbers(CRN)
    • Q-value differences 의 variance 를 줄임
      • 후속?연구. GAE (Schulman et al., 2016)
  • 시스템을 다시 특정 지점으로 복원해야 하므로 보통 시뮬레이션에서만 가능
  • finite action space 와 continuous action space 의 동작이 다르네? 이것도 정리해야 함

Trust Region Policy Optimization

  • 여러 근사를 적용한 알고리즘이 바로 TRPO임.
    1. 몬테카를로 샘플링으로 state-action pairs 에 대한 Q-value를 추정
    2. Objective function, KL constraint 추정.
      • \[\underset{\theta}{\max}\;\mathbb{E}_{s \sim \rho_{\theta_{old}},\, a \sim q}\!\left[\frac{\pi_\theta(a|s)}{q(a|s)}\, Q_{\theta_{old}}(s,a)\right]\]
      • subject to: $$\mathbb{E}{s \sim \rho{\theta_{old}}}!\left[D_{KL}!\left(\pi_{\theta_{old}}(\cdot s) | \pi_\theta(\cdot s)\right)\right] \leq \delta$$
    3. Conjugate gradient, line search 로 제약 최적화 문제를 근사하여 \(\theta\) 업데이트.
      • KL constraint 2차 근사
      • objective 1차 근사
      • natural gradient \(F^{-1}g\) \(\to\) conjugate gradient \(Fx=g\) 를 iterative하게(\(k \approx 10\)) 근사
      • backtracking line search 로 KL constraint 하에서 objective improvement 가 가능한 step size 결정
  • 이 알고리즘은 scalable함.
    • model-free policy 탐색의 큰 도전 과제였던, 수만 개의 파라미터를 가진 비선형 정책을 최적화할 수 있음.
      • model-free 는 전이확률분포 \(P(s' \mid s, a)\) 를 학습하지 않는 것. 정책의 파라미터 유무와는 상관 없음
  • 여러 실험을 수행함.
    • 수영/점프/보행
    • Atari 게임 플레이
    • single path vs. vine
    • TRPO vs. Natural Policy Gradient
      • Final performance
      • Computing time
      • Sample complexity

Full Derivation

Objective

  • Expected discounted reward of stochastic policy \(\pi\) (policy performance) \(\eta(\pi) = \mathbb{E}_{s_0, a_0, \cdots} \left[{\sum_{t=0}^{\infty} {\gamma}^t r (s_t)}\right]\)
  • An identity between policies \(\pi\), \(\tilde{\pi}\) \(\eta(\tilde{\pi}) = \eta(\pi) + \mathbb{E}_{s_0, a_0, \cdots \sim \tilde{\pi}} \left[{\sum_{t=0}^{\infty} {\gamma}^t A_{\pi} (s_t, a_t)}\right] \tag{1}\) \(= \eta(\pi) + \sum_{t=0}^{\infty} \sum_{s} P(s_t = s \mid \tilde{\pi}) \sum_{a} \tilde{\pi}(a \mid s)\; {\gamma}^t A_\pi (s, a)\)
  • Introduce (unnormalized) discounted visitation frequencies \(\rho_\tilde{\pi}\) \(\rho_{\pi}(s) = \sum_{t=0}^{\infty} {\gamma}^t P(s_t = s \mid \pi)\) \(\eta(\tilde{\pi}) = \eta(\pi) + \sum_s \rho_{\tilde{\pi}}(s) \sum_a \tilde{\pi}(a \mid s) A_{\pi} (s, a) \tag{2}\)
  • The condition required to improve policy performance \(\sum_s \rho_{\tilde{\pi}}(s) \sum_a \tilde{\pi}(a \mid s) A_{\pi} (s, a) \geq 0\)

Surrogate

  • The local approximation using the visitation frequency \(\rho_{\pi}\) instead of \(\rho_{\tilde{\pi}}\) \(L_{\pi}(\tilde{\pi}) = \eta(\pi) + \sum_{s} \rho_{\pi} (s) \sum_a \tilde{\pi}(a \mid s) A_{\pi} (s, a)\)
  • \(L_{\pi}(\tilde{\pi})\) and \(\eta(\tilde{\pi})\) agree up to first order at \(\pi_{\theta_0}\) for any parameter value \(\theta_0\) \(\Delta_{\eta_\mu} = \frac{\alpha}{1-\gamma} \mathbb{A}_{\pi, \mu} (\pi') + O({\alpha}^2) \tag{Kakade-Langford}\)
  • Let \(D(\theta) = L_{\pi}(\pi_{\theta}) - \eta(\pi_{\theta})\), \(\begin{aligned} D(\theta) &= \sum_{s} \underbrace{\left( \rho_{\pi} - \rho_{\pi_{\theta}} \right) (s)}_{F(\theta, s)} \overbrace{\sum_a \pi_{\theta}(a \mid s) A_{\pi} (s, a)}^{G(\theta, s)} \end{aligned}\)
  • if \(\theta = \theta_0\), then \(\pi_{\theta} = \pi\) \(\begin{aligned} F(\theta_0, s) &= \left( \rho_{\pi} - \rho_{\pi} \right) (s) \\ &= 0 \\ \\ G(\theta_0, s) &= \sum_a \pi(a \mid s) A_{\pi} (s, a) \\ &= \sum_a \pi(a \mid s) \left(Q_{\pi}(s, a) - V_{\pi}(s)\right) \\ &= 0 \end{aligned}\) \(D(\theta) \big|_{\theta = \theta_0} = \sum_{s} F(\theta_0, s)\;G(\theta_0, s) = 0\) \(\nabla_{\theta} D(\theta) \big|_{\theta = \theta_0} = \sum_{s} \left[ F(\theta, s) \nabla_{\theta} G(\theta, s) + G(\theta, s) \nabla_{\theta} F(\theta, s) \right] \bigg|_{\theta = \theta_0} = 0\)
  • This second-order term will be constrained by KL divergence. \(\begin{aligned} \nabla_{\theta}^2 D(\theta) \big|_{\theta = \theta_0} &= \sum_{s} \left[ F(\theta, s) \nabla_{\theta}^2 G(\theta, s) + 2 \nabla_{\theta} G(\theta, s) \cdot \nabla_{\theta} F^{T}(\theta, s) + G(\theta, s) \nabla_{\theta}^2 F(\theta, s) \right] \bigg|_{\theta = \theta_0} \\ &= 2 \sum_{s} \nabla_{\theta} G(\theta_0, s) \cdot \nabla_{\theta} F^{T}(\theta_0, s) \end{aligned}\)

Monotonic Improvement Guarantee for General Stochastic Policies

Definition (\(\alpha\)-coupling).

A policy pair \((\pi, \tilde{\pi})\) is \(\alpha\)-coupled if there exists a joint distribution \((a, \tilde{a}) \mid s\) with marginals \(\pi(\cdot \mid s)\), \(\tilde{\pi}(\cdot \mid s)\) such that \(P(a \ne \tilde{a} \mid s) \leq \alpha\)

Proposition.

(Levin et al.) If \(\mathbb{D}_{TV}^{max} (\pi, \tilde{\pi}) \leq \alpha\), then an \(\alpha\)-coupling exists.

Theorem 1. Policy Improvement Bound

Let \(\alpha = \mathbb{D}_{TV}^{max} (\pi_{old}, \pi_{new})\). Then the following bound holds: \(\eta({\pi_{new}}) \geq L_{\pi_{old}} (\pi_{new}) - \frac{4\epsilon\gamma}{(1-\gamma)^2} {\alpha}^2\) \(where \;\epsilon = \underset{s, a}{max} \lvert{A_\pi (s, a)}\rvert\) Proof. Taking expectation over trajectories \(\tau := (s_0, a_0, s_1, a_1, \cdots)\), \(\eta(\tilde{\pi}) = \eta(\pi) + \mathbb{E}_{\tau \sim \tilde{\pi}} \left[{\sum_{t=0}^{\infty} {\gamma}^t A_{\pi} (s_t, a_t)}\right]\) Define the expected advantage of \(\tilde{\pi}\) over \(\pi\) at state \(s\): \(\bar{A}(s) = \mathbb{E}_{a \sim \tilde{\pi}(\cdot \mid s)} \left[{A_{\pi}(s, a)}\right]\) Note that \(L_{\pi}\) can be written as \(L_{\pi}(\tilde{\pi}) = \eta(\pi) + \mathbb{E}_{\tau \sim \pi} \left[{\sum_{t=0}^{\infty} {\gamma}^t \bar{A} (s_t)}\right]\) The difference between \(\eta(\tilde{\pi})\) and \(L_\pi\) : \(\eta(\tilde{\pi}) - L_{\pi} (\tilde{\pi}) = \mathbb{E}_{\tau \sim \tilde{\pi}} \left[{\sum_{t=0}^{\infty} {\gamma}^t A_{\pi} (s_t, a_t)}\right] - \mathbb{E}_{\tau \sim \pi} \left[{\sum_{t=0}^{\infty} {\gamma}^t \bar{A} (s_t)}\right] \tag{*}\) To ensure monotonic improvement, the absolute value of \((*)\) should be bounded.

Lemma 1. \(\lvert \bar{A}(s) \rvert \leq 2\alpha\epsilon\)

\(\bar{A}(s) = \mathbb{E}_{\tilde{\pi}} \left[A\pi(s, a)\right] - \mathbb{E}_{\pi} \left[A\pi(s, a)\right]\) (TODO)

Lemma 2. \(\mathbb{D}_{TV} (P^{\tilde{\pi}}_t, P^{\pi}_t) \leq t\alpha\)

Construct paired process \((s_t, \tilde{s}_t)\) sharing transition randomness under \(\alpha\)-coupling (TODO)

Bound timestep \(t\) term of \((*)\) \(\begin{aligned} \biggl| \mathbb{E}P^{\tilde{\pi}}_t \left[\bar{A}(s)\right] - \mathbb{E}P^{\pi}_t \left[\bar{A}(s)\right] \biggr| &= \biggl| \sum_s \left({P^{\tilde{\pi}}_t(s) - P^{\pi}_t}(s)\right) \bar{A}(s) \biggr| \\ &\leq \underset{s}{max} \left|\bar{A}(s)\right| \;\cdot\; \sum_s \left|{P^{\tilde{\pi}}_t(s) - P^{\pi}_t}(s)\right| \\ &= \underset{s}{max} \left|\bar{A}(s)\right| \;\cdot\; 2 \mathbb{D}_{TV} (P^{\tilde{\pi}}_t, P^{\pi}_t) \\ &\leq (2\alpha\epsilon)\;\cdot\;(2t\alpha) = 4\epsilon\alpha^2t \end{aligned}\) Total timesteps: \(\begin{aligned} \lvert \eta(\tilde{\pi}) - L_{\pi} (\tilde{\pi}) \rvert &= \sum_{t=0}^{\infty} \biggl| \mathbb{E}_{s \sim P^{\tilde{\pi}}_t} \left[\bar{A}(s)\right] - \mathbb{E}_{s \sim P^{\pi}_t} \left[\bar{A}(s)\right] \biggr| \\ &\leq \sum_{t=0}^{\infty} \gamma^t \cdot 4\epsilon\alpha^2t = 4\epsilon\alpha^2 \cdot \frac{\gamma}{(1-\gamma)^2} \end{aligned}\) Therefore, \(\eta(\tilde{\pi}) \geq L_{\pi} (\tilde{\pi}) - \frac{4\epsilon\gamma}{(1-\gamma)^2} \alpha^2 \tag{**}\)

Minorization-Maximization

Definition (Minorizer). A function \(M_i(\pi_i)\) called minorizer such that \(M_i(\pi_i) = \eta(\pi_i)\) \(M_i(\pi) \leq \eta(\pi) \quad \text{for } \forall\pi\) Rearranging \((**)\), \(\begin{aligned} \eta(\tilde{\pi}) \geq \underbrace{L_{\pi} (\tilde{\pi}) - C\cdot\mathbb{D}_{KL}^{max}(\pi, \tilde{\pi})}_{M_i(\tilde{\pi})} \end{aligned}\) MM Algorithm

  1. construct a minorizer \(M_i(\pi)\) under the current policy \(\pi_i\).
  2. \[\pi_{i+1} = \underset{\pi}{argmax} M_i(\pi)\]
  3. repeat. \(\eta(\pi_{i+1}) \geq M_i(\pi_{i+1}) \geq M_i(\pi_i) = \eta(\pi_i)\)

Optimization of Parameterized Policies

  • 지금까지의 결과를 parameterized policy \(\pi_\theta\) 에 적용.
  • 이론적 알고리즘(MM)의 trust region 문제: \(\underset{\theta}{\max}\; L_{\theta_{old}}(\theta) \quad \text{subject to} \quad \bar{\mathbb{D}}^{\rho_{\theta_{old}}}_{KL}(\theta_{old}, \theta) \leq \delta\)
  • 이 제약 최적화 문제를 직접 풀기엔 계산량이 너무 많음. 근사가 필요.

KL Divergence의 2차 근사

  • KL divergence를 \(\theta_{old}\) 근방에서 테일러 전개. \(\bar{\mathbb{D}}^{\rho_{\theta_{old}}}_{KL}(\theta_{old}, \theta) \approx \frac{1}{2} (\theta - \theta_{old})^T F (\theta - \theta_{old})\) \(F_{ij} = \mathbb{E}_{s \sim \rho_{\theta_{old}}} \left[\frac{\partial^2}{\partial\theta_i \partial\theta_j} \mathbb{D}_{KL}\left(\pi_{\theta_{old}}(\cdot \mid s) \;\|\; \pi_\theta(\cdot \mid s)\right)\right] \Bigg|_{\theta = \theta_{old}}\)
  • \(F\) 는 Fisher Information Matrix(FIM).     - KL divergence의 Hessian이 곧 FIM이 되는 이유:         - \(\theta = \theta_{old}\) 에서 KL divergence는 0이고, gradient도 0.         - 1차항까지 소멸, 2차항이 leading term.         - \(\nabla^2_\theta \mathbb{D}_{KL}(\pi_{\theta_{old}} \| \pi_\theta) \big|_{\theta=\theta_{old}} = \mathbb{E}_{\pi_{\theta_{old}}} \left[\nabla_\theta \log \pi_\theta \; \nabla_\theta \log \pi_\theta^T \right] \big|_{\theta=\theta_{old}} = F\)     - Why does the Hessian of KL equal FIM? ─ \(\mathbb{D}_{KL}(p \| q_\theta) = -H(p) - \mathbb{E}_p[\log q_\theta]\). \(\theta = \theta_{old}\) 에서 \(p = q_\theta\) 이므로, Hessian 계산 시 \(\mathbb{E}_p[\nabla \log q \cdot \nabla \log q^T]\) 가 남음. 이것이 FIM의 정의.

Objective의 1차 근사

\(L_{\theta_{old}}(\theta) \approx L_{\theta_{old}}(\theta_{old}) + g^T (\theta - \theta_{old})\) \(g = \nabla_\theta L_{\theta_{old}}(\theta) \big|_{\theta = \theta_{old}}\)

  • \(g\) 는 surrogate objective의 policy gradient

제약 최적화의 해: Natural Gradient

  • 근사된 문제: \(\underset{\theta}{\max}\; g^T (\theta - \theta_{old}) \quad \text{subject to} \quad \frac{1}{2} (\theta - \theta_{old})^T F (\theta - \theta_{old}) \leq \delta\)
  • Lagrangian: \(\mathcal{L}(\theta, \lambda) = g^T (\theta - \theta_{old}) - \lambda \left[\frac{1}{2} (\theta - \theta_{old})^T F (\theta - \theta_{old}) - \delta\right]\)
  • KKT 1차 조건 \(\nabla_\theta \mathcal{L} = 0\): \(g - \lambda F (\theta - \theta_{old}) = 0 \implies \theta - \theta_{old} = \frac{1}{\lambda} F^{-1} g\)
  • 제약 등호 조건 \(\frac{1}{2}(\theta - \theta_{old})^T F (\theta - \theta_{old}) = \delta\) 에서 \(\lambda\) 를 결정: \(\frac{1}{2\lambda^2} g^T F^{-1} g = \delta \implies \lambda = \sqrt{\frac{g^T F^{-1} g}{2\delta}}\)
  • 최종 업데이트: \(\theta = \theta_{old} + \sqrt{\frac{2\delta}{g^T F^{-1} g}}\; F^{-1} g\)
  • \(F^{-1} g\) 를 natural gradient 라고 함.     - 일반적인 gradient \(g\) 는 유클리드 공간에서의 steepest ascent 방향.   - natural gradient \(F^{-1} g\) 는 분포 공간(확률 다양체) 에서의 steepest ascent 방향.     - Why is natural gradient better? ─ 파라미터 공간에서의 작은 변화가 정책 분포에 큰 영향을 줄 수도, 작은 영향을 줄 수ㅇ도 있음. FIM은 파라미터 변화가 분포 변화에 미치는 영향을 측정하므로, \(F^{-1}\) 을 곱하면 분포 공간에서 균일한 스텝을 취하게 됨.     - \(\sqrt{\frac{2\delta}{g^T F^{-1} g}}\) 는 step size. trust region 크기 \(\delta\) 에 의해 결정됨.

Conjugate Gradient (CG) 로 \(F^{-1}g\) 근사

  • 문제: \(F\) 는 $$ \theta \times \theta \(행렬. 뉴럴넷의 파라미터가 수만~수백만 개이면\)F$$ 를 명시적으로 저장하거나 역행렬을 구하는 것이 불가능.
  • 해결: \(Fx = g\) 를 iterative하게 풀어 \(x \approx F^{-1}g\) 를 구함.   - \(F\) 자체를 저장하지 않고, Fisher-vector product \(Fv\) 만 계산할 수 있으면 됨. \(F = \nabla^2_\theta \;\mathbb{D}_{KL} (\theta_{old}, \theta) \big|_{\theta=\theta_{old}}\)
Fisher-Vector Product

\(Fv = \nabla_\theta \left[(\nabla_\theta \mathbb{D}_{KL})^T v\right]\)

  • \(\nabla_\theta \mathbb{D}_{KL}\) 를 계산 (1st backprop)
  • 이 gradient 벡터와 \(v\) 의 내적을 \(\theta\) 에 대해 다시 미분 (2nd backprop)
  • 즉, 두 번의 자동미분으로 \(Fv\) 를 계산. 행렬 \(F\) 를 저장할 필요 없음.     - JAX에서의 구현jax.jvp 또는 jax.grad 를 두 번 적용. 구체적으로:         ```python

        def fvp(params, v, states):

            def kl_fn(p):

                return avg_kl(p, params_old, states)

            g = jax.grad(kl_fn)(params)

            # g^T v 를 다시 미분

            return jax.grad(lambda p: jnp.dot(jax.grad(kl_fn)(p), v))(params)

        ```         또는 더 효율적으로 jax.jvp + jax.vjp 조합 사용.

CG Algorithm
  • \(Fx = g\) 를 풀기 위한 CG:     1. \(x_0 = 0\), \(r_0 = g\), \(p_0 = g\)     2. for \(k = 0, 1, \cdots, K-1\):         - \(\alpha_k = \frac{r_k^T r_k}{p_k^T F p_k}\)         - \(x_{k+1} = x_k + \alpha_k p_k\)         - \(r_{k+1} = r_k - \alpha_k F p_k\)         - \(\beta_{k+1} = \frac{r_{k+1}^T r_{k+1}}{r_k^T r_k}\)         - \(p_{k+1} = r_{k+1} + \beta_{k+1} p_k\)     3. return \(x_K \approx F^{-1} g\)
  • 실용적으로 \(K \approx 10\) 이면 충분. 매 iteration에서 \(Fp_k\) (Fisher-vector product)만 필요.
  • Why CG works here ─ \(F\) 는 positive semi-definite (FIM이므로). CG는 PSD 행렬에 대한 선형계를 효율적으로 풀 수 있고, 행렬 곱 연산만 필요.     - damping: 수치 안정성을 위해 \(F \leftarrow F + \lambda I\) (\(\lambda \approx 0.1\)) 를 더해줌. 이러면 \(F\) 가 strictly positive definite가 됨.
  • CG로 search direction \(s = F^{-1}g\) 를 구한 뒤, KL constraint를 실제로 만족하는 step size를 찾아야 함.     - 2차 근사는 \(\theta_{old}\) 근방에서만 정확. 큰 스텝에서는 실제 KL이 \(\delta\) 를 초과할 수 있음.
  • 절차:     1. 이론적 최대 스텝: \(\beta_{max} = \sqrt{\frac{2\delta}{s^T F s}}\)     2. \(\beta = \beta_{max}\) 에서 시작     3. \(\theta_{new} = \theta_{old} + \beta \cdot s\) 로 업데이트     4. 다음 조건을 확인:         - \(\bar{\mathbb{D}}_{KL}(\theta_{old}, \theta_{new}) \leq \delta\) (KL constraint 만족)         - \(L_{\theta_{old}}(\theta_{new}) \geq L_{\theta_{old}}(\theta_{old})\) (surrogate objective 개선)     1. 만족하지 않으면 \(\beta \leftarrow c \cdot \beta\) (\(c \approx 0.5\)) 로 축소. 반복.
  • Why not just use the theoretical step? ─ 이론적 스텝은 \(L\) 과 \(\mathbb{D}_{KL}\) 모두 근사에 기반. 실제 \(\mathbb{D}_{KL}\) 이 \(\delta\) 를 초과하거나 \(L\) 이 감소할 수 있음. Line search로 이를 보정.

Sample-Based Estimation of the Objective and Constraint

  • 위의 최적화 문제에서 expectation \(\mathbb{E}_s\), \(\mathbb{E}_a\) 를 정확히 계산할 수 없으므로, trajectory 샘플로 추정.

Importance Sampling

  • \(\theta_{old}\) 로 trajectory를 수집했으므로, \(a \sim \pi_{\theta_{old}}(\cdot \mid s)\) 에서 샘플링됨.
  • \(\pi_\theta\) 하의 기대값을 importance sampling으로 변환: \(\sum_a \pi_\theta(a \mid s) A_{\theta_{old}}(s, a) = \mathbb{E}_{a \sim q} \left[\frac{\pi_\theta(a \mid s)}{q(a \mid s)} A_{\theta_{old}}(s, a)\right]\)

여기서 \(q = \pi_{\theta_{old}}\).

  • 샘플 기반 surrogate objective:
\[\hat{L}_{\theta_{old}}(\theta) = \frac{1}{|D|} \sum_{(s, a) \in D} \frac{\pi_\theta(a \mid s)}{\pi_{\theta_{old}}(a \mid s)} \hat{A}_{\theta_{old}}(s, a)\]
  • 샘플 기반 KL constraint:
\[\hat{\bar{\mathbb{D}}}_{KL} = \frac{1}{|D_s|} \sum_{s \in D_s} \mathbb{D}_{KL}\left(\pi_{\theta_{old}}(\cdot \mid s) \;\|\; \pi_\theta(\cdot \mid s)\right)\]
  • \(D_s\) 는 수집된 state들의 집합.

    - KL divergence는 closed-form (Gaussian policy의 경우):

    $$\mathbb{D}_{KL}(\mathcal{N}_1 | \mathcal{N}_2) = \frac{1}{2}\left[\log\frac{ \Sigma_2 }{ \Sigma_1 } - d + \text{tr}(\Sigma_2^{-1}\Sigma_1) + (\mu_2-\mu_1)^T \Sigma_2^{-1} (\mu_2-\mu_1)\right]$$

    - 대각 공분산 \(\Sigma = \text{diag}(\sigma_1^2, \ldots, \sigma_d^2)\) 이면 더 단순해짐.

Advantage 추정: Generalized Advantage Estimation (GAE)

  • Single-path의 MC return \(\hat{Q}(s_t, a_t) = \sum_{l=0}^{T-t} \gamma^l r(s_{t+l})\) 는 unbiased이지만 high variance.

  • TD residual을 사용한 advantage 추정이 더 실용적.

TD Residual

\(\delta_t^V = r(s_t) + \gamma V(s_{t+1}) - V(s_t)\)

  • \(V\) 가 정확하면 \(\mathbb{E}[\delta_t^V] = A_\pi(s_t, a_t)\). 즉, TD residual은 advantage의 unbiased estimator.
  • 하지만 \(V\) 는 근사이므로 bias가 존재.
    GAE(\(\lambda\))

    \(\hat{A}_t^{GAE(\gamma, \lambda)} = \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l}^V\)

  • \(\lambda = 0\): \(\hat{A}_t = \delta_t^V = r_t + \gamma V(s_{t+1}) - V(s_t)\) ─ low variance, high bias (1-step TD)
  • \(\lambda = 1\): \(\hat{A}_t = \sum_{l=0}^{\infty} \gamma^l \delta_{t+l}^V = \sum_{l=0}^{\infty} \gamma^l r_{t+l} - V(s_t)\) ─ high variance, low bias (MC return - baseline)
  • \(\lambda \in (0, 1)\): bias-variance tradeoff. 실용적으로 \(\lambda \approx 0.97\) 이 자주 쓰임.     - Insight ─ TRPO 원논문(2015)에서는 single-path MC return을 사용. GAE(Schulman et al., 2016)는 후속 연구로, TRPO와 결합하여 성능을 크게 향상시킴. 실질적으로 TRPO 구현에서는 거의 항상 GAE를 사용.

Policy Gradient의 샘플 추정

  • Policy gradient \(g\): \(g = \nabla_\theta \hat{L}_{\theta_{old}}(\theta) \bigg|_{\theta = \theta_{old}} = \frac{1}{|D|} \sum_{(s, a) \in D} \nabla_\theta \log \pi_\theta(a \mid s) \bigg|_{\theta = \theta_{old}} \hat{A}_{\theta_{old}}(s, a)\)
  • \(\theta = \theta_{old}\) 에서 importance ratio \(\frac{\pi_\theta}{\pi_{\theta_{old}}} = 1\) 이 되므로, 위처럼 단순화됨.

Fisher-Vector Product의 샘플 추정

\(\hat{F}v = \frac{1}{|D_s|} \sum_{s \in D_s} \nabla_\theta\left[(\nabla_\theta \mathbb{D}_{KL})^T v\right]\)

  • 자동미분으로 계산. 배치 크기 $$ D_s $$ 에 비례하는 계산 비용.

Practical Algorithm

TRPO Full Procedure

for iteration = 1, 2, ... do
    1. Rn policy π_θ_old to collect trajectories D = {τ_1, ..., τ_N}
    2. Estimate avantages Â_t using GAE(γ, λ) with fitted value function V_φ
    3. Compute policy gradient: g = ∇_θ L̂(θ)|_{θ=θ_old}
    4. Use CG to compute: s ≈ F⁻¹g  (K ≈ 10 iterations)
    5. Compute max step: β = √(2δ / sᵀFs)
    6. Backtracking line search:
       θ_new = θ_old + c^j · β · s
       where j is the smallest integer such that:
         - KL(θ_old, θ_new) ≤ δ
         - L̂(θ_new) ≥ L̂(θ_old)
    7.Update value function V_φ by regression on collected returns
    8. θ_old ← θ_new
end for

주요 하이퍼파라미터

하이퍼파라미터 의미 전형적 값
\(\delta\) Trust region 크기 (KL constraint) \(0.01\)
\(\gamma\) Discount factor \(0.99\)
\(\lambda\) GAE parameter \(0.97\)
\(K\) CG iterations \(10\)
damping FIM regularization \(\lambda_{\text{damp}}\) \(0.1\)
backtrack coeff Line search 축소 비율 \(0.5\)
backtrack iters Line search 최대 반복 \(10\)

TRPO vs. Natural Policy Gradient (NPG)

| | NPG | TRPO | | ————- | ————————- | —————————————– | | 업데이트 | \(\theta + \alpha F^{-1}g\) | \(\theta + \beta F^{-1}g\) with line search | | Step size | 고정 \(\alpha\) | Adaptive (trust region \(\delta\) 기반) | | KL constraint | 없음 (implicit) | 명시적 (\(\bar{\mathbb{D}}_{KL} \leq \delta\)) | | 안정성 | \(\alpha\) 에 민감 | Line search로 보정, 더 robust |

  • TRPO는 NPG에 trust region constraint와 line search를 추가한 것으로 이해할 수 있음.
  • NPG에서 step size를 잘못 설정하면 catastrophic update가 발생할 수 있지만, TRPO는 이를 line search로 방지.