txt2im-models / app.py
dn6's picture
dn6 HF Staff
add application file
d35f312
raw
history blame
3.03 kB
import comet_ml
import uuid
import gradio as gr
class MyProject:
def __init__(self):
self.experiment = None
def start_experiment(
self, comet_api_key: str, comet_workspace: str, comet_project_name: str
):
if not comet_api_key:
return """
Please add your API key in order to log your predictions to a Comet Experiment.
If you don't have a Comet account yet, you can sign up here:
https://www.comet.ml/signup
"""
else:
try:
self.experiment = comet_ml.Experiment(
api_key=comet_api_key,
workspace=comet_workspace,
project_name=comet_project_name,
)
self.experiment.add_tags(["spaces"])
return f"Started {self.experiment.name}. Happy logging!😊"
except Exception as e:
return e
def end_experiment(self):
if self.experiment is not None:
self.experiment.end()
return f"Ended {self.experiment.name}"
def start_comet_interface(self):
demo = gr.Blocks()
with demo:
# credentials
comet_api_key = gr.Textbox(label="Comet API Key")
comet_workspace = gr.Textbox(label="Comet Workspace")
comet_project_name = gr.Textbox(label="Comet Project Name")
with gr.Row():
start_experiment = gr.Button("Start Experiment", variant="primary")
end_experiment = gr.Button("End Experiment", variant="secondary")
output = gr.Textbox(label="Status")
start_experiment.click(
self.start_experiment,
inputs=[
comet_api_key,
comet_workspace,
comet_project_name,
],
outputs=output,
)
end_experiment.click(self.end_experiment, inputs=None, outputs=output)
return demo
def predict(
self,
model,
prompt,
):
io = gr.Interface.load(model)
image = io(prompt)
if self.experiment is not None:
image_id = uuid.uuid4().hex
self.experiment.log_image(image, name=image_id)
self.experiment.log_text(
prompt, metadata={"image_id": image_id, "model": model}
)
return image
def load_interface(self):
model = gr.Textbox(label="Model", value="spaces/valhalla/glide-text2im")
prompt = gr.Textbox(label="Prompt")
outputs = gr.Image(label="Image")
interface = gr.Interface(self.predict, inputs=[model, prompt], outputs=outputs)
return interface
def launch(self):
interface = gr.TabbedInterface(
[self.start_comet_interface(), self.load_interface()],
tab_names=["Comet Settings", "Diffusion Model"],
)
interface.launch()
MyProject().launch()