Skip to content

Memory Management ;( the runtime gets crashed after few prompts #109

@shiroanon

Description

@shiroanon
import random
import string
import sys
from typing import Sequence, Mapping, Any, Union
import torch
from flask import Flask, request, jsonify ,send_from_directory

from flask_cors import CORS
app = Flask(__name__)
CORS(app)
def random_strings_list(n):
    return [''.join(random.choices(string.ascii_letters + string.digits, k=16)) for _ in range(n)]


import threading
def tu():
  !/content/loophole http 5000
threading.Thread(target=tu , daemon=True).start()

import shiro.utils
# Lets manipulate comfy through args switches
from shiro.cli_args import args, LatentPreviewMethod
args.preview_method = LatentPreviewMethod.Latent2RGB
def f(value, total, preview):
		if preview:
			preview[1].save('/tmp/preview.jpg')
shiro.utils.set_progress_bar_global_hook(f)


def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:

    try:
        return obj[index]
    except KeyError:
        return obj["result"][index]


def find_path(name: str, path: str = None) -> str:

    # If no path is given, use the current working directory
    if path is None:
        path = os.getcwd()

    # Check if the current directory contains the name
    if name in os.listdir(path):
        path_name = os.path.join(path, name)
        print(f"{name} found: {path_name}")
        return path_name

    # Get the parent directory
    parent_directory = os.path.dirname(path)

    # If the parent directory is the same as the current directory, we've reached the root and stop the search
    if parent_directory == path:
        return None

    # Recursively call the function with the parent directory
    return find_path(name, parent_directory)


def add_shiroui_directory_to_sys_path() -> None:
    shiroui_path = find_path("ShiroUI")
    if shiroui_path is not None and os.path.isdir(shiroui_path):
        sys.path.append(shiroui_path)
        print(f"'{shiroui_path}' added to sys.path")

add_shiroui_directory_to_sys_path()



def import_custom_nodes() -> None:
    import asyncio
    import execution
    from nodes import init_extra_nodes
    import server

    # Creating a new event loop and setting it as the default loop
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

    # Creating an instance of PromptServer with the loop
    server_instance = server.PromptServer(loop)
    execution.PromptQueue(server_instance)

    # Initializing custom nodes
    init_extra_nodes()


from nodes import (
    CLIPTextEncode,
    CheckpointLoaderSimple,
    VAEDecode,
    NODE_CLASS_MAPPINGS,
    SaveImage,
    LoraLoader,
    EmptyLatentImage,
)


def main(prompt,cf,batch_siz):
    import_custom_nodes()
    global lis
    lis=random_strings_list(1)
    with torch.inference_mode():
        checkpointloadersimple = CheckpointLoaderSimple()
        checkpointloadersimple_1 = checkpointloadersimple.load_checkpoint(
            ckpt_name="kk.safetensors"
        )

        loraloader = LoraLoader()
        loraloader_10 = loraloader.load_lora(
            lora_name="flat.safetensors",
            strength_model=1,
            strength_clip=1,
            model=get_value_at_index(checkpointloadersimple_1, 0),
            clip=get_value_at_index(checkpointloadersimple_1, 1),
        )

        loraloader_11 = loraloader.load_lora(
            lora_name="flat.safetensors",
            strength_model=0,
            strength_clip=0,
            model=get_value_at_index(loraloader_10, 0),
            clip=get_value_at_index(loraloader_10, 1),
        )

        loraloader_12 = loraloader.load_lora(
            lora_name="flat.safetensors",
            strength_model=0,
            strength_clip=0,
            model=get_value_at_index(loraloader_11, 0),
            clip=get_value_at_index(loraloader_11, 1),
        )

        cliptextencode = CLIPTextEncode()
        cliptextencode_3 = cliptextencode.encode(
            text=prompt, clip=get_value_at_index(loraloader_12, 1)
        )

        cliptextencode_4 = cliptextencode.encode(
            text="", clip=get_value_at_index(loraloader_12, 1)
        )

        alignyourstepsscheduler = NODE_CLASS_MAPPINGS["AlignYourStepsScheduler"]()
        alignyourstepsscheduler_5 = alignyourstepsscheduler.get_sigmas(
            model_type="SD1", steps=25, denoise=1
        )

        ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
        ksamplerselect_6 = ksamplerselect.get_sampler(sampler_name="euler")

        emptylatentimage = EmptyLatentImage()
        emptylatentimage_7 = emptylatentimage.generate(
            width=512, height=512, batch_size=batch_siz
        )

        samplercustom = NODE_CLASS_MAPPINGS["SamplerCustom"]()
        vaedecode = VAEDecode()
        saveimage = SaveImage()

        for q in range(1):
            samplercustom_2 = samplercustom.sample(
                add_noise=True,
                noise_seed=random.randint(1, 2**64),
                cfg=cf,
                model=get_value_at_index(checkpointloadersimple_1, 0),
                positive=get_value_at_index(cliptextencode_3, 0),
                negative=get_value_at_index(cliptextencode_4, 0),
                sampler=get_value_at_index(ksamplerselect_6, 0),
                sigmas=get_value_at_index(alignyourstepsscheduler_5, 0),
                latent_image=get_value_at_index(emptylatentimage_7, 0),
            )

            vaedecode_8 = vaedecode.decode(
                samples=get_value_at_index(samplercustom_2, 0),
                vae=get_value_at_index(checkpointloadersimple_1, 2),
            )

            saveimage_9 = saveimage.save_images(
                filename_prefix=lis[0], images=get_value_at_index(vaedecode_8, 0)
            )



@app.route('/generate', methods=['POST'])
def generate():
    data = request.json
    prompt = data.get('prompt', '')
    cfg = data.get('cfg', 1)
    batch_size = data.get('batch_size', 1)
    seed = data.get('seed', 0)
    global response

    response = {
        "prompt": prompt,
        "cfg": cfg,
        "batch_size": batch_size,
        "seed": seed
    }
    print(response)
    main(prompt,cfg,batch_size)

    query = lis[0]
    directory = "/content/ShiroUI/output"

    if not query or not directory or not os.path.isdir(directory):
        return jsonify({"error": "Invalid query or directory"}), 400

    matched_images = [
        os.path.join("output", f) for f in os.listdir(directory)
        if query in f and f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'))
    ]

    if not matched_images:
        return jsonify({"error": "No matching images found"}), 404

    return jsonify(matched_images)


@app.route('/output/<path:filename>', methods=['GET'])
def get_image(filename):
    directory = "/content/ShiroUI/output"
    return send_from_directory(directory, filename)

if __name__ == '__main__':
    app.run()



Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions