Skip to content

Models API

embodied_gen.models.texture_model

build_texture_gen_pipe

build_texture_gen_pipe(base_ckpt_dir: str, controlnet_ckpt: str = None, ip_adapt_scale: float = 0, device: str = 'cuda') -> DiffusionPipeline

Build and initialize the Kolors + ControlNet (optional IP-Adapter) texture generation pipeline.

Loads Kolors tokenizer, text encoder (ChatGLM), VAE, UNet, scheduler and (optionally) a ControlNet checkpoint plus IP-Adapter vision encoder. If controlnet_ckpt is not provided, the default multi-view texture ControlNet weights are downloaded automatically from the hub. When ip_adapt_scale > 0 an IP-Adapter vision encoder and its weights are also loaded and activated.

Parameters:

Name Type Description Default
base_ckpt_dir str

Root directory where Kolors (and optionally Kolors-IP-Adapter-Plus) weights are or will be stored. Required subfolders: Kolors/{text_encoder,vae,unet,scheduler}.

required
controlnet_ckpt str

Directory containing a ControlNet checkpoint (safetensors). If None, downloads the default texture_gen_mv_v1 snapshot.

None
ip_adapt_scale float

Strength (>=0) of IP-Adapter conditioning. Set >0 to enable IP-Adapter; typical values: 0.4-0.8. Default: 0 (disabled).

0
device str

Target device to move the pipeline to (e.g. "cuda", "cuda:0", "cpu"). Default: "cuda".

'cuda'

Returns:

Name Type Description
DiffusionPipeline DiffusionPipeline

A configured

DiffusionPipeline

StableDiffusionXLControlNetImg2ImgPipeline ready for multi-view texture

DiffusionPipeline

generation (with optional IP-Adapter support).

Example

Initialize pipeline with IP-Adapter enabled.

from embodied_gen.models.texture_model import build_texture_gen_pipe
ip_adapt_scale = 0.7
PIPELINE = build_texture_gen_pipe(
    base_ckpt_dir="./weights",
    ip_adapt_scale=ip_adapt_scale,
    device="cuda",
)
PIPELINE.set_ip_adapter_scale([ip_adapt_scale])
Initialize pipeline without IP-Adapter.
from embodied_gen.models.texture_model import build_texture_gen_pipe
PIPELINE = build_texture_gen_pipe(
    base_ckpt_dir="./weights",
    ip_adapt_scale=0,
    device="cuda",
)

Source code in embodied_gen/models/texture_model.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def build_texture_gen_pipe(
    base_ckpt_dir: str,
    controlnet_ckpt: str = None,
    ip_adapt_scale: float = 0,
    device: str = "cuda",
) -> DiffusionPipeline:
    """Build and initialize the Kolors + ControlNet (optional IP-Adapter) texture generation pipeline.

    Loads Kolors tokenizer, text encoder (ChatGLM), VAE, UNet, scheduler and (optionally)
    a ControlNet checkpoint plus IP-Adapter vision encoder. If ``controlnet_ckpt`` is
    not provided, the default multi-view texture ControlNet weights are downloaded
    automatically from the hub. When ``ip_adapt_scale > 0`` an IP-Adapter vision
    encoder and its weights are also loaded and activated.

    Args:
        base_ckpt_dir (str):
            Root directory where Kolors (and optionally Kolors-IP-Adapter-Plus) weights
            are or will be stored. Required subfolders: ``Kolors/{text_encoder,vae,unet,scheduler}``.
        controlnet_ckpt (str, optional):
            Directory containing a ControlNet checkpoint (safetensors). If ``None``,
            downloads the default ``texture_gen_mv_v1`` snapshot.
        ip_adapt_scale (float, optional):
            Strength (>=0) of IP-Adapter conditioning. Set >0 to enable IP-Adapter;
            typical values: 0.4-0.8. Default: 0 (disabled).
        device (str, optional):
            Target device to move the pipeline to (e.g. ``"cuda"``, ``"cuda:0"``, ``"cpu"``).
            Default: ``"cuda"``.

    Returns:
        DiffusionPipeline: A configured
        ``StableDiffusionXLControlNetImg2ImgPipeline`` ready for multi-view texture
        generation (with optional IP-Adapter support).

    Example:
        Initialize pipeline with IP-Adapter enabled.
        ```python
        from embodied_gen.models.texture_model import build_texture_gen_pipe
        ip_adapt_scale = 0.7
        PIPELINE = build_texture_gen_pipe(
            base_ckpt_dir="./weights",
            ip_adapt_scale=ip_adapt_scale,
            device="cuda",
        )
        PIPELINE.set_ip_adapter_scale([ip_adapt_scale])
        ```
        Initialize pipeline without IP-Adapter.
        ```python
        from embodied_gen.models.texture_model import build_texture_gen_pipe
        PIPELINE = build_texture_gen_pipe(
            base_ckpt_dir="./weights",
            ip_adapt_scale=0,
            device="cuda",
        )
        ```
    """

    download_kolors_weights(f"{base_ckpt_dir}/Kolors")
    logger.info(f"Load Kolors weights...")
    tokenizer = ChatGLMTokenizer.from_pretrained(
        f"{base_ckpt_dir}/Kolors/text_encoder"
    )
    text_encoder = ChatGLMModel.from_pretrained(
        f"{base_ckpt_dir}/Kolors/text_encoder", torch_dtype=torch.float16
    ).half()
    vae = AutoencoderKL.from_pretrained(
        f"{base_ckpt_dir}/Kolors/vae", revision=None
    ).half()
    unet = UNet2DConditionModel.from_pretrained(
        f"{base_ckpt_dir}/Kolors/unet", revision=None
    ).half()
    scheduler = EulerDiscreteScheduler.from_pretrained(
        f"{base_ckpt_dir}/Kolors/scheduler"
    )

    if controlnet_ckpt is None:
        suffix = "texture_gen_mv_v1"  # "geo_cond_mv"
        model_path = snapshot_download(
            repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
        )
        controlnet_ckpt = os.path.join(model_path, suffix)

    controlnet = ControlNetModel.from_pretrained(
        controlnet_ckpt, use_safetensors=True
    ).half()

    # IP-Adapter model
    image_encoder = None
    clip_image_processor = None
    if ip_adapt_scale > 0:
        image_encoder = CLIPVisionModelWithProjection.from_pretrained(
            f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus/image_encoder",
            # ignore_mismatched_sizes=True,
        ).to(dtype=torch.float16)
        ip_img_size = 336
        clip_image_processor = CLIPImageProcessor(
            size=ip_img_size, crop_size=ip_img_size
        )

    pipe = StableDiffusionXLControlNetImg2ImgPipeline(
        vae=vae,
        controlnet=controlnet,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
        image_encoder=image_encoder,
        feature_extractor=clip_image_processor,
        force_zeros_for_empty_prompt=False,
    )

    if ip_adapt_scale > 0:
        if hasattr(pipe.unet, "encoder_hid_proj"):
            pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
        pipe.load_ip_adapter(
            f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus",
            subfolder="",
            weight_name=["ip_adapter_plus_general.bin"],
        )
        pipe.set_ip_adapter_scale([ip_adapt_scale])

    pipe = pipe.to(device)
    pipe.enable_model_cpu_offload()

    return pipe

embodied_gen.models.gs_model

GaussianOperator dataclass

GaussianOperator(_opacities: Tensor, _means: Tensor, _scales: Tensor, _quats: Tensor, _rgbs: Optional[Tensor] = None, _features_dc: Optional[Tensor] = None, _features_rest: Optional[Tensor] = None, sh_degree: Optional[int] = 0, device: str = 'cuda')

