Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
8d1a17c
re-add RL model code
natolambert Jul 19, 2022
84e94d7
match model forward api
natolambert Jul 19, 2022
f67b036
add register_to_config, pass training tests
natolambert Jul 26, 2022
e42d1c0
fix tests, update forward outputs
natolambert Oct 3, 2022
2dd514e
remove unused code, some comments
natolambert Oct 3, 2022
b4c6188
add to docs
natolambert Oct 3, 2022
c53bba9
remove extra embedding code
natolambert Oct 6, 2022
effcbdb
unify time embedding
natolambert Oct 7, 2022
7865231
remove conv1d output sequential
natolambert Oct 8, 2022
35b0a43
remove sequential from conv1dblock
natolambert Oct 8, 2022
9b1379d
style and deleting duplicated code
natolambert Oct 8, 2022
e97a610
clean files
natolambert Oct 8, 2022
8642560
remove unused variables
natolambert Oct 10, 2022
f58c915
clean variables
natolambert Oct 10, 2022
ad8376d
Merge branch 'main' into rl
natolambert Oct 10, 2022
3b08bea
add 1d resnet block structure for downsample
natolambert Oct 10, 2022
aae2a9a
rename as unet1d
natolambert Oct 10, 2022
dd872af
fix renaming
natolambert Oct 10, 2022
9b67bb7
rename files
natolambert Oct 12, 2022
db012eb
add get_block(...) api
natolambert Oct 12, 2022
4db6e0b
unify args for model1d like model2d
natolambert Oct 12, 2022
634a526
minor cleaning
natolambert Oct 12, 2022
aebf547
fix docs
natolambert Oct 12, 2022
305ecd8
improve 1d resnet blocks
natolambert Oct 12, 2022
42855b9
Merge branch 'main' into rl
natolambert Oct 12, 2022
95d3a1c
fix tests, remove permuts
natolambert Oct 12, 2022
6cbb73b
fix style
natolambert Oct 12, 2022
ffb7355
add output activation
natolambert Oct 18, 2022
a6314f6
rename flax blocks file
natolambert Oct 18, 2022
48a7414
Add Value Function and corresponding example script to Diffuser imple…
bglick13 Oct 21, 2022
3acddb5
update post merge of scripts
natolambert Oct 21, 2022
713e8f2
add mdiblock / outblock architecture
natolambert Oct 24, 2022
268ebdf
Pipeline cleanup (#947)
bglick13 Oct 24, 2022
daa05fb
Update src/diffusers/models/unet_1d_blocks.py
Oct 24, 2022
ea5f231
Update tests/test_models_unet.py
Oct 24, 2022
4f7a3a4
RL Cleanup v2 (#965)
bglick13 Oct 24, 2022
d90b8b1
fix quality in tests
natolambert Oct 24, 2022
ad8b6cf
fix quality style, split test file
natolambert Oct 24, 2022
e06a4a4
Merge branch 'main' into rl
natolambert Oct 24, 2022
99b2c81
fix checks / tests
natolambert Oct 24, 2022
de4b6e4
make timesteps closer to main
natolambert Oct 25, 2022
ef6ca1f
unify block API
natolambert Oct 25, 2022
6e3485c
Merge branch 'main' into rl
natolambert Oct 25, 2022
e6f1a83
unify forward api
natolambert Oct 25, 2022
c35a925
delete lines in examples
natolambert Oct 25, 2022
949b93a
style
natolambert Oct 25, 2022
2f6462b
examples style
natolambert Oct 25, 2022
a2dd559
all tests pass
natolambert Oct 26, 2022
39dff73
make style
natolambert Oct 26, 2022
d5eedff
make dance_diff test pass
natolambert Oct 26, 2022
faeacd5
Refactoring RL PR (#1200)
Nov 8, 2022
be25030
Merge branch 'main' into rl
natolambert Nov 8, 2022
72b7ee8
hotfix for tests
natolambert Nov 8, 2022
cf76a2d
quality
natolambert Nov 8, 2022
2290356
fix some tests
natolambert Nov 9, 2022
a061f7e
change defaults
natolambert Nov 9, 2022
0c58758
more mps test fixes
natolambert Nov 9, 2022
691ddee
unet1d defaults
natolambert Nov 9, 2022
4948ca7
do not default import experimental
natolambert Nov 9, 2022
ac88677
defaults for tests
natolambert Nov 9, 2022
ba204db
fix tests
natolambert Nov 9, 2022
915c41e
fix-copies
natolambert Nov 9, 2022
c901889
Merge branch 'main' into rl
natolambert Nov 14, 2022
becc803
fix
natolambert Nov 14, 2022
9b8e5ee
changes per Patrik's comments (#1285)
bglick13 Nov 14, 2022
3684a8c
fix renaming
natolambert Nov 14, 2022
ebdef16
skip more mps tests
natolambert Nov 14, 2022
a259aae
last test fix
natolambert Nov 14, 2022
1f7702c
Update examples/rl/README.md
Nov 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add register_to_config, pass training tests
  • Loading branch information
natolambert committed Oct 3, 2022
commit f67b036e862b34741282af0d9477c04326ea9cfb
3 changes: 2 additions & 1 deletion src/diffusers/models/unet_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock, Upsample1D

from ..configuration_utils import ConfigMixin
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding

Expand Down Expand Up @@ -57,6 +57,7 @@ def forward(self, x):


class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
@register_to_config
def __init__(
self,
training_horizon=128,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,12 @@ def input_shape(self):
def output_shape(self):
return (4, 16, 14)

def test_ema_training(self):
pass

def test_training(self):
pass

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"training_horizon": 128,
Expand Down