File size: 8,776 Bytes
5c05919
 
 
 
 
 
f5262f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c05919
f5262f7
5c05919
 
 
 
 
 
 
f5262f7
5c05919
 
 
f5262f7
5c05919
 
 
 
1b6c2fc
5c05919
 
 
 
 
 
 
4b2e52f
5c05919
4b2e52f
 
5c05919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755c458
5c05919
f5262f7
755c458
 
33a8649
755c458
 
 
 
 
 
 
 
 
 
 
 
 
5c05919
755c458
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import kagglehub
import os
import polars as pl
import gradio as gr
import google.generativeai as genai

# === Predefined schema from user ===
schema_dict = {
    "KernelTags": ['Id', 'KernelId', 'TagId'],
    "ModelVariations": ['Id', 'ModelId', 'CurrentVariationSlug', 'ModelFramework', 'CurrentModelVariationVersionId', 'LicenseName', 'BaseModelVariationId', 'CurrentDatasourceVersionId'],
    "KernelVersionCompetitionSources": ['Id', 'KernelVersionId', 'SourceCompetitionId'],
    "Datasets": ['Id', 'CreatorUserId', 'OwnerUserId', 'OwnerOrganizationId', 'CurrentDatasetVersionId', 'CurrentDatasourceVersionId', 'ForumId', 'Type', 'CreationDate', 'LastActivityDate', 'TotalViews', 'TotalDownloads', 'TotalVotes', 'TotalKernels', 'Medal', 'MedalAwardDate'],
    "KernelVersionKernelSources": ['Id', 'KernelVersionId', 'SourceKernelVersionId'],
    "KernelVotes": ['Id', 'UserId', 'KernelVersionId', 'VoteDate'],
    "Submissions": ['Id', 'SubmittedUserId', 'TeamId', 'SourceKernelVersionId', 'SubmissionDate', 'ScoreDate', 'IsAfterDeadline', 'IsSelected', 'PublicScoreLeaderboardDisplay', 'PublicScoreFullPrecision', 'PrivateScoreLeaderboardDisplay', 'PrivateScoreFullPrecision'],
    "KernelLanguages": ['Id', 'Name', 'DisplayName', 'IsNotebook'],
    "Users": ['Id', 'UserName', 'DisplayName', 'RegisterDate', 'PerformanceTier', 'Country', 'LocationSharingOptOut'],
    "ForumMessageVotes": ['Id', 'ForumMessageId', 'FromUserId', 'ToUserId', 'VoteDate'],
    "Competitions": ['Id', 'Slug', 'Title', 'Subtitle', 'HostSegmentTitle', 'ForumId', 'OrganizationId', 'EnabledDate', 'DeadlineDate', 'ProhibitNewEntrantsDeadlineDate', 'TeamMergerDeadlineDate', 'TeamModelDeadlineDate', 'ModelSubmissionDeadlineDate', 'FinalLeaderboardHasBeenVerified', 'HasKernels', 'OnlyAllowKernelSubmissions', 'HasLeaderboard', 'LeaderboardPercentage', 'ScoreTruncationNumDecimals', 'EvaluationAlgorithmAbbreviation', 'EvaluationAlgorithmName', 'EvaluationAlgorithmDescription', 'EvaluationAlgorithmIsMax', 'MaxDailySubmissions', 'NumScoredSubmissions', 'MaxTeamSize', 'BanTeamMergers', 'EnableTeamModels', 'RewardType', 'RewardQuantity', 'NumPrizes', 'UserRankMultiplier', 'CanQualifyTiers', 'TotalTeams', 'TotalCompetitors', 'TotalSubmissions', 'LicenseName', 'Overview', 'Rules', 'DatasetDescription', 'TotalCompressedBytes', 'TotalUncompressedBytes', 'ValidationSetName', 'ValidationSetValue', 'EnableSubmissionModelHashes', 'EnableSubmissionModelAttachments', 'HostName', 'CompetitionTypeId'],
    "DatasetTaskSubmissions": ['Id', 'DatasetTaskId', 'SubmittedUserId', 'CreationDate', 'KernelId', 'DatasetId', 'AcceptedDate'],
    "UserAchievements": ['Id', 'UserId', 'AchievementType', 'Tier', 'TierAchievementDate', 'Points', 'CurrentRanking', 'HighestRanking', 'TotalGold', 'TotalSilver', 'TotalBronze'],
    "UserOrganizations": ['Id', 'UserId', 'OrganizationId', 'JoinDate'],
    "Teams": ['Id', 'CompetitionId', 'TeamLeaderId', 'TeamName', 'ScoreFirstSubmittedDate', 'LastSubmissionDate', 'PublicLeaderboardSubmissionId', 'PrivateLeaderboardSubmissionId', 'IsBenchmark', 'Medal', 'MedalAwardDate', 'PublicLeaderboardRank', 'PrivateLeaderboardRank', 'WriteUpForumTopicId'],
    "UserFollowers": ['Id', 'UserId', 'FollowingUserId', 'CreationDate'],
    "CompetitionTags": ['Id', 'CompetitionId', 'TagId'],
    "Kernels": ['Id', 'AuthorUserId', 'CurrentKernelVersionId', 'ForkParentKernelVersionId', 'ForumTopicId', 'FirstKernelVersionId', 'CreationDate', 'EvaluationDate', 'MadePublicDate', 'IsProjectLanguageTemplate', 'CurrentUrlSlug', 'Medal', 'MedalAwardDate', 'TotalViews', 'TotalComments', 'TotalVotes'],
    "Organizations": ['Id', 'Name', 'Slug', 'CreationDate', 'Description'],
    "Datasources": ['Id', 'CreatorUserId', 'CreationDate', 'Type', 'CurrentDatasourceVersionId'],
    "ModelVersions": ['Id', 'ModelId', 'Title', 'Subtitle', 'ModelCard', 'CreationDate', 'OriginalPublishDate', 'CreatorUserId', 'ProvenanceSources'],
    "ForumTopics": ['Id', 'ForumId', 'KernelId', 'LastForumMessageId', 'FirstForumMessageId', 'CreationDate', 'LastCommentDate', 'Title', 'IsSticky', 'TotalViews', 'Score', 'TotalMessages', 'TotalReplies'],
    "DatasetVersions": ['Id', 'DatasetId', 'DatasourceVersionId', 'CreatorUserId', 'LicenseName', 'CreationDate', 'VersionNumber', 'Title', 'Slug', 'Subtitle', 'Description', 'VersionNotes', 'TotalCompressedBytes', 'TotalUncompressedBytes'],
    "ModelVotes": ['Id', 'UserId', 'ModelId', 'VoteDate'],
    "DatasetVotes": ['Id', 'UserId', 'DatasetVersionId', 'VoteDate'],
    "TeamMemberships": ['Id', 'TeamId', 'UserId', 'RequestDate'],
    "Forums": ['Id', 'ParentForumId', 'Title'],
    "KernelVersions": ['Id', 'ScriptId', 'ParentScriptVersionId', 'ScriptLanguageId', 'AuthorUserId', 'CreationDate', 'VersionNumber', 'Title', 'EvaluationDate', 'IsChange', 'TotalLines', 'LinesInsertedFromPrevious', 'LinesChangedFromPrevious', 'LinesUnchangedFromPrevious', 'LinesInsertedFromFork', 'LinesDeletedFromFork', 'LinesChangedFromFork', 'LinesUnchangedFromFork', 'TotalVotes', 'IsInternetEnabled', 'RunningTimeInMilliseconds', 'AcceleratorTypeId', 'DockerImage'],
    "ModelVariationVersions": ['Id', 'ModelVariationId', 'ModelVersionId', 'DatasourceVersionId', 'CreationDate', 'VariationOverview', 'VariationUsage', 'FineTunable', 'SourceUrl', 'SourceOrganizationName'],
    "ForumMessages": ['Id', 'ForumTopicId', 'PostUserId', 'PostDate', 'ReplyToForumMessageId', 'Message', 'RawMarkdown', 'Medal', 'MedalAwardDate'],
    "KernelVersionDatasetSources": ['Id', 'KernelVersionId', 'SourceDatasetVersionId'],
    "Episodes": ['Id', 'Type', 'CompetitionId', 'CreateTime', 'EndTime'],
    "EpisodeAgents": ['Id', 'EpisodeId', 'Index', 'Reward', 'State', 'SubmissionId', 'InitialConfidence', 'InitialScore', 'UpdatedConfidence', 'UpdatedScore'],
    "KernelAcceleratorTypes": ['Id', 'Label'],
    "KernelVersionModelSources": ['Id', 'KernelVersionId', 'SourceModelVariationVersionId', 'SourceModelVariationId'],
    "ForumMessageReactions": ['Id', 'ForumMessageId', 'FromUserId', 'ReactionType', 'ReactionDate'],
    "Tags": ['Id', 'ParentTagId', 'Name', 'Slug', 'FullPath', 'Description', 'DatasetCount', 'CompetitionCount', 'KernelCount'],
    "DatasetTasks": ['Id', 'DatasetId', 'OwnerUserId', 'CreationDate', 'Description', 'ForumId', 'Title', 'Subtitle', 'Deadline', 'TotalVotes'],
    "Models": ['Id', 'OwnerUserId', 'OwnerOrganizationId', 'CurrentModelVersionId', 'ForumId', 'CreationDate', 'TotalViews', 'TotalDownloads', 'TotalVotes', 'TotalKernels', 'CurrentSlug'],
    "DatasetTags": ['Id', 'DatasetId', 'TagId'],
    "ModelTags": ['Id', 'ModelId', 'TagId'],
}