Bases: GaussianBase

Gaussian Splatting operator.

Supports transformation, scaling, color computation, and rasterization-based rendering.

Inherits

GaussianBase: Base class with Gaussian params (means, scales, etc.)

Functionality includes: - Applying instance poses to transform Gaussian means and quaternions. - Scaling Gaussians to a real-world size. - Computing colors using spherical harmonics. - Rendering images via differentiable rasterization. - Exporting transformed and rescaled models to .ply format.

get_gaussians
get_gaussians(c2w: Tensor = None, instance_pose: Tensor = None, apply_activate: bool = False) -> GaussianBase

Get Gaussian data under the given instance_pose.

Source code in embodied_gen/models/gs_model.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def get_gaussians(
    self,
    c2w: torch.Tensor = None,
    instance_pose: torch.Tensor = None,
    apply_activate: bool = False,
) -> "GaussianBase":
    """Get Gaussian data under the given instance_pose."""
    if c2w is None:
        c2w = torch.eye(4).to(self.device)

    if instance_pose is not None:
        # compute the transformed gs means and quats
        world_means, world_quats = self._compute_transform(
            self._means, self._quats, instance_pose.float().to(self.device)
        )
    else:
        world_means, world_quats = self._means, self._quats

    # get colors of gaussians
    if self._features_rest is not None:
        colors = torch.cat(
            (self._features_dc[:, None, :], self._features_rest), dim=1
        )
    else:
        colors = self._features_dc[:, None, :]

    if self.sh_degree > 0:
        viewdirs = world_means.detach() - c2w[..., :3, 3]  # (N, 3)
        viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
        rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors)
        rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0)
    else:
        rgbs = torch.sigmoid(colors[:, 0, :])

    gs_dict = dict(
        _means=world_means,
        _opacities=(
            torch.sigmoid(self._opacities)
            if apply_activate
            else self._opacities
        ),
        _rgbs=rgbs,
        _scales=(
            torch.exp(self._scales) if apply_activate else self._scales
        ),
        _quats=self.quat_norm(world_quats),
        _features_dc=self._features_dc,
        _features_rest=self._features_rest,
        sh_degree=self.sh_degree,
        device=self.device,
    )

    return GaussianOperator(**gs_dict)

embodied_gen.models.layout

LayoutDesigner

LayoutDesigner(gpt_client: GPTclient, system_prompt: str, verbose: bool = False)

Bases: object

A class for querying GPT-based scene layout reasoning and formatting responses.

Attributes:

Name Type Description
prompt str

The system prompt for GPT.

verbose bool

Whether to log responses.

gpt_client GPTclient

The GPT client instance.

Methods:

Name Description
query

Query GPT with a prompt and parameters.

format_response

Parse and clean JSON response.

format_response_repair

Repair and parse JSON response.

save_output

Save output to file.

__call__

Query and process output.

Source code in embodied_gen/models/layout.py
394
395
396
397
398
399
400
401
402
def __init__(
    self,
    gpt_client: GPTclient,
    system_prompt: str,
    verbose: bool = False,
) -> None:
    self.prompt = system_prompt.strip()
    self.verbose = verbose
    self.gpt_client = gpt_client
__call__
__call__(prompt: str, save_path: str = None, params: dict = None) -> dict | str

Query GPT and process the output.

Parameters:

Name Type Description Default
prompt str

User prompt.

required
save_path str

Path to save output.

None
params dict

GPT parameters.

None

Returns:

Type Description
dict | str

dict | str: Output data.

Source code in embodied_gen/models/layout.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def __call__(
    self, prompt: str, save_path: str = None, params: dict = None
) -> dict | str:
    """Query GPT and process the output.

    Args:
        prompt (str): User prompt.
        save_path (str, optional): Path to save output.
        params (dict, optional): GPT parameters.

    Returns:
        dict | str: Output data.
    """
    response = self.query(prompt, params=params)
    output = self.format_response_repair(response)
    self.save_output(output, save_path) if save_path else None

    return output
format_response
format_response(response: str) -> dict

Format and parse GPT response as JSON.

Parameters:

Name Type Description Default
response str

Raw GPT response.

required

Returns:

Name Type Description
dict dict

Parsed JSON output.

Raises:

Type Description
JSONDecodeError

If parsing fails.

Source code in embodied_gen/models/layout.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def format_response(self, response: str) -> dict:
    """Format and parse GPT response as JSON.

    Args:
        response (str): Raw GPT response.

    Returns:
        dict: Parsed JSON output.

    Raises:
        json.JSONDecodeError: If parsing fails.
    """
    cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
    try:
        output = json.loads(cleaned)
    except json.JSONDecodeError as e:
        raise json.JSONDecodeError(
            f"Error: {e}, failed to parse JSON response: {response}"
        )

    return output
format_response_repair
format_response_repair(response: str) -> dict

Repair and parse possibly broken JSON response.

Parameters:

Name Type Description Default
response str

Raw GPT response.

required

Returns:

Name Type Description
dict dict

Parsed JSON output.

Source code in embodied_gen/models/layout.py
448
449
450
451
452
453
454
455
456
457
def format_response_repair(self, response: str) -> dict:
    """Repair and parse possibly broken JSON response.

    Args:
        response (str): Raw GPT response.

    Returns:
        dict: Parsed JSON output.
    """
    return json_repair.loads(response)
query
query(prompt: str, params: dict = None) -> str

Query GPT with the system prompt and user prompt.

Parameters:

Name Type Description Default
prompt str

User prompt.

required
params dict

GPT parameters.

None

Returns:

Name Type Description
str str

GPT response.

Source code in embodied_gen/models/layout.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
def query(self, prompt: str, params: dict = None) -> str:
    """Query GPT with the system prompt and user prompt.

    Args:
        prompt (str): User prompt.
        params (dict, optional): GPT parameters.

    Returns:
        str: GPT response.
    """
    full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""

    response = self.gpt_client.query(
        text_prompt=full_prompt,
        params=params,
    )

    if self.verbose:
        logger.info(f"Response: {response}")

    return response
save_output
save_output(output: dict, save_path: str) -> None

Save output dictionary to a file.

Parameters:

Name Type Description Default
output dict

Output data.

required
save_path str

Path to save the file.

required
Source code in embodied_gen/models/layout.py
459
460
461
462
463
464
465
466
467
468
def save_output(self, output: dict, save_path: str) -> None:
    """Save output dictionary to a file.

    Args:
        output (dict): Output data.
        save_path (str): Path to save the file.
    """
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w') as f:
        json.dump(output, f, indent=4)

build_scene_layout

build_scene_layout(task_desc: str, output_path: str = None, gpt_params: dict = None) -> LayoutInfo

Build a 3D scene layout from a natural language task description.

This function uses GPT-based reasoning to generate a structured scene layout, including object hierarchy, spatial relations, and style descriptions.

Parameters:

Name Type Description Default
task_desc str

Natural language description of the robotic task.

required
output_path str

Path to save the visualized scene tree.

None
gpt_params dict

Parameters for GPT queries.

None

Returns:

Name Type Description
LayoutInfo LayoutInfo

Structured layout information for the scene.

Example
from embodied_gen.models.layout import build_scene_layout
layout_info = build_scene_layout(
    task_desc="Put the apples on the table on the plate",
    output_path="outputs/scene_tree.jpg",
)
print(layout_info)
Source code in embodied_gen/models/layout.py
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
def build_scene_layout(
    task_desc: str, output_path: str = None, gpt_params: dict = None
) -> LayoutInfo:
    """Build a 3D scene layout from a natural language task description.

    This function uses GPT-based reasoning to generate a structured scene layout,
    including object hierarchy, spatial relations, and style descriptions.

    Args:
        task_desc (str): Natural language description of the robotic task.
        output_path (str, optional): Path to save the visualized scene tree.
        gpt_params (dict, optional): Parameters for GPT queries.

    Returns:
        LayoutInfo: Structured layout information for the scene.

    Example:
        ```py
        from embodied_gen.models.layout import build_scene_layout
        layout_info = build_scene_layout(
            task_desc="Put the apples on the table on the plate",
            output_path="outputs/scene_tree.jpg",
        )
        print(layout_info)
        ```
    """
    layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
    layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
    object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
    obj_prompt = f'{layout_relation["task_desc"]} {object_mapping}'
    objs_desc = LAYOUT_DESCRIBER(obj_prompt, params=gpt_params)
    layout_info = LayoutInfo(
        layout_tree, layout_relation, objs_desc, object_mapping
    )

    if output_path is not None:
        visualizer = SceneTreeVisualizer(layout_info)
        visualizer.render(save_path=output_path)
        logger.info(f"Scene hierarchy tree saved to {output_path}")

    return layout_info

embodied_gen.models.text_model

build_text2img_ip_pipeline

build_text2img_ip_pipeline(ckpt_dir: str, ref_scale: float, device: str = 'cuda') -> StableDiffusionXLPipelineIP

Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation.

Parameters:

Name Type Description Default
ckpt_dir str

Directory containing model checkpoints.

required
ref_scale float

Reference scale for IP-Adapter.

required
device str

Device for inference.

'cuda'

Returns:

Name Type Description
StableDiffusionXLPipelineIP StableDiffusionXLPipeline

Configured pipeline.

Example
from embodied_gen.models.text_model import build_text2img_ip_pipeline
pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3)
Source code in embodied_gen/models/text_model.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def build_text2img_ip_pipeline(
    ckpt_dir: str,
    ref_scale: float,
    device: str = "cuda",
) -> StableDiffusionXLPipelineIP:
    """Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation.

    Args:
        ckpt_dir (str): Directory containing model checkpoints.
        ref_scale (float): Reference scale for IP-Adapter.
        device (str, optional): Device for inference.

    Returns:
        StableDiffusionXLPipelineIP: Configured pipeline.

    Example:
        ```py
        from embodied_gen.models.text_model import build_text2img_ip_pipeline
        pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3)
        ```
    """
    download_kolors_weights(ckpt_dir)

    text_encoder = ChatGLMModel.from_pretrained(
        f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
    ).half()
    tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
    vae = AutoencoderKL.from_pretrained(
        f"{ckpt_dir}/vae", revision=None
    ).half()
    scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
    unet = UNet2DConditionModelIP.from_pretrained(
        f"{ckpt_dir}/unet", revision=None
    ).half()
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        f"{ckpt_dir}/../Kolors-IP-Adapter-Plus/image_encoder",
        ignore_mismatched_sizes=True,
    ).to(dtype=torch.float16)
    clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)

    pipe = StableDiffusionXLPipelineIP(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
        image_encoder=image_encoder,
        feature_extractor=clip_image_processor,
        force_zeros_for_empty_prompt=False,
    )

    if hasattr(pipe.unet, "encoder_hid_proj"):
        pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj

    pipe.load_ip_adapter(
        f"{ckpt_dir}/../Kolors-IP-Adapter-Plus",
        subfolder="",
        weight_name=["ip_adapter_plus_general.bin"],
    )
    pipe.set_ip_adapter_scale([ref_scale])

    pipe = pipe.to(device)
    pipe.image_encoder = pipe.image_encoder.to(device)
    pipe.enable_model_cpu_offload()
    # pipe.enable_xformers_memory_efficient_attention()
    # pipe.enable_vae_slicing()

    return pipe

build_text2img_pipeline

build_text2img_pipeline(ckpt_dir: str, device: str = 'cuda') -> StableDiffusionXLPipeline

Builds a Stable Diffusion XL pipeline for text-to-image generation.

Parameters:

Name Type Description Default
ckpt_dir str

Directory containing model checkpoints.

required
device str

Device for inference.

'cuda'

Returns:

Name Type Description
StableDiffusionXLPipeline StableDiffusionXLPipeline

Configured pipeline.

Example
from embodied_gen.models.text_model import build_text2img_pipeline
pipe = build_text2img_pipeline("weights/Kolors")
Source code in embodied_gen/models/text_model.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def build_text2img_pipeline(
    ckpt_dir: str,
    device: str = "cuda",
) -> StableDiffusionXLPipeline:
    """Builds a Stable Diffusion XL pipeline for text-to-image generation.

    Args:
        ckpt_dir (str): Directory containing model checkpoints.
        device (str, optional): Device for inference.

    Returns:
        StableDiffusionXLPipeline: Configured pipeline.

    Example:
        ```py
        from embodied_gen.models.text_model import build_text2img_pipeline
        pipe = build_text2img_pipeline("weights/Kolors")
        ```
    """
    download_kolors_weights(ckpt_dir)

    text_encoder = ChatGLMModel.from_pretrained(
        f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
    ).half()
    tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
    vae = AutoencoderKL.from_pretrained(
        f"{ckpt_dir}/vae", revision=None
    ).half()
    scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
    unet = UNet2DConditionModel.from_pretrained(
        f"{ckpt_dir}/unet", revision=None
    ).half()
    pipe = StableDiffusionXLPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
        force_zeros_for_empty_prompt=False,
    )
    pipe = pipe.to(device)
    pipe.enable_model_cpu_offload()
    pipe.enable_xformers_memory_efficient_attention()

    return pipe

download_kolors_weights

download_kolors_weights(local_dir: str = 'weights/Kolors') -> None

Downloads Kolors model weights from HuggingFace.

Parameters:

Name Type Description Default
local_dir str

Local directory to store weights.

'weights/Kolors'
Source code in embodied_gen/models/text_model.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
    """Downloads Kolors model weights from HuggingFace.

    Args:
        local_dir (str, optional): Local directory to store weights.
    """
    logger.info(f"Download kolors weights from huggingface...")
    os.makedirs(local_dir, exist_ok=True)
    subprocess.run(
        [
            "huggingface-cli",
            "download",
            "--resume-download",
            "Kwai-Kolors/Kolors",
            "--local-dir",
            local_dir,
        ],
        check=True,
    )

    ip_adapter_path = f"{local_dir}/../Kolors-IP-Adapter-Plus"
    subprocess.run(
        [
            "huggingface-cli",
            "download",
            "--resume-download",
            "Kwai-Kolors/Kolors-IP-Adapter-Plus",
            "--local-dir",
            ip_adapter_path,
        ],
        check=True,
    )

text2img_gen

text2img_gen(prompt: str, n_sample: int, guidance_scale: float, pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipeline, ip_image: Image | str = None, image_wh: tuple[int, int] = [1024, 1024], infer_step: int = 50, ip_image_size: int = 512, seed: int = None) -> list[Image.Image]

Generates images from text prompts using a Stable Diffusion XL pipeline.

Parameters:

Name Type Description Default
prompt str

Text prompt for image generation.

required
n_sample int

Number of images to generate.

required
guidance_scale float

Guidance scale for diffusion.

required
pipeline StableDiffusionXLPipeline | StableDiffusionXLPipeline

Pipeline instance.

required
ip_image Image | str

Reference image for IP-Adapter.

