Skip to content

Commit

Permalink
Deploying to gh-pages from @ 1169fc5 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Oct 31, 2024
1 parent a4cc326 commit 597c98b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -569,10 +569,24 @@ <h1>Source code for grl.generative_models.conditional_flow_model.independent_con
<span class="c1"># x.shape = (B*N, D)</span>

<span class="k">if</span> <span class="n">condition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="c1"># condition.shape = (B*N, D)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="c1"># condition.shape = (B*N, D)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">TensorDict</span><span class="p">):</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">TensorDict</span><span class="p">(</span>
<span class="p">{</span>
<span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">condition</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="p">},</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">)</span> <span class="o">*</span> <span class="n">condition</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">condition</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;Not implemented&quot;</span><span class="p">)</span>

<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">solver</span><span class="p">,</span> <span class="n">DPMSolver</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;Not implemented&quot;</span><span class="p">)</span>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,24 @@ <h1>Source code for grl.generative_models.conditional_flow_model.optimal_transpo
<span class="c1"># x.shape = (B*N, D)</span>

<span class="k">if</span> <span class="n">condition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="c1"># condition.shape = (B*N, D)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="c1"># condition.shape = (B*N, D)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">TensorDict</span><span class="p">):</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">TensorDict</span><span class="p">(</span>
<span class="p">{</span>
<span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">condition</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="p">},</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">)</span> <span class="o">*</span> <span class="n">condition</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">condition</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;Not implemented&quot;</span><span class="p">)</span>

<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">solver</span><span class="p">,</span> <span class="n">DPMSolver</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;Not implemented&quot;</span><span class="p">)</span>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,10 +547,16 @@ <h1>Source code for grl.generative_models.diffusion_model.diffusion_model</h1><d
<span class="p">)</span>
<span class="c1"># condition.shape = (B*N, D)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">TensorDict</span><span class="p">):</span>
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">condition</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="n">condition</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">TensorDict</span><span class="p">(</span>
<span class="p">{</span>
<span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">condition</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="p">},</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">)</span> <span class="o">*</span> <span class="n">condition</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">condition</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;Not implemented&quot;</span><span class="p">)</span>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,26 +650,31 @@ <h1>Source code for grl.generative_models.diffusion_model.energy_conditional_dif
<span class="c1"># x.shape = (B*N, D)</span>

<span class="k">if</span> <span class="n">condition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">TensorDict</span><span class="p">):</span>
<span class="n">repeated_condition</span> <span class="o">=</span> <span class="n">TensorDict</span><span class="p">(</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="c1"># condition.shape = (B*N, D)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">treetensor</span><span class="o">.</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">condition</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="n">condition</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="c1"># condition.shape = (B*N, D)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">TensorDict</span><span class="p">):</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">TensorDict</span><span class="p">(</span>
<span class="p">{</span>
<span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">value</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="n">condition</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">condition</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">condition</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="p">},</span>
<span class="n">batch_size</span><span class="o">=</span><span class="nb">int</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">*</span><span class="n">condition</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">extra_batch_size</span><span class="p">])</span>
<span class="p">)</span>
<span class="p">),</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">)</span> <span class="o">*</span> <span class="n">condition</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">condition</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">repeated_condition</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">repeated_condition</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span>
<span class="n">condition</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">extra_batch_size</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
<span class="p">)</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;Not implemented&quot;</span><span class="p">)</span>

<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">solver</span><span class="p">,</span> <span class="n">DPMSolver</span><span class="p">):</span>

<span class="k">def</span> <span class="nf">noise_function_with_energy_guidance</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">condition</span><span class="p">):</span>
Expand Down

0 comments on commit 597c98b

Please sign in to comment.