# Skip file_map since no file reading is needed
file_map = {}

# === Build schema as prompt context ===
schema_description = "\n\n".join(
    [f"### {name}\n{', '.join(cols)}" for name, cols in schema_dict.items()]
)


context_prompt = f"""
You are a helpful assistant that helps users understand which parts of the Meta-Kaggle dataset they need for their analysis.

Below is the dataset schema (CSV files):

{schema_description}

When the user asks a question, respond with:
1. Which csv file(s) are needed
2. Which column(s) are relevant
3. If needed, describe any join keys (e.g., "CompetitionId", "UserId")
Do not generate or run any code. Just guide the user on what parts of the dataset are needed.
"""

# === Gemini setup ===
try:
    # Use environment variable for the API key
    GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
    if not GOOGLE_API_KEY:
        raise ValueError("GOOGLE_API_KEY environment variable not set")
    genai.configure(api_key=GOOGLE_API_KEY)
    model = genai.GenerativeModel("gemini-1.5-flash")
except Exception as e:
    print(f"Error setting up Gemini API: {e}")
    model = None

# === Analysis Guide Function ===
def guide_user(prompt):
    if model is None:
        return "Error: Gemini API not properly configured."
    full_prompt = context_prompt + f"\n\nUser question: {prompt}\n\nAnswer:"
    result = model.generate_content(full_prompt)
    return result.text.strip()



# === Launch Gradio UI ===
with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Meta-Kaggle Dataset Navigator")
        gr.Markdown("Ask which CSV files and columns you need for your analysis!")
        
        with gr.Row():
            input_text = gr.Textbox(
                label="Your Question", 
                placeholder="E.g., Which files and columns do I need to analyze competition rankings?",
                lines=2
            )
        with gr.Row():
            output_text = gr.Textbox(label="Guidance", lines=10, interactive=False)
        
        with gr.Row():
            submit_button = gr.Button("Get Guidance")
            submit_button.click(fn=guide_user, inputs=input_text, outputs=output_text)

demo.launch()