Skip to content

Conversation

@Disty0
Copy link
Contributor

@Disty0 Disty0 commented Dec 7, 2025

What does this PR do?

Adds NewbieAI support to Diffusers.
Adds pooled_projection_dim config to Lumina2Transformer2DModel and uses pooled projections from Newbie codebase if it is set to something other than None.

Original NewbieAI model: https://huggingface.co/NewBie-AI/NewBie-image-Exp0.1
NewbieAI in Diffusers format: https://huggingface.co/Disty0/NewBie-image-Exp0.1-Diffusers

Known Issues:

  • JinaClip requires trust_remote_code=True and has to be loaded separately.
  • JinaClip doesn't work with CPU Offload.

Example code:

import torch
from diffusers import NewbiePipeline
from transformers import AutoModel

device = "cuda"
model_path = "Disty0/NewBie-image-Exp0.1-Diffusers"
text_encoder_2 = AutoModel.from_pretrained(model_path, subfolder="text_encoder_2", trust_remote_code=True, torch_dtype=torch.bfloat16)
pipe = NewbiePipeline.from_pretrained(model_path, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16)
del text_encoder_2

# Enable memory optimizations.
pipe.enable_model_cpu_offload(device=device)

prompt = """
  <character_1>
  <n>$character_1$</n>
  <gender>1girl</gender>
  <appearance>chibi, red_eyes, blue_hair, long_hair, hair_between_eyes, head_tilt, tareme, closed_mouth</appearance>
  <clothing>school_uniform, serafuku, white_sailor_collar, white_shirt, short_sleeves, red_neckerchief, bow, blue_skirt, miniskirt, pleated_skirt, blue_hat, mini_hat, thighhighs, grey_thighhighs, black_shoes, mary_janes</clothing>
  <expression>happy, smile</expression>
  <action>standing, holding, holding_briefcase</action>
  <position>center_left</position>
  </character_1>

  <character_2>
  <n>$character_2$</n>
  <gender>1girl</gender>
  <appearance>chibi, red_eyes, pink_hair, long_hair, very_long_hair, multi-tied_hair, open_mouth</appearance>
  <clothing>school_uniform, serafuku, white_sailor_collar, white_shirt, short_sleeves, red_neckerchief, bow, red_skirt, miniskirt, pleated_skirt, hair_bow, multiple_hair_bows, white_bow, ribbon_trim, ribbon-trimmed_bow, white_thighhighs, black_shoes, mary_janes, bow_legwear, bare_arms</clothing>
  <expression>happy, smile</expression>
  <action>standing, holding, holding_briefcase, waving</action>
  <position>center_right</position>
  </character_2>

  <general_tags>
  <count>2girls, multiple_girls</count>
  <style>anime_style, digital_art</style>
  <background>white_background, simple_background</background>
  <atmosphere>cheerful</atmosphere>
  <quality>high_resolution, detailed</quality>
  <objects>briefcase</objects>
  <other>alternate_costume</other>
  </general_tags>
"""

negative_prompt = "blurry, worst quality, low quality, deformed hands, bad anatomy, extra limbs, poorly drawn face, mutated, extra eyes, bad proportions"

pipe.text_encoder_2 = pipe.text_encoder_2.to(device)
image = pipe(
    prompt,
    negative_prompt=negative_prompt,
    height=1024,
    width=1024,
    guidance_scale=2.5,
    num_inference_steps=30,
    generator=torch.manual_seed(42),
).images[0]
display(image)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Core library:

@Disty0 Disty0 changed the title Add NewbieAI support Add Newbie Image support Dec 7, 2025
@david6666666
Copy link

Any updated?

@woct0rdho
Copy link

FYI: A simplified implementation of Jina CLIP v2 has been merged in ComfyUI, see Comfy-Org/ComfyUI#11415 . Compared to the original XLM-RoBERTa, the main difference is that they use a RoPE.

If you don't like it that loading the official Jina CLIP v2 requires trust_remote_code=True, maybe you can directly implement it in Transformers.

@vladmandic
Copy link
Contributor

gentle ping @sayakpaul @yiyixuxu

@yiyixuxu yiyixuxu requested a review from dg845 January 8, 2026 22:57
self,
hidden_size: int = 4096,
cap_feat_dim: int = 2048,
pooled_projection_dim: Optional[int] = None,
Copy link
Collaborator

@dg845 dg845 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of modifying transformer_lumina2.py directly, can you implement the modified transformer in a separate file transformer_newbie.py? You can copy over model classes as needed with the # Copied from mechanism:

class NewbieCombinedTimestepCaptionEmbedding(nn.Module):
    # Lumina2CombinedTimestepCaptionEmbedding which accepts an additional `pooled_projection_dim`
    ...

# Since not changed, use copied from
# Copied from diffusers.models.transformers.transformer_lumina2.Lumina2AttnProcessor2_0
class Lumina2AttnProcessor2_0:
    ...

# Copied from will ensure that the code is synced between the two files.

>>> device = "cuda"
>>> model_path = "Disty0/NewBie-image-Exp0.1-Diffusers"
>>> text_encoder_2 = AutoModel.from_pretrained(model_path, subfolder="text_encoder_2", trust_remote_code=True, torch_dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you implement the JinaClip text encoder model in the Newbie pipeline directory, for example at src/diffusers/pipelines/newbie/modeling_jina_clip.py? The model should inherit from diffusers.models.modeling_utils.ModelMixin and diffusers.configuration_utils.ConfigMixin. This will help ensure that text_encoder_2 works with e.g. CPU offloading and removes the dependency on trust_remote_code=True.

return timesteps, num_inference_steps


class NewbiePipeline(Lumina2Pipeline):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can NewbiePipeline inherit directly from DiffusionPipeline (and Lumina2LoraLoaderMixin, if appropriate) rather than Lumina2Pipeline? You can copy over methods (with the # Copied from mechanism) and properties as needed:

class NewbiePipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
    ....
    # Copied from diffusers.pipelines.lumina2.Lumina2Pipeline.prepare_latents
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        ...
    ...

You don't need to copy over the VAE slicing/tiling methods, as these will be deprecated and users can always call the analogous methods directly on the VAE with e.g. pipe.vae.enable_slicing().

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Disty0, thanks for the PR and thanks for your patience! The main items from the review are as follows:

  1. Implement the DiT in a separate file transformer_newbie.py.
  2. Implement JinaClip in a file in the newbie/ pipeline directory
  3. Have NewbiePipeline inherit directly from DiffusionPipeline

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants