1716 lines
111 KiB
HTML
Raw Permalink Normal View History

2024-03-15 14:52:38 +08:00
<DIV class="page" id="readability-page-1">
<div>
<center>
<a href="http://folinoid.com/">[Home]</a>
</center>
<center>
<b><a href="https://folinoid.com/">Emmanuel Bengio</a>, <a href="https://mj10.github.io/">Moksh Jain</a>, <a href="https://scholar.google.com/citations?user=TpuvCSwAAAAJ&hl=en">Maksym Korablyov</a>, <a href="https://www.cs.mcgill.ca/~dprecup/">Doina Precup</a>, <a href="https://yoshuabengio.org/">Yoshua Bengio</a></b>
</center><br>
<center>
<b><a href="https://arxiv.org/abs/2106.04399">arXiv preprint</a>, <a href="https://github.com/bengioe/gflownet">code</a></b><br> also see the <b><a href="https://arxiv.org/abs/2111.09266">GFlowNet Foundations</a></b> paper<br> and a more recent (and thorough) <a href="https://tinyurl.com/gflownet-tutorial">tutorial on the framework</a>.
</center>
<p><i>What follows is a high-level overview of this work, for more details refer to our paper.</i> Given a reward <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> R(x) </annotation>
</semantics>
</math></span></span> and a deterministic episodic environment where episodes end with a ``generate <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> x </mi>
</mrow>
<annotation encoding="application/x-tex"> x </annotation>
</semantics>
</math></span></span>'' action, how do we generate diverse and high-reward <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> x </mi>
</mrow>
<annotation encoding="application/x-tex"> x </annotation>
</semantics>
</math></span></span>s?<br> We propose to use <i>Flow Networks</i> to model discrete <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> p </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
<mo></mo>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> p(x) \propto R(x) </annotation>
</semantics>
</math></span></span> from which we can sample sequentially (like episodic RL, rather than iteratively as MCMC methods would). We show that our method, <b>GFlowNet</b>, is very useful on a combinatorial domain, drug molecule synthesis, because unlike RL methods it generates diverse <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> x </mi>
</mrow>
<annotation encoding="application/x-tex"> x </annotation>
</semantics>
</math></span></span>s by design.<br>
<a name="s2" id="s2"></a>
</p>
<h3> Flow Networks </h3>
<p>A flow network is a directed graph with <i>sources</i> and <i>sinks</i>, and edges carrying some amount of flow between them through intermediate nodes -- think of pipes of water. For our purposes, we define a flow network with a single source, the root or <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 0 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_0 </annotation>
</semantics>
</math></span></span>; the sinks of the network correspond to the terminal states. We'll assign to each sink <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> x </mi>
</mrow>
<annotation encoding="application/x-tex"> x </annotation>
</semantics>
</math></span></span> an ``out-flow'' <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> R(x) </annotation>
</semantics>
</math></span></span>.</p>
<center>
<div id="can1_div">
<canvas id="can1" width="450px" height="225px"></canvas>
<p><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 0 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{0} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 1 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{1} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 2 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{2} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 3 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{3} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> x </mi>
<mn> 3 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> x_{3} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi mathvariant="normal"> </mi>
</mrow>
<annotation encoding="application/x-tex"> \top </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 5 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{5} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> x </mi>
<mn> 5 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> x_{5} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi mathvariant="normal"> </mi>
</mrow>
<annotation encoding="application/x-tex"> \top </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 7 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{7} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 8 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{8} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> x </mi>
<mn> 8 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> x_{8} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi mathvariant="normal"> </mi>
</mrow>
<annotation encoding="application/x-tex"> \top </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 10 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{10} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 11 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{11} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> x </mi>
<mn> 11 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> x_{11} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi mathvariant="normal"> </mi>
</mrow>
<annotation encoding="application/x-tex"> \top </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 13 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{13} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> x </mi>
<mn> 13 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> x_{13} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi mathvariant="normal"> </mi>
</mrow>
<annotation encoding="application/x-tex"> \top </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 15 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{15} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 16 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{16} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> x </mi>
<mn> 16 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> x_{16} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi mathvariant="normal"> </mi>
</mrow>
<annotation encoding="application/x-tex"> \top </annotation>
</semantics>
</math></span></span></span>
</p>
</div>
</center>
<p>Given the graph structure and the out-flow of the sinks, we wish to calculate a valid <i>flow</i> between nodes, e.g. how much water each pipe is carrying. Generally there can be infinite solutions, but this is not a problem here -- any valid solution will do. For example above, there is almost no flow between <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 7 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_7 </annotation>
</semantics>
</math></span></span> and <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 13 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{13} </annotation>
</semantics>
</math></span></span> that goes through <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 11 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{11} </annotation>
</semantics>
</math></span></span>, it all goes through <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 10 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{10} </annotation>
</semantics>
</math></span></span>, but the reverse solution would also be a valid flow.<br> Why is this useful? Such a construction corresponds to a generative model. If we follow the flow, we'll end up in a terminal state, a sink, with probability <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> p </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
<mo></mo>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> p(x) \propto R(x) </annotation>
</semantics>
</math></span></span>. On top of that, we'll have the property that the in-flow of <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 0 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_0 </annotation>
</semantics>
</math></span></span>--the flow of the unique source--is <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mo></mo>
<mi> x </mi>
</msub>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<mi> Z </mi>
</mrow>
<annotation encoding="application/x-tex"> \sum_x R(x)=Z </annotation>
</semantics>
</math></span></span>, the partition function. If we assign to each intermediate node a <i>state</i> and to each edge an <i>action</i>, we recover a useful MDP.<br> Let <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<mi> f </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> F(s,a)=f(s,s') </annotation>
</semantics>
</math></span></span> be the flow between <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> s </mi>
</mrow>
<annotation encoding="application/x-tex"> s </annotation>
</semantics>
</math></span></span> and <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</mrow>
<annotation encoding="application/x-tex"> s' </annotation>
</semantics>
</math></span></span>, where <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> T </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</mrow>
<annotation encoding="application/x-tex"> T(s,a)=s' </annotation>
</semantics>
</math></span></span>, i.e. <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</mrow>
<annotation encoding="application/x-tex"> s' </annotation>
</semantics>
</math></span></span> is the (deterministic) state transitioned to from state <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> s </mi>
</mrow>
<annotation encoding="application/x-tex"> s </annotation>
</semantics>
</math></span></span> and action <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> a </mi>
</mrow>
<annotation encoding="application/x-tex"> a </annotation>
</semantics>
</math></span></span>. Let <span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<semantics>
<mtable rowspacing="0.24999999999999992em" columnalign="right" columnspacing="">
<mtr>
<mtd>
<mstyle scriptlevel="0" displaystyle="true">
<mrow>
<mi> π </mi>
<mo stretchy="false"> ( </mo>
<mi> a </mi>
<mi mathvariant="normal"> </mi>
<mi> s </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<mfrac>
<mrow>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<mrow>
<munder>
<mo></mo>
<msup>
<mi> a </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</munder>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<msup>
<mi> a </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
</mrow>
</mfrac>
</mrow>
</mstyle>
</mtd>
</mtr>
</mtable>
<annotation encoding="application/x-tex"> \begin{aligned}\pi(a|s) = \frac{F(s,a)}{\sum_{a'}F(s,a')}\end{aligned} </annotation>
</semantics>
</math></span></span></span> then following policy <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> π </mi>
</mrow>
<annotation encoding="application/x-tex"> \pi </annotation>
</semantics>
</math></span></span>, starting from <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 0 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_0 </annotation>
</semantics>
</math></span></span>, leads to terminal state <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> x </mi>
</mrow>
<annotation encoding="application/x-tex"> x </annotation>
</semantics>
</math></span></span> with probability <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> R(x) </annotation>
</semantics>
</math></span></span> (see the paper for proofs and more rigorous explanations).<br>
<a name="s3" id="s3"></a>
</p>
<h3> Approximating Flow Networks </h3>
<p>As you may suspect, there are only few scenarios in which we can build the above graph explicitly. For drug-like molecules, it would have around <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mn> 1 </mn>
<msup>
<mn> 0 </mn>
<mn> 16 </mn>
</msup>
</mrow>
<annotation encoding="application/x-tex"> 10^{16} </annotation>
</semantics>
</math></span></span> nodes!<br> Instead, we resort to function approximation, just like deep RL resorts to it when computing the (action-)value functions of MDPs.<br> Our goal here is to approximate the flow <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> F(s,a) </annotation>
</semantics>
</math></span></span>. Earlier we called a <i>valid</i> flow one that correctly routed all the flow from the source to the sinks through the intermediary nodes. Let's be more precise. For some node <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</mrow>
<annotation encoding="application/x-tex"> s' </annotation>
</semantics>
</math></span></span>, let the in-flow <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> F(s') </annotation>
</semantics>
</math></span></span> be the sum of incoming flows: <span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<semantics>
<mtable rowspacing="0.24999999999999992em" columnalign="right" columnspacing="">
<mtr>
<mtd>
<mstyle scriptlevel="0" displaystyle="true">
<mrow>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<munder>
<mo></mo>
<mrow>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo> : </mo>
<mi> T </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</mrow>
</munder>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
</mrow>
</mstyle>
</mtd>
</mtr>
</mtable>
<annotation encoding="application/x-tex"> \begin{aligned}F(s') = \sum_{s,a:T(s,a)=s'} F(s,a)\end{aligned} </annotation>
</semantics>
</math></span></span></span> Here the set <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mo stretchy="false"> { </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo> : </mo>
<mi> T </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> } </mo>
</mrow>
<annotation encoding="application/x-tex"> \{s,a:T(s,a)=s'\} </annotation>
</semantics>
</math></span></span> is the set of state-action pairs that lead to <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</mrow>
<annotation encoding="application/x-tex"> s' </annotation>
</semantics>
</math></span></span>. Now, let the out-flow be the sum of outgoing flows--or the reward if <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</mrow>
<annotation encoding="application/x-tex"> s' </annotation>
</semantics>
</math></span></span> is terminal: <span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<semantics>
<mtable rowspacing="0.24999999999999992em" columnalign="right" columnspacing="">
<mtr>
<mtd>
<mstyle scriptlevel="0" displaystyle="true">
<mrow>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
<mo> + </mo>
<munder>
<mo></mo>
<mrow>
<msup>
<mi> a </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo></mo>
<mi mathvariant="script"> A </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
</mrow>
</munder>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo separator="true"> , </mo>
<msup>
<mi> a </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
<mi mathvariant="normal"> . </mi>
</mrow>
</mstyle>
</mtd>
</mtr>
</mtable>
<annotation encoding="application/x-tex"> \begin{aligned}F(s') = R(s') + \sum_{a'\in\mathcal{A}(s')} F(s',a').\end{aligned} </annotation>
</semantics>
</math></span></span></span> Note that we reused <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> F(s') </annotation>
</semantics>
</math></span></span>. This is because for a valid flow, the in-flow is equal to the out-flow, i.e. the flow through <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</mrow>
<annotation encoding="application/x-tex"> s' </annotation>
</semantics>
</math></span></span>, <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> F(s') </annotation>
</semantics>
</math></span></span>. Here <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi mathvariant="script"> A </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> \mathcal{A}(s) </annotation>
</semantics>
</math></span></span> is the set of valid actions in state <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> s </mi>
</mrow>
<annotation encoding="application/x-tex"> s </annotation>
</semantics>
</math></span></span>, which is the empty set when <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> s </mi>
</mrow>
<annotation encoding="application/x-tex"> s </annotation>
</semantics>
</math></span></span> is a sink. <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> R(s) </annotation>
</semantics>
</math></span></span> is 0 unless <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> s </mi>
</mrow>
<annotation encoding="application/x-tex"> s </annotation>
</semantics>
</math></span></span> is a sink, in which case <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo stretchy="false"> ) </mo>
<mo> &gt; </mo>
<mn> 0 </mn>
</mrow>
<annotation encoding="application/x-tex"> R(s)&gt;0 </annotation>
</semantics>
</math></span></span>.<br> We can thus call the set of these equalities for all states <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo mathvariant="normal"></mo>
<msub>
<mi> s </mi>
<mn> 0 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s'\neq s_0 </annotation>
</semantics>
</math></span></span> the <i>flow consistency equations</i>: <span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<semantics>
<mtable rowspacing="0.24999999999999992em" columnalign="right" columnspacing="">
<mtr>
<mtd>
<mstyle scriptlevel="0" displaystyle="true">
<mrow>
<munder>
<mo></mo>
<mrow>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo> : </mo>
<mi> T </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</mrow>
</munder>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
<mo> + </mo>
<munder>
<mo></mo>
<mrow>
<msup>
<mi> a </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo></mo>
<mi mathvariant="script"> A </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
</mrow>
</munder>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo separator="true"> , </mo>
<msup>
<mi> a </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
<mi mathvariant="normal"> . </mi>
</mrow>
</mstyle>
</mtd>
</mtr>
</mtable>
<annotation encoding="application/x-tex"> \begin{aligned}\sum_{s,a:T(s,a)=s'} F(s,a) = R(s') + \sum_{a'\in\mathcal{A}(s')} F(s',a').\end{aligned} </annotation>
</semantics>
</math></span></span></span></p>
<center>
<div id="can2_div">
<canvas id="can2" width="200px" height="135px"></canvas>
<p><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> a </mi>
<mn> 1 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> a_1 </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> a </mi>
<mn> 7 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> a_7 </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> a </mi>
<mn> 3 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> a_3 </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> a </mi>
<mn> 4 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> a_4 </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> a </mi>
<mn> 2 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> a_2 </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> a </mi>
<mn> 8 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> a_8 </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 0 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{0} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 1 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{1} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 2 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{2} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 3 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{3} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 4 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{4} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 5 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{5} </annotation>
</semantics>
</math></span></span></span><span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> s </mi>
<mn> 6 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> s_{6} </annotation>
</semantics>
</math></span></span></span>
</p>
</div>
</center>
<p>Here the set of parents <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mo stretchy="false"> { </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo> : </mo>
<mi> T </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<msub>
<mi> s </mi>
<mn> 3 </mn>
</msub>
<mo stretchy="false"> } </mo>
</mrow>
<annotation encoding="application/x-tex"> \{s,a:T(s,a)=s_3\} </annotation>
</semantics>
</math></span></span> is <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mo stretchy="false"> { </mo>
<mo stretchy="false"> ( </mo>
<msub>
<mi> s </mi>
<mn> 0 </mn>
</msub>
<mo separator="true"> , </mo>
<msub>
<mi> a </mi>
<mn> 1 </mn>
</msub>
<mo stretchy="false"> ) </mo>
<mo separator="true"> , </mo>
<mo stretchy="false"> ( </mo>
<msub>
<mi> s </mi>
<mn> 1 </mn>
</msub>
<mo separator="true"> , </mo>
<msub>
<mi> a </mi>
<mn> 7 </mn>
</msub>
<mo stretchy="false"> ) </mo>
<mo separator="true"> , </mo>
<mo stretchy="false"> ( </mo>
<msub>
<mi> s </mi>
<mn> 2 </mn>
</msub>
<mo separator="true"> , </mo>
<msub>
<mi> a </mi>
<mn> 3 </mn>
</msub>
<mo stretchy="false"> ) </mo>
<mo stretchy="false"> } </mo>
</mrow>
<annotation encoding="application/x-tex"> \{(s_0, a_1), (s_1, a_7), (s_2, a_3)\} </annotation>
</semantics>
</math></span></span>, and <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi mathvariant="script"> A </mi>
<mo stretchy="false"> ( </mo>
<msub>
<mi> s </mi>
<mn> 3 </mn>
</msub>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<mo stretchy="false"> { </mo>
<msub>
<mi> a </mi>
<mn> 2 </mn>
</msub>
<mo separator="true"> , </mo>
<msub>
<mi> a </mi>
<mn> 4 </mn>
</msub>
<mo separator="true"> , </mo>
<msub>
<mi> a </mi>
<mn> 8 </mn>
</msub>
<mo stretchy="false"> } </mo>
</mrow>
<annotation encoding="application/x-tex"> \mathcal{A}(s_3)=\{a_2,a_4,a_8\} </annotation>
</semantics>
</math></span></span>.<br> By now our RL senses should be tingling. We've defined a value function recursively, with two quantities that need to match.<br>
<a name="s4" id="s4"></a>
</p>
<h4> A TD-Like Objective </h4>
<p>Just like one can cast the Bellman equations into TD objectives, so do we cast the flow consistency equations into an objective. We want <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> F </mi>
<mi> θ </mi>
</msub>
</mrow>
<annotation encoding="application/x-tex"> F_\theta </annotation>
</semantics>
</math></span></span> that minimizes the square difference between the two sides of the equations, but we add a few bells and whistles: <span><span><span><math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<semantics>
<mtable rowspacing="0.24999999999999992em" columnalign="right" columnspacing="">
<mtr>
<mtd>
<mstyle scriptlevel="0" displaystyle="true">
<mrow>
<msub>
<mi mathvariant="script"> L </mi>
<mrow>
<mi> θ </mi>
<mo separator="true"> , </mo>
<mi> ϵ </mi>
</mrow>
</msub>
<mo stretchy="false"> ( </mo>
<mi> τ </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<munder>
<mo></mo>
<mpadded lspace="-0.5width" width="0px">
<mrow>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo></mo>
<mi> τ </mi>
<mo mathvariant="normal"></mo>
<msub>
<mi> s </mi>
<mn> 0 </mn>
</msub>
</mrow>
</mpadded>
</munder>
<mtext> </mtext>
<msup>
<mrow>
<mo fence="true"> ( </mo>
<mi> log </mi>
<mo> </mo>
<mtext> </mtext>
<mrow>
<mo fence="true"> [ </mo>
<mi> ϵ </mi>
<mo> + </mo>
<munder>
<mo></mo>
<mpadded lspace="-0.5width" width="0px">
<mrow>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo> : </mo>
<mi> T </mi>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
</mrow>
</mpadded>
</munder>
<mi> exp </mi>
<mo> </mo>
<msubsup>
<mi> F </mi>
<mi> θ </mi>
<mi> log </mi>
<mo> </mo>
</msubsup>
<mo stretchy="false"> ( </mo>
<mi> s </mi>
<mo separator="true"> , </mo>
<mi> a </mi>
<mo stretchy="false"> ) </mo>
<mo fence="true"> ] </mo>
</mrow>
<mo> </mo>
<mi> log </mi>
<mo> </mo>
<mtext> </mtext>
<mrow>
<mo fence="true"> [ </mo>
<mi> ϵ </mi>
<mo> + </mo>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
<mo> + </mo>
<munder>
<mo></mo>
<mpadded lspace="-0.5width" width="0px">
<mrow>
<msup>
<mi> a </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo></mo>
<mi mathvariant="script"> A </mi>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
</mrow>
</mpadded>
</munder>
<mi> exp </mi>
<mo> </mo>
<msubsup>
<mi> F </mi>
<mi> θ </mi>
<mi> log </mi>
<mo> </mo>
</msubsup>
<mo stretchy="false"> ( </mo>
<msup>
<mi> s </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo separator="true"> , </mo>
<msup>
<mi> a </mi>
<mo mathvariant="normal" lspace="0em" rspace="0em"> </mo>
</msup>
<mo stretchy="false"> ) </mo>
<mo fence="true"> ] </mo>
</mrow>
<mo fence="true"> ) </mo>
</mrow>
<mn> 2 </mn>
</msup>
<mi mathvariant="normal"> . </mi>
</mrow>
</mstyle>
</mtd>
</mtr>
</mtable>
<annotation encoding="application/x-tex"> \begin{aligned}\mathcal{L}_{\theta,\epsilon}(\tau) = \sum_{\mathclap{s'\in\tau\neq s_0}}\,\left(\log\! \left[\epsilon+{\sum_{\mathclap{s,a:T(s,a)=s'}}} \exp F^{\log}_\theta(s,a)\right]- \log\! \left[\epsilon + R(s') + \sum_{\mathclap{a'\in{\cal A}(s')}} \exp F^{\log}_\theta(s',a')\right]\right)^2.\end{aligned} </annotation>
</semantics>
</math></span></span></span> First, we match the <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> log </mi>
<mo> </mo>
</mrow>
<annotation encoding="application/x-tex"> \log </annotation>
</semantics>
</math></span></span> of each side, which is important since as intermediate nodes get closer to the root, their flow will become exponentially bigger (remember that <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> F </mi>
<mo stretchy="false"> ( </mo>
<msub>
<mi> s </mi>
<mn> 0 </mn>
</msub>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<mi> Z </mi>
<mo> = </mo>
<msub>
<mo></mo>
<mi> x </mi>
</msub>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> F(s_0) = Z = \sum_x R(x) </annotation>
</semantics>
</math></span></span>), but we care equally about all nodes. Second, we predict <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msubsup>
<mi> F </mi>
<mi> θ </mi>
<mi> log </mi>
<mo> </mo>
</msubsup>
<mo></mo>
<mi> log </mi>
<mo> </mo>
<mi> F </mi>
</mrow>
<annotation encoding="application/x-tex"> F^{\log}_\theta\approx\log F </annotation>
</semantics>
</math></span></span> for the same reasons. Finally, we add an <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> ϵ </mi>
</mrow>
<annotation encoding="application/x-tex"> \epsilon </annotation>
</semantics>
</math></span></span> value inside the <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> log </mi>
<mo> </mo>
</mrow>
<annotation encoding="application/x-tex"> \log </annotation>
</semantics>
</math></span></span>; this doesn't change the minima of the objective, but gives more gradient weight to large values and less to small values.<br> We show in the paper that a minimizer of this objective achieves our desiderata, which is to have <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> p </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
<mo></mo>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> p(x)\propto R(x) </annotation>
</semantics>
</math></span></span> when sampling from <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> π </mi>
<mo stretchy="false"> ( </mo>
<mi> a </mi>
<mi mathvariant="normal"> </mi>
<mi> s </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> \pi(a|s) </annotation>
</semantics>
</math></span></span> as defined above.<br>
<a name="s5" id="s5"></a>
</p>
<h3> GFlowNet as Amortized Sampling with an OOD Potential </h3>
<p>It is interesting to compare GFlowNet with Monte-Carlo Markov Chain (MCMC) methods. MCMC methods can be used to sample from a distribution for which there is no analytical sampling formula but an energy function or unnormalized probability function is available. In our context, this unnormalized probability function is our reward function <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
<mo> = </mo>
<msup>
<mi> e </mi>
<mrow>
<mo> </mo>
<mi> e </mi>
<mi> n </mi>
<mi> e </mi>
<mi> r </mi>
<mi> g </mi>
<mi> y </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
</msup>
</mrow>
<annotation encoding="application/x-tex"> R(x)=e^{-energy(x)} </annotation>
</semantics>
</math></span></span>.<br> Like MCMC methods, GFlowNet can turn a given energy function into samples but it does it in an amortized way, converting the cost a lot of very expensive MCMC trajectories (to obtain each sample) into the cost training a generative model (in our case a generative policy which sequentially builds up <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> x </mi>
</mrow>
<annotation encoding="application/x-tex"> x </annotation>
</semantics>
</math></span></span>). Sampling from the generative model is then very cheap (e.g. adding one component at a time to a molecule) compared to an MCMC. But the most important gain may not be just computational, but in terms of the ability to discover new modes of the reward function.<br> MCMC methods are iterative, making many small noisy steps, which can converge in the neighborhood of a mode, and with some probability jump from one mode to a nearby one. However, if two modes are far from each other, MCMC can require <i>exponential</i> time to mix between the two. If in addition the modes occupy a tiny volume of the state space, the chances of initializing a chain near one of the unknown modes is also tiny, and the MCMC approach becomes unsatisfactory. Whereas such a situation seems hopeless with MCMC, GFlowNet has the potential to discover modes and jump there directly, if there is structure that relates the modes that it already knows, and if its inductive biases and training procedure make it possible to generalize there.<br> GFlowNet does not need to perfectly know where the modes are: it is sufficient to make guesses which occasionally work well. Like for MCMC methods, once a point in the region of new mode is discovered, further training of GFlowNet will sculpt that mode and zoom in on its peak.<br> Note that we can put <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> R(x) </annotation>
</semantics>
</math></span></span> to some power <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> β </mi>
</mrow>
<annotation encoding="application/x-tex"> \beta </annotation>
</semantics>
</math></span></span>, a coefficient which acts like a temperature, and <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<msup>
<mo stretchy="false"> ) </mo>
<mi> β </mi>
</msup>
<mo> = </mo>
<msup>
<mi> e </mi>
<mrow>
<mo> </mo>
<mi> β </mi>
<mtext> </mtext>
<mi> e </mi>
<mi> n </mi>
<mi> e </mi>
<mi> r </mi>
<mi> g </mi>
<mi> y </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
</msup>
</mrow>
<annotation encoding="application/x-tex"> R(x)^\beta = e^{-\beta\; energy(x)} </annotation>
</semantics>
</math></span></span>, making it possible to focus more or less on the highest modes (versus spreading probability mass more uniformly).<br>
<a name="s6" id="s6"></a>
</p>
<h3> Generating molecule graphs </h3>
<p>The motivation for this work is to be able to generate diverse molecules from a proxy reward <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
</mrow>
<annotation encoding="application/x-tex"> R </annotation>
</semantics>
</math></span></span> that is imprecise because it comes from biochemical simulations that have a high uncertainty. As such, we do not care about the maximizer as RL methods would, but rather about a set of ``good enough'' candidates to send to a true biochemical assay.<br> Another motivation is to have diversity: by fitting the distribution of rewards rather than trying to maximize the expected reward, we're likely to find more modes than if we were being greedy after having found a good enough mode, which again and again we've found RL methods such as PPO to do.<br> Here we generate molecule graphs via a sequence of additive edits, i.e. we progressively build the graph by adding new leaf nodes to it. We also create molecules block-by-block rather than atom-by-atom.<br> We find experimentally that we get both good molecules, and diverse ones. We compare ourselves to PPO and MARS (an MCMC-based method).<br> Figure 3 shows that we're fitting a distribution that makes sense. If we change the reward by exponentiating it as <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msup>
<mi> R </mi>
<mi> β </mi>
</msup>
</mrow>
<annotation encoding="application/x-tex"> R^\beta </annotation>
</semantics>
</math></span></span> with <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> β </mi>
<mo> &gt; </mo>
<mn> 1 </mn>
</mrow>
<annotation encoding="application/x-tex"> \beta&gt;1 </annotation>
</semantics>
</math></span></span>, this shifts the reward distribution to the right.<br> Figure 4 shows the top-<span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> k </mi>
</mrow>
<annotation encoding="application/x-tex"> k </annotation>
</semantics>
</math></span></span> found as a function of the number of episodes.</p>
<center>
<img src="http://fakehost/test/gfn_fig34.png" width="650px">
</center>
<p> Finally, Figure 5 shows that using a biochemical measure of diversity to estimate the number of distinct modes found, GFlowNet finds much more varied candidates.</p>
<center>
<img src="http://fakehost/test/gfn_fig5.png" width="650px">
</center><br>
<h4> Active Learning experiments </h4>
<p>The above experiments assume access to a reward <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
</mrow>
<annotation encoding="application/x-tex"> R </annotation>
</semantics>
</math></span></span> that is cheap to evaluate. In fact it uses a neural network <i>proxy</i> trained from a large dataset of molecules. This setup isn't quite what we would get when interacting with biochemical assays, where we'd have access to much fewer data. To emulate such a setting, we consider our oracle to be a <i>docking simulation</i> (which is relatively expensive to run, ~30 cpu seconds).<br> In this setting, there is a limited budget for calls to the true oracle <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> O </mi>
</mrow>
<annotation encoding="application/x-tex"> O </annotation>
</semantics>
</math></span></span>. We use a proxy <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> M </mi>
</mrow>
<annotation encoding="application/x-tex"> M </annotation>
</semantics>
</math></span></span> initialized by training on a limited dataset of <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo separator="true"> , </mo>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> (x, R(x)) </annotation>
</semantics>
</math></span></span> pairs <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> D </mi>
<mn> 0 </mn>
</msub>
</mrow>
<annotation encoding="application/x-tex"> D_0 </annotation>
</semantics>
</math></span></span>, where <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
<mo stretchy="false"> ( </mo>
<mi> x </mi>
<mo stretchy="false"> ) </mo>
</mrow>
<annotation encoding="application/x-tex"> R(x) </annotation>
</semantics>
</math></span></span> is the true reward from the oracle. The generative model (<span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> π </mi>
<mi> θ </mi>
</msub>
</mrow>
<annotation encoding="application/x-tex"> \pi_{\theta} </annotation>
</semantics>
</math></span></span>) is then trained to fit <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> R </mi>
</mrow>
<annotation encoding="application/x-tex"> R </annotation>
</semantics>
</math></span></span> but as predicted by the proxy <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> M </mi>
</mrow>
<annotation encoding="application/x-tex"> M </annotation>
</semantics>
</math></span></span>. We then sample a batch <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> B </mi>
<mo> = </mo>
<mo stretchy="false"> { </mo>
<msub>
<mi> x </mi>
<mn> 1 </mn>
</msub>
<mo separator="true"> , </mo>
<msub>
<mi> x </mi>
<mn> 2 </mn>
</msub>
<mo separator="true"> , </mo>
<mo></mo>
<msub>
<mi> x </mi>
<mi> k </mi>
</msub>
<mo stretchy="false"> } </mo>
</mrow>
<annotation encoding="application/x-tex"> B=\{x_1, x_2, \dots x_k\} </annotation>
</semantics>
</math></span></span> where <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<msub>
<mi> x </mi>
<mi> i </mi>
</msub>
<mo> </mo>
<msub>
<mi> π </mi>
<mi> θ </mi>
</msub>
</mrow>
<annotation encoding="application/x-tex"> x_i\sim \pi_{\theta} </annotation>
</semantics>
</math></span></span>, which is evaluated with the oracle <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> O </mi>
</mrow>
<annotation encoding="application/x-tex"> O </annotation>
</semantics>
</math></span></span>. The proxy <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> M </mi>
</mrow>
<annotation encoding="application/x-tex"> M </annotation>
</semantics>
</math></span></span> is updated with this newly acquired and labeled batch, and the process is repeated for <span><span><math xmlns="http://www.w3.org/1998/Math/MathML">
<semantics>
<mrow>
<mi> N </mi>
</mrow>
<annotation encoding="application/x-tex"> N </annotation>
</semantics>
</math></span></span> iterations.<br> By doing this on the molecule setting we again find that we can generate better molecules. This showcases the importance of having these diverse candidates.</p>
<center>
<img src="http://fakehost/test/gfn_fig7.png" width="325px">
</center>
<p> For more figures, experiments and explanations, check out <a href="https://arxiv.org/abs/2106.04399">the paper</a>, or reach out to us!<br>
</p>
</div>
</DIV>