Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
import os | |
import pymupdf # PyMuPDF | |
from pptx import Presentation | |
from sentence_transformers import SentenceTransformer | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
import chromadb | |
import numpy as np | |
from sklearn.decomposition import PCA | |
app = FastAPI() | |
# Initialize ChromaDB | |
client = chromadb.PersistentClient(path="/data/chroma_db") | |
collection = client.get_or_create_collection(name="knowledge_base") | |
# File Paths | |
pdf_file = "Sutures and Suturing techniques.pdf" | |
pptx_file = "impalnt 1.pptx" | |
# Initialize Embedding Models | |
text_model = SentenceTransformer('all-MiniLM-L6-v2') | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# Image Storage Folder | |
IMAGE_FOLDER = "/data/extracted_images" | |
os.makedirs(IMAGE_FOLDER, exist_ok=True) | |
# Extract Text from PDF | |
def extract_text_from_pdf(pdf_path): | |
try: | |
doc = pymupdf.open(pdf_path) | |
text = " ".join(page.get_text() for page in doc) | |
return text.strip() if text else None | |
except Exception as e: | |
print(f"Error extracting text from PDF: {e}") | |
return None | |
# Extract Text from PPTX | |
def extract_text_from_pptx(pptx_path): | |
try: | |
prs = Presentation(pptx_path) | |
text = " ".join( | |
shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text") | |
) | |
return text.strip() if text else None | |
except Exception as e: | |
print(f"Error extracting text from PPTX: {e}") | |
return None | |
# Extract Images from PDF | |
def extract_images_from_pdf(pdf_path): | |
try: | |
doc = pymupdf.open(pdf_path) | |
images = [] | |
for i, page in enumerate(doc): | |
for img_index, img in enumerate(page.get_images(full=True)): | |
xref = img[0] | |
image = doc.extract_image(xref) | |
img_path = f"{IMAGE_FOLDER}/pdf_image_{i}_{img_index}.{image['ext']}" | |
with open(img_path, "wb") as f: | |
f.write(image["image"]) | |
images.append(img_path) | |
return images | |
except Exception as e: | |
print(f"Error extracting images from PDF: {e}") | |
return [] | |
# Extract Images from PPTX | |
def extract_images_from_pptx(pptx_path): | |
try: | |
images = [] | |
prs = Presentation(pptx_path) | |
for i, slide in enumerate(prs.slides): | |
for shape in slide.shapes: | |
if shape.shape_type == 13: | |
img_path = f"{IMAGE_FOLDER}/pptx_image_{i}.{shape.image.ext}" | |
with open(img_path, "wb") as f: | |
f.write(shape.image.blob) | |
images.append(img_path) | |
return images | |
except Exception as e: | |
print(f"Error extracting images from PPTX: {e}") | |
return [] | |
# Convert Text to Embeddings | |
def get_text_embedding(text): | |
return text_model.encode(text).tolist() | |
# Extract Image Embeddings | |
def get_image_embedding(image_path): | |
try: | |
image = Image.open(image_path) | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
image_embedding = model.get_image_features(**inputs).numpy().flatten() | |
return image_embedding.tolist() | |
except Exception as e: | |
print(f"Error generating image embedding: {e}") | |
return None | |
# Reduce Embedding Dimensions (If Needed) | |
def reduce_embedding_dim(embeddings): | |
try: | |
embeddings = np.array(embeddings) | |
n_components = min(embeddings.shape[0], embeddings.shape[1], 384) # Ensure valid PCA size | |
pca = PCA(n_components=n_components) | |
return pca.fit_transform(embeddings).tolist() | |
except Exception as e: | |
print(f"Error in PCA transformation: {e}") | |
return embeddings.tolist() # Return original embeddings if PCA fails | |
# Store Data in ChromaDB | |
def store_data(texts, image_paths): | |
for i, text in enumerate(texts): | |
if text: | |
collection.add(ids=[f"text_{i}"], embeddings=[get_text_embedding(text)], documents=[text]) | |
all_embeddings = [get_image_embedding(img_path) for img_path in image_paths if get_image_embedding(img_path) is not None] | |
if all_embeddings: | |
all_embeddings = np.array(all_embeddings) | |
transformed_embeddings = reduce_embedding_dim(all_embeddings) if all_embeddings.shape[0] > 1 else all_embeddings.tolist() | |
for j, img_path in enumerate(image_paths): | |
collection.add(ids=[f"image_{j}"], embeddings=[transformed_embeddings[j]], documents=[img_path]) | |
print("Data stored successfully!") | |
# Process and Store from Files | |
def process_and_store(pdf_path=None, pptx_path=None): | |
texts, images = [], [] | |
if pdf_path: | |
pdf_text = extract_text_from_pdf(pdf_path) | |
if pdf_text: | |
texts.append(pdf_text) | |
images.extend(extract_images_from_pdf(pdf_path)) | |
if pptx_path: | |
pptx_text = extract_text_from_pptx(pptx_path) | |
if pptx_text: | |
texts.append(pptx_text) | |
images.extend(extract_images_from_pptx(pptx_path)) | |
store_data(texts, images) | |
# Run Data Processing | |
process_and_store(pdf_path=pdf_file, pptx_path=pptx_file) | |
# FastAPI Endpoints | |
def greet_json(): | |
return {"Hello": "World!"} | |
def greet_json(): | |
return {"Hello": "Redmind!"} | |
def search(query: str): | |
try: | |
query_embedding = get_text_embedding(query) | |
results = collection.query(query_embeddings=[query_embedding], n_results=5) | |
return {"results": results.get("documents", [])} | |
except Exception as e: | |
return {"error": str(e)} | |