None
image_wh tuple[int, int]

Output image size (width, height).

[1024, 1024]
infer_step int

Number of inference steps.

50
ip_image_size int

Size for IP-Adapter image.

512
seed int

Random seed.

None

Returns:

Type Description
list[Image]

list[Image.Image]: List of generated images.

Example
from embodied_gen.models.text_model import text2img_gen
images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5)
images[0].save("banana.png")
Source code in embodied_gen/models/text_model.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
def text2img_gen(
    prompt: str,
    n_sample: int,
    guidance_scale: float,
    pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP,
    ip_image: Image.Image | str = None,
    image_wh: tuple[int, int] = [1024, 1024],
    infer_step: int = 50,
    ip_image_size: int = 512,
    seed: int = None,
) -> list[Image.Image]:
    """Generates images from text prompts using a Stable Diffusion XL pipeline.

    Args:
        prompt (str): Text prompt for image generation.
        n_sample (int): Number of images to generate.
        guidance_scale (float): Guidance scale for diffusion.
        pipeline (StableDiffusionXLPipeline | StableDiffusionXLPipelineIP): Pipeline instance.
        ip_image (Image.Image | str, optional): Reference image for IP-Adapter.
        image_wh (tuple[int, int], optional): Output image size (width, height).
        infer_step (int, optional): Number of inference steps.
        ip_image_size (int, optional): Size for IP-Adapter image.
        seed (int, optional): Random seed.

    Returns:
        list[Image.Image]: List of generated images.

    Example:
        ```py
        from embodied_gen.models.text_model import text2img_gen
        images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5)
        images[0].save("banana.png")
        ```
    """
    prompt = PROMPT_KAPPEND.format(object=prompt.strip())
    logger.info(f"Processing prompt: {prompt}")

    generator = None
    if seed is not None:
        generator = torch.Generator(pipeline.device).manual_seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

    kwargs = dict(
        prompt=prompt,
        height=image_wh[1],
        width=image_wh[0],
        num_inference_steps=infer_step,
        guidance_scale=guidance_scale,
        num_images_per_prompt=n_sample,
        generator=generator,
    )
    if ip_image is not None:
        if isinstance(ip_image, str):
            ip_image = Image.open(ip_image)
        ip_image = ip_image.resize((ip_image_size, ip_image_size))
        kwargs.update(ip_adapter_image=[ip_image])

    return pipeline(**kwargs).images

embodied_gen.models.sr_model

ImageRealESRGAN

ImageRealESRGAN(outscale: int, model_path: str = None)

A wrapper for Real-ESRGAN-based image super-resolution.

This class uses the RealESRGAN model to perform image upscaling, typically by a factor of 4.

Attributes:

Name Type Description
outscale int

The output image scale factor (e.g., 2, 4).

model_path str

Path to the pre-trained model weights.

Example
from embodied_gen.models.sr_model import ImageRealESRGAN
from PIL import Image

sr_model = ImageRealESRGAN(outscale=4)
img = Image.open("input.png")
upscaled = sr_model(img)
upscaled.save("output.png")

Initializes the RealESRGAN upscaler.

Parameters:

Name Type Description Default
outscale int

Output scale factor.

required
model_path str

Path to model weights.

None
Source code in embodied_gen/models/sr_model.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def __init__(self, outscale: int, model_path: str = None) -> None:
    """Initializes the RealESRGAN upscaler.

    Args:
        outscale (int): Output scale factor.
        model_path (str, optional): Path to model weights.
    """
    # monkey patch to support torchvision>=0.16
    import torchvision
    from packaging import version

    if version.parse(torchvision.__version__) > version.parse("0.16"):
        import sys
        import types

        import torchvision.transforms.functional as TF

        functional_tensor = types.ModuleType(
            "torchvision.transforms.functional_tensor"
        )
        functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
        sys.modules["torchvision.transforms.functional_tensor"] = (
            functional_tensor
        )

    self.outscale = outscale
    self.upsampler = None

    if model_path is None:
        suffix = "super_resolution"
        model_path = snapshot_download(
            repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
        )
        model_path = os.path.join(
            model_path, suffix, "RealESRGAN_x4plus.pth"
        )

    self.model_path = model_path
__call__
__call__(image: Union[Image, ndarray]) -> Image.Image

Performs super-resolution on the input image.

Parameters:

Name Type Description Default
image Union[Image, ndarray]

Input image.

required

Returns:

Type Description
Image

Image.Image: Upscaled image.

Source code in embodied_gen/models/sr_model.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
@spaces.GPU
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
    """Performs super-resolution on the input image.

    Args:
        image (Union[Image.Image, np.ndarray]): Input image.

    Returns:
        Image.Image: Upscaled image.
    """
    self._lazy_init()

    if isinstance(image, Image.Image):
        image = np.array(image)

    with torch.no_grad():
        output, _ = self.upsampler.enhance(image, outscale=self.outscale)

    return Image.fromarray(output)

ImageStableSR

ImageStableSR(model_path: str = 'stabilityai/stable-diffusion-x4-upscaler', device='cuda')

Super-resolution image upscaler using Stable Diffusion x4 upscaling model.

This class wraps the StabilityAI Stable Diffusion x4 upscaler for high-quality image super-resolution.

Parameters:

Name Type Description Default
model_path str

Path or HuggingFace repo for the model.

'stabilityai/stable-diffusion-x4-upscaler'
device str

Device for inference.

'cuda'
Example
from embodied_gen.models.sr_model import ImageStableSR
from PIL import Image

sr_model = ImageStableSR()
img = Image.open("input.png")
upscaled = sr_model(img)
upscaled.save("output.png")

Initializes the Stable Diffusion x4 upscaler.

Parameters:

Name Type Description Default
model_path str

Model path or repo.

'stabilityai/stable-diffusion-x4-upscaler'
device str

Device for inference.

'cuda'
Source code in embodied_gen/models/sr_model.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
    device="cuda",
) -> None:
    """Initializes the Stable Diffusion x4 upscaler.

    Args:
        model_path (str, optional): Model path or repo.
        device (str, optional): Device for inference.
    """
    from diffusers import StableDiffusionUpscalePipeline

    self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
    ).to(device)
    self.up_pipeline_x4.set_progress_bar_config(disable=True)
    self.up_pipeline_x4.enable_model_cpu_offload()
__call__
__call__(image: Union[Image, ndarray], prompt: str = '', infer_step: int = 20) -> Image.Image

Performs super-resolution on the input image.

Parameters:

Name Type Description Default
image Union[Image, ndarray]

Input image.

required
prompt str

Text prompt for upscaling.

''
infer_step int

Number of inference steps.

20

Returns:

Type Description
Image

Image.Image: Upscaled image.

Source code in embodied_gen/models/sr_model.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@spaces.GPU
def __call__(
    self,
    image: Union[Image.Image, np.ndarray],
    prompt: str = "",
    infer_step: int = 20,
) -> Image.Image:
    """Performs super-resolution on the input image.

    Args:
        image (Union[Image.Image, np.ndarray]): Input image.
        prompt (str, optional): Text prompt for upscaling.
        infer_step (int, optional): Number of inference steps.

    Returns:
        Image.Image: Upscaled image.
    """
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    image = image.convert("RGB")

    with torch.no_grad():
        upscaled_image = self.up_pipeline_x4(
            image=image,
            prompt=[prompt],
            num_inference_steps=infer_step,
        ).images[0]

    return upscaled_image

embodied_gen.models.segment_model

BMGG14Remover

BMGG14Remover()

Bases: object

