-
Notifications
You must be signed in to change notification settings - Fork 179
Open
Description
import random
import string
import sys
from typing import Sequence, Mapping, Any, Union
import torch
from flask import Flask, request, jsonify ,send_from_directory
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
def random_strings_list(n):
return [''.join(random.choices(string.ascii_letters + string.digits, k=16)) for _ in range(n)]
import threading
def tu():
!/content/loophole http 5000
threading.Thread(target=tu , daemon=True).start()
import shiro.utils
# Lets manipulate comfy through args switches
from shiro.cli_args import args, LatentPreviewMethod
args.preview_method = LatentPreviewMethod.Latent2RGB
def f(value, total, preview):
if preview:
preview[1].save('/tmp/preview.jpg')
shiro.utils.set_progress_bar_global_hook(f)
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
def find_path(name: str, path: str = None) -> str:
# If no path is given, use the current working directory
if path is None:
path = os.getcwd()
# Check if the current directory contains the name
if name in os.listdir(path):
path_name = os.path.join(path, name)
print(f"{name} found: {path_name}")
return path_name
# Get the parent directory
parent_directory = os.path.dirname(path)
# If the parent directory is the same as the current directory, we've reached the root and stop the search
if parent_directory == path:
return None
# Recursively call the function with the parent directory
return find_path(name, parent_directory)
def add_shiroui_directory_to_sys_path() -> None:
shiroui_path = find_path("ShiroUI")
if shiroui_path is not None and os.path.isdir(shiroui_path):
sys.path.append(shiroui_path)
print(f"'{shiroui_path}' added to sys.path")
add_shiroui_directory_to_sys_path()
def import_custom_nodes() -> None:
import asyncio
import execution
from nodes import init_extra_nodes
import server
# Creating a new event loop and setting it as the default loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Creating an instance of PromptServer with the loop
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
# Initializing custom nodes
init_extra_nodes()
from nodes import (
CLIPTextEncode,
CheckpointLoaderSimple,
VAEDecode,
NODE_CLASS_MAPPINGS,
SaveImage,
LoraLoader,
EmptyLatentImage,
)
def main(prompt,cf,batch_siz):
import_custom_nodes()
global lis
lis=random_strings_list(1)
with torch.inference_mode():
checkpointloadersimple = CheckpointLoaderSimple()
checkpointloadersimple_1 = checkpointloadersimple.load_checkpoint(
ckpt_name="kk.safetensors"
)
loraloader = LoraLoader()
loraloader_10 = loraloader.load_lora(
lora_name="flat.safetensors",
strength_model=1,
strength_clip=1,
model=get_value_at_index(checkpointloadersimple_1, 0),
clip=get_value_at_index(checkpointloadersimple_1, 1),
)
loraloader_11 = loraloader.load_lora(
lora_name="flat.safetensors",
strength_model=0,
strength_clip=0,
model=get_value_at_index(loraloader_10, 0),
clip=get_value_at_index(loraloader_10, 1),
)
loraloader_12 = loraloader.load_lora(
lora_name="flat.safetensors",
strength_model=0,
strength_clip=0,
model=get_value_at_index(loraloader_11, 0),
clip=get_value_at_index(loraloader_11, 1),
)
cliptextencode = CLIPTextEncode()
cliptextencode_3 = cliptextencode.encode(
text=prompt, clip=get_value_at_index(loraloader_12, 1)
)
cliptextencode_4 = cliptextencode.encode(
text="", clip=get_value_at_index(loraloader_12, 1)
)
alignyourstepsscheduler = NODE_CLASS_MAPPINGS["AlignYourStepsScheduler"]()
alignyourstepsscheduler_5 = alignyourstepsscheduler.get_sigmas(
model_type="SD1", steps=25, denoise=1
)
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
ksamplerselect_6 = ksamplerselect.get_sampler(sampler_name="euler")
emptylatentimage = EmptyLatentImage()
emptylatentimage_7 = emptylatentimage.generate(
width=512, height=512, batch_size=batch_siz
)
samplercustom = NODE_CLASS_MAPPINGS["SamplerCustom"]()
vaedecode = VAEDecode()
saveimage = SaveImage()
for q in range(1):
samplercustom_2 = samplercustom.sample(
add_noise=True,
noise_seed=random.randint(1, 2**64),
cfg=cf,
model=get_value_at_index(checkpointloadersimple_1, 0),
positive=get_value_at_index(cliptextencode_3, 0),
negative=get_value_at_index(cliptextencode_4, 0),
sampler=get_value_at_index(ksamplerselect_6, 0),
sigmas=get_value_at_index(alignyourstepsscheduler_5, 0),
latent_image=get_value_at_index(emptylatentimage_7, 0),
)
vaedecode_8 = vaedecode.decode(
samples=get_value_at_index(samplercustom_2, 0),
vae=get_value_at_index(checkpointloadersimple_1, 2),
)
saveimage_9 = saveimage.save_images(
filename_prefix=lis[0], images=get_value_at_index(vaedecode_8, 0)
)
@app.route('/generate', methods=['POST'])
def generate():
data = request.json
prompt = data.get('prompt', '')
cfg = data.get('cfg', 1)
batch_size = data.get('batch_size', 1)
seed = data.get('seed', 0)
global response
response = {
"prompt": prompt,
"cfg": cfg,
"batch_size": batch_size,
"seed": seed
}
print(response)
main(prompt,cfg,batch_size)
query = lis[0]
directory = "/content/ShiroUI/output"
if not query or not directory or not os.path.isdir(directory):
return jsonify({"error": "Invalid query or directory"}), 400
matched_images = [
os.path.join("output", f) for f in os.listdir(directory)
if query in f and f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'))
]
if not matched_images:
return jsonify({"error": "No matching images found"}), 404
return jsonify(matched_images)
@app.route('/output/<path:filename>', methods=['GET'])
def get_image(filename):
directory = "/content/ShiroUI/output"
return send_from_directory(directory, filename)
if __name__ == '__main__':
app.run()
Metadata
Metadata
Assignees
Labels
No labels