File size: 4,020 Bytes
3f32750
81c68d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f32750
81c68d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f32750
 
81c68d9
 
 
 
 
 
 
 
 
 
 
 
 
 
3f32750
81c68d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f32750
 
81c68d9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from starvector.data.util import process_and_rasterize_svg
import torch
import io

USE_BOTH_MODELS = True  # Set this to True to load both models

# Load models at startup
models = {}
if USE_BOTH_MODELS:
    # Load 8b model
    model_name_8b = "starvector/starvector-8b-im2svg"
    models['8b'] = {
        'model': AutoModelForCausalLM.from_pretrained(model_name_8b, torch_dtype=torch.float16, trust_remote_code=True),
        'processor': None  # Will be set below
    }
    models['8b']['model'].cuda()
    models['8b']['model'].eval()
    models['8b']['processor'] = models['8b']['model'].model.processor

    # Load 1b model
    model_name_1b = "starvector/starvector-1b-im2svg"
    models['1b'] = {
        'model': AutoModelForCausalLM.from_pretrained(model_name_1b, torch_dtype=torch.float16, trust_remote_code=True),
        'processor': None
    }
    models['1b']['model'].cuda()
    models['1b']['model'].eval()
    models['1b']['processor'] = models['1b']['model'].model.processor
else:
    # Load only 8b model
    model_name = "starvector/starvector-8b-im2svg"
    models['8b'] = {
        'model': AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True),
        'processor': None
    }
    models['8b']['model'].cuda()
    models['8b']['model'].eval()
    models['8b']['processor'] = models['8b']['model'].model.processor

def convert_to_svg(image, model_choice):
    try:
        if image is None:
            return None, None, "Please upload an image first"
        
        # Select the model based on user choice
        selected_model = models[model_choice]['model']
        selected_processor = models[model_choice]['processor']
        
        # Process the uploaded image
        image_pil = Image.open(image)
        image_tensor = selected_processor(image_pil, return_tensors="pt")['pixel_values'].cuda()
        
        if not image_tensor.shape[0] == 1:
            image_tensor = image_tensor.squeeze(0)
        
        batch = {"image": image_tensor}
        
        # Generate SVG
        raw_svg = selected_model.generate_im2svg(batch, max_length=4000)[0]
        svg, raster_image = process_and_rasterize_svg(raw_svg)
        
        # Convert SVG string to bytes for download
        svg_bytes = io.BytesIO(svg.encode('utf-8'))
        
        return raster_image, svg_bytes, f"Conversion successful using {model_choice} model!"
    except Exception as e:
        return None, None, f"Error: {str(e)}"

# Create Blocks interface
with gr.Blocks(title="Image to SVG Converter") as demo:
    gr.Markdown("# Image to SVG Converter")
    gr.Markdown("Upload an image to convert it to SVG format using StarVector model")
    
    with gr.Row():
        with gr.Column(scale=1):
            # Input section
            input_image = gr.Image(type="filepath", label="Upload Image")
            if USE_BOTH_MODELS:
                model_choice = gr.Radio(
                    choices=["8b", "1b"],
                    value="8b",
                    label="Select Model",
                    info="Choose between 8b (larger) and 1b (smaller) models"
                )
            convert_btn = gr.Button("Convert to SVG")
            example = gr.Examples(
                examples=[["assets/examples/sample-18.png"]],
                inputs=input_image
            )
        
        with gr.Column(scale=1):
            # Output section
            output_preview = gr.Image(type="pil", label="Rasterized SVG Preview")
            output_file = gr.File(label="Download SVG")
            status = gr.Textbox(label="Status")
    
    # Connect button click to conversion function
    inputs = [input_image]
    if USE_BOTH_MODELS:
        inputs.append(model_choice)
        
    convert_btn.click(
        fn=convert_to_svg,
        inputs=inputs,
        outputs=[output_preview, output_file, status]
    )

# Launch the app
demo.launch()