Removes background using the RMBG-1.4 segmentation model.

Example
from embodied_gen.models.segment_model import BMGG14Remover
remover = BMGG14Remover()
result = remover("input.jpg", "output.png")

Initializes the BMGG14Remover.

Source code in embodied_gen/models/segment_model.py
369
370
371
372
373
374
375
def __init__(self) -> None:
    """Initializes the BMGG14Remover."""
    self.model = pipeline(
        "image-segmentation",
        model="briaai/RMBG-1.4",
        trust_remote_code=True,
    )
__call__
__call__(image: Union[str, Image, ndarray], save_path: str = None)

Removes background from an image.

Parameters:

Name Type Description Default
image Union[str, Image, ndarray]

Input image.

required
save_path str

Path to save the output image.

None

Returns:

Type Description

Image.Image: Image with background removed.

Source code in embodied_gen/models/segment_model.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def __call__(
    self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
):
    """Removes background from an image.

    Args:
        image (Union[str, Image.Image, np.ndarray]): Input image.
        save_path (str, optional): Path to save the output image.

    Returns:
        Image.Image: Image with background removed.
    """
    if isinstance(image, str):
        image = Image.open(image)
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    image = resize_pil(image)
    output_image = self.model(image)

    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        output_image.save(save_path)

    return output_image

RembgRemover

RembgRemover()

Bases: object

Removes background from images using the rembg library.

Example
from embodied_gen.models.segment_model import RembgRemover
remover = RembgRemover()
result = remover("input.jpg", "output.png")

Initializes the RembgRemover.

Source code in embodied_gen/models/segment_model.py
327
328
329
def __init__(self):
    """Initializes the RembgRemover."""
    self.rembg_session = rembg.new_session("u2net")
__call__
__call__(image: Union[str, Image, ndarray], save_path: str = None) -> Image.Image

Removes background from an image.

Parameters:

Name Type Description Default
image Union[str, Image, ndarray]

Input image.

required
save_path str

Path to save the output image.

None

Returns:

Type Description
Image

Image.Image: Image with background removed (RGBA).

Source code in embodied_gen/models/segment_model.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def __call__(
    self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
) -> Image.Image:
    """Removes background from an image.

    Args:
        image (Union[str, Image.Image, np.ndarray]): Input image.
        save_path (str, optional): Path to save the output image.

    Returns:
        Image.Image: Image with background removed (RGBA).
    """
    if isinstance(image, str):
        image = Image.open(image)
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    image = resize_pil(image)
    output_image = rembg.remove(image, session=self.rembg_session)

    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        output_image.save(save_path)

    return output_image

SAMPredictor

SAMPredictor(checkpoint: str = None, model_type: str = 'vit_h', binary_thresh: float = 0.1, device: str = 'cuda')

Bases: object

Loads SAM models and predicts segmentation masks from user points.

Parameters:

Name Type Description Default
checkpoint str

Path to model checkpoint.

None
model_type str

SAM model type.

'vit_h'
binary_thresh float

Threshold for binary mask.

0.1
device str

Device for inference.

'cuda'
Source code in embodied_gen/models/segment_model.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def __init__(
    self,
    checkpoint: str = None,
    model_type: str = "vit_h",
    binary_thresh: float = 0.1,
    device: str = "cuda",
):
    self.device = device
    self.model_type = model_type

    if checkpoint is None:
        suffix = "sam"
        model_path = snapshot_download(
            repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
        )
        checkpoint = os.path.join(
            model_path, suffix, "sam_vit_h_4b8939.pth"
        )

    self.predictor = self._load_sam_model(checkpoint)
    self.binary_thresh = binary_thresh
__call__
__call__(image: Union[str, Image, ndarray], selected_points: list[list[int]]) -> Image.Image

Segments image using selected points.

Parameters:

Name Type Description Default
image Union[str, Image, ndarray]

Input image.

required
selected_points list[list[int]]

List of points and labels.

required

Returns:

Type Description
Image

Image.Image: Segmented RGBA image.

Source code in embodied_gen/models/segment_model.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def __call__(
    self,
    image: Union[str, Image.Image, np.ndarray],
    selected_points: list[list[int]],
) -> Image.Image:
    """Segments image using selected points.

    Args:
        image (Union[str, Image.Image, np.ndarray]): Input image.
        selected_points (list[list[int]]): List of points and labels.

    Returns:
        Image.Image: Segmented RGBA image.
    """
    image = self.preprocess_image(image)
    self.predictor.set_image(image)
    masks = self.generate_masks(image, selected_points)

    return self.get_segmented_image(image, masks)
generate_masks
generate_masks(image: ndarray, selected_points: list[list[int]]) -> np.ndarray

Generates segmentation masks from selected points.

Parameters:

Name Type Description Default
image ndarray

Input image array.

required
selected_points list[list[int]]

List of points and labels.

required

Returns:

Type Description
ndarray

list[tuple[np.ndarray, str]]: List of masks and names.

Source code in embodied_gen/models/segment_model.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def generate_masks(
    self,
    image: np.ndarray,
    selected_points: list[list[int]],
) -> np.ndarray:
    """Generates segmentation masks from selected points.

    Args:
        image (np.ndarray): Input image array.
        selected_points (list[list[int]]): List of points and labels.

    Returns:
        list[tuple[np.ndarray, str]]: List of masks and names.
    """
    if len(selected_points) == 0:
        return []

    points = (
        torch.Tensor([p for p, _ in selected_points])
        .to(self.predictor.device)
        .unsqueeze(1)
    )

    labels = (
        torch.Tensor([int(l) for _, l in selected_points])
        .to(self.predictor.device)
        .unsqueeze(1)
    )

    transformed_points = self.predictor.transform.apply_coords_torch(
        points, image.shape[:2]
    )

    masks, scores, _ = self.predictor.predict_torch(
        point_coords=transformed_points,
        point_labels=labels,
        multimask_output=True,
    )
    valid_mask = masks[:, torch.argmax(scores, dim=1)]
    masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy()
    masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy()
    if len(masks_neg) == 0:
        masks_neg = np.zeros_like(masks_pos)
    if len(masks_pos) == 0:
        masks_pos = np.zeros_like(masks_neg)
    masks_neg = masks_neg.max(axis=0, keepdims=True)
    masks_pos = masks_pos.max(axis=0, keepdims=True)
    valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1)

    binary_mask = (valid_mask > self.binary_thresh).astype(np.int32)

    return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)]
get_segmented_image
get_segmented_image(image: ndarray, masks: list[tuple[ndarray, str]]) -> Image.Image

Combines masks and returns segmented image with alpha channel.

Parameters:

Name Type Description Default
image ndarray

Input image array.

required
masks list[tuple[ndarray, str]]

List of masks.

required

Returns:

Type Description
Image

Image.Image: Segmented RGBA image.

Source code in embodied_gen/models/segment_model.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def get_segmented_image(
    self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
) -> Image.Image:
    """Combines masks and returns segmented image with alpha channel.

    Args:
        image (np.ndarray): Input image array.
        masks (list[tuple[np.ndarray, str]]): List of masks.

    Returns:
        Image.Image: Segmented RGBA image.
    """
    seg_image = Image.fromarray(image, mode="RGB")
    alpha_channel = np.zeros(
        (seg_image.height, seg_image.width), dtype=np.uint8
    )
    for mask, _ in masks:
        # Use the maximum to combine multiple masks
        alpha_channel = np.maximum(alpha_channel, mask)

    alpha_channel = np.clip(alpha_channel, 0, 1)
    alpha_channel = (alpha_channel * 255).astype(np.uint8)
    alpha_image = Image.fromarray(alpha_channel, mode="L")
    r, g, b = seg_image.split()
    seg_image = Image.merge("RGBA", (r, g, b, alpha_image))

    return seg_image
preprocess_image
preprocess_image(image: Image) -> np.ndarray

Preprocesses input image for SAM prediction.

Parameters:

Name Type Description Default
image Image

Input image.

required

Returns:

Type Description
ndarray

np.ndarray: Preprocessed image array.

Source code in embodied_gen/models/segment_model.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def preprocess_image(self, image: Image.Image) -> np.ndarray:
    """Preprocesses input image for SAM prediction.

    Args:
        image (Image.Image): Input image.

    Returns:
        np.ndarray: Preprocessed image array.
    """
    if isinstance(image, str):
        image = Image.open(image)
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image).convert("RGB")

    image = resize_pil(image)
    image = np.array(image.convert("RGB"))

    return image

SAMRemover

SAMRemover(checkpoint: str = None, model_type: str = 'vit_h', area_ratio: float = 15)

Bases: object

Loads SAM models and performs background removal on images.

Attributes:

Name Type Description
checkpoint str

Path to the model checkpoint.

model_type str

Type of the SAM model to load.

area_ratio float

Area ratio for filtering small connected components.

Example
from embodied_gen.models.segment_model import SAMRemover
remover = SAMRemover(model_type="vit_h")
result = remover("input.jpg", "output.png")
Source code in embodied_gen/models/segment_model.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self,
    checkpoint: str = None,
    model_type: str = "vit_h",
    area_ratio: float = 15,
):
    self.device = "cuda" if torch.cuda.is_available() else "cpu"
    self.model_type = model_type
    self.area_ratio = area_ratio

    if checkpoint is None:
        suffix = "sam"
        model_path = snapshot_download(
            repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
        )
        checkpoint = os.path.join(
            model_path, suffix, "sam_vit_h_4b8939.pth"
        )

    self.mask_generator = self._load_sam_model(checkpoint)
__call__
__call__(image: Union[str, Image, ndarray], save_path: str = None) -> Image.Image

Removes the background from an image using the SAM model.

Parameters:

Name Type Description Default
image Union[str, Image, ndarray]

Input image.

required
save_path str

Path to save the output image.

None

Returns:

Type Description
Image

Image.Image: Image with background removed (RGBA).

Source code in embodied_gen/models/segment_model.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def __call__(
    self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
) -> Image.Image:
    """Removes the background from an image using the SAM model.

    Args:
        image (Union[str, Image.Image, np.ndarray]): Input image.
        save_path (str, optional): Path to save the output image.

    Returns:
        Image.Image: Image with background removed (RGBA).
    """
    # Convert input to numpy array
    if isinstance(image, str):
        image = Image.open(image)
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image).convert("RGB")
    image = resize_pil(image)
    image = np.array(image.convert("RGB"))

    # Generate masks
    masks = self.mask_generator.generate(image)
    masks = sorted(masks, key=lambda x: x["area"], reverse=True)

    if not masks:
        logger.warning(
            "Segmentation failed: No mask generated, return raw image."
        )
        output_image = Image.fromarray(image, mode="RGB")
    else:
        # Use the largest mask
        best_mask = masks[0]["segmentation"]
        mask = (best_mask * 255).astype(np.uint8)
        mask = filter_small_connected_components(
            mask, area_ratio=self.area_ratio
        )
        # Apply the mask to remove the background
        background_removed = cv2.bitwise_and(image, image, mask=mask)
        output_image = np.dstack((background_removed, mask))
        output_image = Image.fromarray(output_image, mode="RGBA")

    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        output_image.save(save_path)

    return output_image

get_segmented_image_by_agent

get_segmented_image_by_agent(image: Image, sam_remover: SAMRemover, rbg_remover: RembgRemover, seg_checker: ImageSegChecker = None, save_path: str = None, mode: Literal['loose', 'strict'] = 'loose') -> Image.Image

Segments an image using SAM and rembg, with quality checking.

Parameters:

Name Type Description Default
image Image

Input image.

required
sam_remover SAMRemover

SAM-based remover.

required
rbg_remover RembgRemover

rembg-based remover.

required
seg_checker ImageSegChecker

Quality checker.

None
save_path str

Path to save the output image.

None
mode Literal['loose', 'strict']

Segmentation mode.

'loose'

Returns:

Type Description
Image

Image.Image: Segmented RGBA image.

Source code in embodied_gen/models/segment_model.py
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
def get_segmented_image_by_agent(
    image: Image.Image,
    sam_remover: SAMRemover,
    rbg_remover: RembgRemover,
    seg_checker: ImageSegChecker = None,
    save_path: str = None,
    mode: Literal["loose", "strict"] = "loose",
) -> Image.Image:
    """Segments an image using SAM and rembg, with quality checking.

    Args:
        image (Image.Image): Input image.
        sam_remover (SAMRemover): SAM-based remover.
        rbg_remover (RembgRemover): rembg-based remover.
        seg_checker (ImageSegChecker, optional): Quality checker.
        save_path (str, optional): Path to save the output image.
        mode (Literal["loose", "strict"], optional): Segmentation mode.

    Returns:
        Image.Image: Segmented RGBA image.
    """

    def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
        if seg_checker is None:
            return True
        return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0]

    out_sam = f"{save_path}_sam.png" if save_path else None
    out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None
    out_rbg = f"{save_path}_rbg.png" if save_path else None

    seg_image = sam_remover(image, out_sam)
    seg_image = seg_image.convert("RGBA")
    _, _, _, alpha = seg_image.split()
    seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv)
    seg_image_rbg = rbg_remover(image, out_rbg)

    final_image = None
    if _is_valid_seg(image, seg_image):
        final_image = seg_image
    elif _is_valid_seg(image, seg_image_inv):
        final_image = seg_image_inv
    elif _is_valid_seg(image, seg_image_rbg):
        logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.")
        final_image = seg_image_rbg
    else:
        if mode == "strict":
            raise RuntimeError(
                f"Failed to segment by `SAM` or `rembg`, abort."
            )
        logger.warning("Failed to segment by SAM or rembg, use raw image.")
        final_image = image.convert("RGBA")

    if save_path:
        final_image.save(save_path)

    final_image = trellis_preprocess(final_image)

    return final_image

invert_rgba_pil

invert_rgba_pil(image: Image, mask: Image, save_path: str = None) -> Image.Image

Inverts the alpha channel of an RGBA image using a mask.

Parameters:

Name Type Description Default
image Image

Input RGB image.

required
mask Image

Mask image for alpha inversion.

required
save_path str

Path to save the output image.

None

Returns:

Type Description
Image

Image.Image: RGBA image with inverted alpha.

Source code in embodied_gen/models/segment_model.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def invert_rgba_pil(
    image: Image.Image, mask: Image.Image, save_path: str = None
) -> Image.Image:
    """Inverts the alpha channel of an RGBA image using a mask.

    Args:
        image (Image.Image): Input RGB image.
        mask (Image.Image): Mask image for alpha inversion.
        save_path (str, optional): Path to save the output image.

    Returns:
        Image.Image: RGBA image with inverted alpha.
    """
    mask = (255 - np.array(mask))[..., None]
    image_array = np.concatenate([np.array(image), mask], axis=-1)
    inverted_image = Image.fromarray(image_array, "RGBA")

    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        inverted_image.save(save_path)

    return inverted_image

embodied_gen.models.image_comm_model

BasePipelineLoader

BasePipelineLoader(device='cuda')

Bases: ABC

Abstract base class for loading Hugging Face image generation pipelines.

Attributes:

Name Type Description
device str

Device to load the pipeline on.

Methods:

Name Description
load

Loads and returns the pipeline.

Source code in embodied_gen/models/image_comm_model.py
50
51
def __init__(self, device="cuda"):
    self.device = device
load abstractmethod
load()

Load and return the pipeline instance.

Source code in embodied_gen/models/image_comm_model.py
53
54
55
56
@abstractmethod
def load(self):
    """Load and return the pipeline instance."""
    pass

BasePipelineRunner

BasePipelineRunner(pipe)

Bases: ABC

Abstract base class for running image generation pipelines.

Attributes:

Name Type Description
pipe

The loaded pipeline.

Methods:

Name Description
run

Runs the pipeline with a prompt.

Source code in embodied_gen/models/image_comm_model.py
69
70
def __init__(self, pipe):
    self.pipe = pipe
run abstractmethod
run(prompt: str, **kwargs) -> Image.Image

Run the pipeline with the given prompt.

Parameters:

Name Type Description Default
prompt str

Text prompt for image generation.

required
**kwargs

Additional pipeline arguments.

{}

Returns:

Type Description
Image

Image.Image: Generated image(s).

Source code in embodied_gen/models/image_comm_model.py
72
73
74
75
76
77
78
79
80
81
82
83
@abstractmethod
def run(self, prompt: str, **kwargs) -> Image.Image:
    """Run the pipeline with the given prompt.

    Args:
        prompt (str): Text prompt for image generation.
        **kwargs: Additional pipeline arguments.

    Returns:
        Image.Image: Generated image(s).
    """
    pass

ChromaLoader

ChromaLoader(device='cuda')

Bases: BasePipelineLoader

Loader for Chroma pipeline.

Source code in embodied_gen/models/image_comm_model.py
50
51
def __init__(self, device="cuda"):
    self.device = device
load
load()

Load the Chroma pipeline.

Returns:

Name Type Description
ChromaPipeline

Loaded pipeline.

Source code in embodied_gen/models/image_comm_model.py
296
297
298
299
300
301
302
303
304
def load(self):
    """Load the Chroma pipeline.

    Returns:
        ChromaPipeline: Loaded pipeline.
    """
    return ChromaPipeline.from_pretrained(
        "lodestones/Chroma", torch_dtype=torch.bfloat16
    ).to(self.device)

ChromaRunner

ChromaRunner(pipe)

Bases: BasePipelineRunner

Runner for Chroma pipeline.

Source code in embodied_gen/models/image_comm_model.py
69
70
def __init__(self, pipe):
    self.pipe = pipe
run
run(prompt: str, negative_prompt=None, **kwargs) -> Image.Image

Generate images using Chroma pipeline.

Parameters:

Name Type Description Default
prompt str

Text prompt.

required
negative_prompt str

Negative prompt.

None
**kwargs

Additional arguments.

{}

Returns:

Type Description
Image

Image.Image: Generated image(s).

Source code in embodied_gen/models/image_comm_model.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
    """Generate images using Chroma pipeline.

    Args:
        prompt (str): Text prompt.
        negative_prompt (str, optional): Negative prompt.
        **kwargs: Additional arguments.

    Returns:
        Image.Image: Generated image(s).
    """
    return self.pipe(
        prompt=prompt, negative_prompt=negative_prompt, **kwargs
    ).images

CosmosLoader

CosmosLoader(model_id='nvidia/Cosmos-Predict2-2B-Text2Image', local_dir='weights/cosmos2', device='cuda')

Bases: BasePipelineLoader

Loader for Cosmos2 text-to-image pipeline.

Source code in embodied_gen/models/image_comm_model.py
127
128
129
130
131
132
133
134
135
def __init__(
    self,
    model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
    local_dir="weights/cosmos2",
    device="cuda",
):
    super().__init__(device)
    self.model_id = model_id
    self.local_dir = local_dir
load
load()

Load the Cosmos2 text-to-image pipeline.

Returns:

Name Type Description
Cosmos2TextToImagePipeline

Loaded pipeline.

Source code in embodied_gen/models/image_comm_model.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def load(self):
    """Load the Cosmos2 text-to-image pipeline.

    Returns:
        Cosmos2TextToImagePipeline: Loaded pipeline.
    """
    self._patch()
    snapshot_download(
        repo_id=self.model_id,
        local_dir=self.local_dir,
        local_dir_use_symlinks=False,
        resume_download=True,
    )

    config = PipelineQuantizationConfig(
        quant_backend="bitsandbytes_4bit",
        quant_kwargs={
            "load_in_4bit": True,
            "bnb_4bit_quant_type": "nf4",
            "bnb_4bit_compute_dtype": torch.bfloat16,
            "bnb_4bit_use_double_quant": True,
        },
        components_to_quantize=["text_encoder", "transformer", "unet"],
    )

    pipe = Cosmos2TextToImagePipeline.from_pretrained(
        self.model_id,
        torch_dtype=torch.bfloat16,
        quantization_config=config,
        use_safetensors=True,
        safety_checker=None,
        requires_safety_checker=False,
    ).to(self.device)
    return pipe

CosmosRunner

CosmosRunner(pipe)

Bases: BasePipelineRunner

Runner for Cosmos2 text-to-image pipeline.

Source code in embodied_gen/models/image_comm_model.py
69
70
def __init__(self, pipe):
    self.pipe = pipe
run
run(prompt: str, negative_prompt=None, **kwargs) -> Image.Image

Generate images using Cosmos2 pipeline.

Parameters:

Name Type Description Default
prompt str

Text prompt.

required
negative_prompt str

Negative prompt.

None
**kwargs

Additional arguments.

{}

Returns:

Type Description
Image

Image.Image: Generated image(s).

Source code in embodied_gen/models/image_comm_model.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
    """Generate images using Cosmos2 pipeline.

    Args:
        prompt (str): Text prompt.
        negative_prompt (str, optional): Negative prompt.
        **kwargs: Additional arguments.

    Returns:
        Image.Image: Generated image(s).
    """
    return self.pipe(
        prompt=prompt, negative_prompt=negative_prompt, **kwargs
    ).images

FluxLoader

FluxLoader(device='cuda')

Bases: BasePipelineLoader

Loader for Flux pipeline.

Source code in embodied_gen/models/image_comm_model.py
50
51
def __init__(self, device="cuda"):
    self.device = device
load
load()

Load the Flux pipeline.

Returns:

Name Type Description
FluxPipeline

Loaded pipeline.

Source code in embodied_gen/models/image_comm_model.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def load(self):
    """Load the Flux pipeline.

    Returns:
        FluxPipeline: Loaded pipeline.
    """
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
    )
    pipe.enable_model_cpu_offload()
    pipe.enable_xformers_memory_efficient_attention()
    pipe.enable_attention_slicing()
    return pipe.to(self.device)

FluxRunner

FluxRunner(pipe)

Bases: BasePipelineRunner

Runner for Flux pipeline.

Source code in embodied_gen/models/image_comm_model.py
69
70
def __init__(self, pipe):
    self.pipe = pipe
run
run(prompt: str, **kwargs) -> Image.Image

Generate images using Flux pipeline.

Parameters:

Name Type Description Default
prompt str

Text prompt.

required
**kwargs

Additional arguments.

{}

Returns:

Type Description
Image

Image.Image: Generated image(s).

Source code in embodied_gen/models/image_comm_model.py
279
280
281
282
283
284
285
286
287
288
289
def run(self, prompt: str, **kwargs) -> Image.Image:
    """Generate images using Flux pipeline.

    Args:
        prompt (str): Text prompt.
        **kwargs: Additional arguments.

    Returns:
        Image.Image: Generated image(s).
    """
    return self.pipe(prompt=prompt, **kwargs).images

KolorsLoader

KolorsLoader(device='cuda')

Bases: BasePipelineLoader

Loader for Kolors pipeline.

Source code in embodied_gen/models/image_comm_model.py
50
51
def __init__(self, device="cuda"):
    self.device = device
load
load()

Load the Kolors pipeline.

Returns:

Name Type Description
KolorsPipeline

Loaded pipeline.

Source code in embodied_gen/models/image_comm_model.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def load(self):
    """Load the Kolors pipeline.

    Returns:
        KolorsPipeline: Loaded pipeline.
    """
    pipe = KolorsPipeline.from_pretrained(
        "Kwai-Kolors/Kolors-diffusers",
        torch_dtype=torch.float16,
        variant="fp16",
    ).to(self.device)
    pipe.enable_model_cpu_offload()
    pipe.enable_xformers_memory_efficient_attention()
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(
        pipe.scheduler.config, use_karras_sigmas=True
    )
    return pipe

KolorsRunner

KolorsRunner(pipe)

Bases: BasePipelineRunner

Runner for Kolors pipeline.

Source code in embodied_gen/models/image_comm_model.py
69
70
def __init__(self, pipe):
    self.pipe = pipe
run
run(prompt: str, **kwargs) -> Image.Image

Generate images using Kolors pipeline.

Parameters:

Name Type Description Default
prompt str

Text prompt.

required
**kwargs

Additional arguments.

{}

Returns:

Type Description
Image

Image.Image: Generated image(s).

Source code in embodied_gen/models/image_comm_model.py
243
244
245
246
247
248
249
250
251
252
253
def run(self, prompt: str, **kwargs) -> Image.Image:
    """Generate images using Kolors pipeline.

    Args:
        prompt (str): Text prompt.
        **kwargs: Additional arguments.

    Returns:
        Image.Image: Generated image(s).
    """
    return self.pipe(prompt=prompt, **kwargs).images

SD35Loader

SD35Loader(device='cuda')

Bases: BasePipelineLoader

Loader for Stable Diffusion 3.5 medium pipeline.

Source code in embodied_gen/models/image_comm_model.py
50
51
def __init__(self, device="cuda"):
    self.device = device
load
load()

Load the Stable Diffusion 3.5 medium pipeline.

Returns:

Name Type Description
StableDiffusion3Pipeline

Loaded pipeline.

Source code in embodied_gen/models/image_comm_model.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def load(self):
    """Load the Stable Diffusion 3.5 medium pipeline.

    Returns:
        StableDiffusion3Pipeline: Loaded pipeline.
    """
    pipe = StableDiffusion3Pipeline.from_pretrained(
        "stabilityai/stable-diffusion-3.5-medium",
        torch_dtype=torch.float16,
    )
    pipe = pipe.to(self.device)
    pipe.enable_model_cpu_offload()
    pipe.enable_xformers_memory_efficient_attention()
    pipe.enable_attention_slicing()
    return pipe

SD35Runner

SD35Runner(pipe)

Bases: BasePipelineRunner

Runner for Stable Diffusion 3.5 medium pipeline.

Source code in embodied_gen/models/image_comm_model.py
69
70
def __init__(self, pipe):
    self.pipe = pipe
run
run(prompt: str, **kwargs) -> Image.Image

Generate images using Stable Diffusion 3.5 medium.

Parameters:

Name Type Description Default
prompt str

Text prompt.

required
**kwargs

Additional arguments.

{}

Returns:

Type Description
Image

Image.Image: Generated image(s).

Source code in embodied_gen/models/image_comm_model.py
110
111
112
113
114
115
116
117
118
119
120
def run(self, prompt: str, **kwargs) -> Image.Image:
    """Generate images using Stable Diffusion 3.5 medium.

    Args:
        prompt (str): Text prompt.
        **kwargs: Additional arguments.

    Returns:
        Image.Image: Generated image(s).
    """
    return self.pipe(prompt=prompt, **kwargs).images

build_hf_image_pipeline

build_hf_image_pipeline(name: str, device='cuda') -> BasePipelineRunner

Build a Hugging Face image generation pipeline runner by name.

Parameters:

Name Type Description Default
name str

Name of the pipeline (e.g., "sd35", "cosmos").

required
device str

Device to load the pipeline on.

'cuda'

Returns:

Name Type Description
BasePipelineRunner BasePipelineRunner

Pipeline runner instance.

Example
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
runner = build_hf_image_pipeline("sd35")
images = runner.run(prompt="A robot holding a sign that says 'Hello'")
Source code in embodied_gen/models/image_comm_model.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
    """Build a Hugging Face image generation pipeline runner by name.

    Args:
        name (str): Name of the pipeline (e.g., "sd35", "cosmos").
        device (str): Device to load the pipeline on.

    Returns:
        BasePipelineRunner: Pipeline runner instance.

    Example:
        ```py
        from embodied_gen.models.image_comm_model import build_hf_image_pipeline
        runner = build_hf_image_pipeline("sd35")
        images = runner.run(prompt="A robot holding a sign that says 'Hello'")
        ```
    """
    if name not in PIPELINE_REGISTRY:
        raise ValueError(f"Unsupported model: {name}")
    loader_cls, runner_cls = PIPELINE_REGISTRY[name]
    pipe = loader_cls(device=device).load()

    return runner_cls(pipe)

embodied_gen.models.delight_model

DelightingModel

DelightingModel(model_path: str = None, num_infer_step: int = 50, mask_erosion_size: int = 3, image_guide_scale: float = 1.5, text_guide_scale: float = 1.0, device: str = 'cuda', seed: int = 0)

Bases: object

A model to remove the lighting in image space.

This model is encapsulated based on the Hunyuan3D-Delight model from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa

Attributes:

Name Type Description
image_guide_scale float

Weight of image guidance in diffusion process.

text_guide_scale float

Weight of text (prompt) guidance in diffusion process.

num_infer_step int

Number of inference steps for diffusion model.

mask_erosion_size int

Size of erosion kernel for alpha mask cleanup.

device str

Device used for inference, e.g., 'cuda' or 'cpu'.

seed int

Random seed for diffusion model reproducibility.

model_path str

Filesystem path to pretrained model weights.

pipeline

Lazy-loaded diffusion pipeline instance.

Source code in embodied_gen/models/delight_model.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def __init__(
    self,
    model_path: str = None,
    num_infer_step: int = 50,
    mask_erosion_size: int = 3,
    image_guide_scale: float = 1.5,
    text_guide_scale: float = 1.0,
    device: str = "cuda",
    seed: int = 0,
) -> None:
    self.image_guide_scale = image_guide_scale
    self.text_guide_scale = text_guide_scale
    self.num_infer_step = num_infer_step
    self.mask_erosion_size = mask_erosion_size
    self.kernel = np.ones(
        (self.mask_erosion_size, self.mask_erosion_size), np.uint8
    )
    self.seed = seed
    self.device = device
    self.pipeline = None  # lazy load model adapt to @spaces.GPU

    if model_path is None:
        suffix = "hunyuan3d-delight-v2-0"
        model_path = snapshot_download(
            repo_id="tencent/Hunyuan3D-2", allow_patterns=f"{suffix}/*"
        )
        model_path = os.path.join(model_path, suffix)

    self.model_path = model_path