|
import kagglehub |
|
import os |
|
import polars as pl |
|
import gradio as gr |
|
import google.generativeai as genai |
|
|
|
|
|
schema_dict = { |
|
"KernelTags": ['Id', 'KernelId', 'TagId'], |
|
"ModelVariations": ['Id', 'ModelId', 'CurrentVariationSlug', 'ModelFramework', 'CurrentModelVariationVersionId', 'LicenseName', 'BaseModelVariationId', 'CurrentDatasourceVersionId'], |
|
"KernelVersionCompetitionSources": ['Id', 'KernelVersionId', 'SourceCompetitionId'], |
|
"Datasets": ['Id', 'CreatorUserId', 'OwnerUserId', 'OwnerOrganizationId', 'CurrentDatasetVersionId', 'CurrentDatasourceVersionId', 'ForumId', 'Type', 'CreationDate', 'LastActivityDate', 'TotalViews', 'TotalDownloads', 'TotalVotes', 'TotalKernels', 'Medal', 'MedalAwardDate'], |
|
"KernelVersionKernelSources": ['Id', 'KernelVersionId', 'SourceKernelVersionId'], |
|
"KernelVotes": ['Id', 'UserId', 'KernelVersionId', 'VoteDate'], |
|
"Submissions": ['Id', 'SubmittedUserId', 'TeamId', 'SourceKernelVersionId', 'SubmissionDate', 'ScoreDate', 'IsAfterDeadline', 'IsSelected', 'PublicScoreLeaderboardDisplay', 'PublicScoreFullPrecision', 'PrivateScoreLeaderboardDisplay', 'PrivateScoreFullPrecision'], |
|
"KernelLanguages": ['Id', 'Name', 'DisplayName', 'IsNotebook'], |
|
"Users": ['Id', 'UserName', 'DisplayName', 'RegisterDate', 'PerformanceTier', 'Country', 'LocationSharingOptOut'], |
|
"ForumMessageVotes": ['Id', 'ForumMessageId', 'FromUserId', 'ToUserId', 'VoteDate'], |
|
"Competitions": ['Id', 'Slug', 'Title', 'Subtitle', 'HostSegmentTitle', 'ForumId', 'OrganizationId', 'EnabledDate', 'DeadlineDate', 'ProhibitNewEntrantsDeadlineDate', 'TeamMergerDeadlineDate', 'TeamModelDeadlineDate', 'ModelSubmissionDeadlineDate', 'FinalLeaderboardHasBeenVerified', 'HasKernels', 'OnlyAllowKernelSubmissions', 'HasLeaderboard', 'LeaderboardPercentage', 'ScoreTruncationNumDecimals', 'EvaluationAlgorithmAbbreviation', 'EvaluationAlgorithmName', 'EvaluationAlgorithmDescription', 'EvaluationAlgorithmIsMax', 'MaxDailySubmissions', 'NumScoredSubmissions', 'MaxTeamSize', 'BanTeamMergers', 'EnableTeamModels', 'RewardType', 'RewardQuantity', 'NumPrizes', 'UserRankMultiplier', 'CanQualifyTiers', 'TotalTeams', 'TotalCompetitors', 'TotalSubmissions', 'LicenseName', 'Overview', 'Rules', 'DatasetDescription', 'TotalCompressedBytes', 'TotalUncompressedBytes', 'ValidationSetName', 'ValidationSetValue', 'EnableSubmissionModelHashes', 'EnableSubmissionModelAttachments', 'HostName', 'CompetitionTypeId'], |
|
"DatasetTaskSubmissions": ['Id', 'DatasetTaskId', 'SubmittedUserId', 'CreationDate', 'KernelId', 'DatasetId', 'AcceptedDate'], |
|
"UserAchievements": ['Id', 'UserId', 'AchievementType', 'Tier', 'TierAchievementDate', 'Points', 'CurrentRanking', 'HighestRanking', 'TotalGold', 'TotalSilver', 'TotalBronze'], |
|
"UserOrganizations": ['Id', 'UserId', 'OrganizationId', 'JoinDate'], |
|
"Teams": ['Id', 'CompetitionId', 'TeamLeaderId', 'TeamName', 'ScoreFirstSubmittedDate', 'LastSubmissionDate', 'PublicLeaderboardSubmissionId', 'PrivateLeaderboardSubmissionId', 'IsBenchmark', 'Medal', 'MedalAwardDate', 'PublicLeaderboardRank', 'PrivateLeaderboardRank', 'WriteUpForumTopicId'], |
|
"UserFollowers": ['Id', 'UserId', 'FollowingUserId', 'CreationDate'], |
|
"CompetitionTags": ['Id', 'CompetitionId', 'TagId'], |
|
"Kernels": ['Id', 'AuthorUserId', 'CurrentKernelVersionId', 'ForkParentKernelVersionId', 'ForumTopicId', 'FirstKernelVersionId', 'CreationDate', 'EvaluationDate', 'MadePublicDate', 'IsProjectLanguageTemplate', 'CurrentUrlSlug', 'Medal', 'MedalAwardDate', 'TotalViews', 'TotalComments', 'TotalVotes'], |
|
"Organizations": ['Id', 'Name', 'Slug', 'CreationDate', 'Description'], |
|
"Datasources": ['Id', 'CreatorUserId', 'CreationDate', 'Type', 'CurrentDatasourceVersionId'], |
|
"ModelVersions": ['Id', 'ModelId', 'Title', 'Subtitle', 'ModelCard', 'CreationDate', 'OriginalPublishDate', 'CreatorUserId', 'ProvenanceSources'], |
|
"ForumTopics": ['Id', 'ForumId', 'KernelId', 'LastForumMessageId', 'FirstForumMessageId', 'CreationDate', 'LastCommentDate', 'Title', 'IsSticky', 'TotalViews', 'Score', 'TotalMessages', 'TotalReplies'], |
|
"DatasetVersions": ['Id', 'DatasetId', 'DatasourceVersionId', 'CreatorUserId', 'LicenseName', 'CreationDate', 'VersionNumber', 'Title', 'Slug', 'Subtitle', 'Description', 'VersionNotes', 'TotalCompressedBytes', 'TotalUncompressedBytes'], |
|
"ModelVotes": ['Id', 'UserId', 'ModelId', 'VoteDate'], |
|
"DatasetVotes": ['Id', 'UserId', 'DatasetVersionId', 'VoteDate'], |
|
"TeamMemberships": ['Id', 'TeamId', 'UserId', 'RequestDate'], |
|
"Forums": ['Id', 'ParentForumId', 'Title'], |
|
"KernelVersions": ['Id', 'ScriptId', 'ParentScriptVersionId', 'ScriptLanguageId', 'AuthorUserId', 'CreationDate', 'VersionNumber', 'Title', 'EvaluationDate', 'IsChange', 'TotalLines', 'LinesInsertedFromPrevious', 'LinesChangedFromPrevious', 'LinesUnchangedFromPrevious', 'LinesInsertedFromFork', 'LinesDeletedFromFork', 'LinesChangedFromFork', 'LinesUnchangedFromFork', 'TotalVotes', 'IsInternetEnabled', 'RunningTimeInMilliseconds', 'AcceleratorTypeId', 'DockerImage'], |
|
"ModelVariationVersions": ['Id', 'ModelVariationId', 'ModelVersionId', 'DatasourceVersionId', 'CreationDate', 'VariationOverview', 'VariationUsage', 'FineTunable', 'SourceUrl', 'SourceOrganizationName'], |
|
"ForumMessages": ['Id', 'ForumTopicId', 'PostUserId', 'PostDate', 'ReplyToForumMessageId', 'Message', 'RawMarkdown', 'Medal', 'MedalAwardDate'], |
|
"KernelVersionDatasetSources": ['Id', 'KernelVersionId', 'SourceDatasetVersionId'], |
|
"Episodes": ['Id', 'Type', 'CompetitionId', 'CreateTime', 'EndTime'], |
|
"EpisodeAgents": ['Id', 'EpisodeId', 'Index', 'Reward', 'State', 'SubmissionId', 'InitialConfidence', 'InitialScore', 'UpdatedConfidence', 'UpdatedScore'], |
|
"KernelAcceleratorTypes": ['Id', 'Label'], |
|
"KernelVersionModelSources": ['Id', 'KernelVersionId', 'SourceModelVariationVersionId', 'SourceModelVariationId'], |
|
"ForumMessageReactions": ['Id', 'ForumMessageId', 'FromUserId', 'ReactionType', 'ReactionDate'], |
|
"Tags": ['Id', 'ParentTagId', 'Name', 'Slug', 'FullPath', 'Description', 'DatasetCount', 'CompetitionCount', 'KernelCount'], |
|
"DatasetTasks": ['Id', 'DatasetId', 'OwnerUserId', 'CreationDate', 'Description', 'ForumId', 'Title', 'Subtitle', 'Deadline', 'TotalVotes'], |
|
"Models": ['Id', 'OwnerUserId', 'OwnerOrganizationId', 'CurrentModelVersionId', 'ForumId', 'CreationDate', 'TotalViews', 'TotalDownloads', 'TotalVotes', 'TotalKernels', 'CurrentSlug'], |
|
"DatasetTags": ['Id', 'DatasetId', 'TagId'], |
|
"ModelTags": ['Id', 'ModelId', 'TagId'], |
|
} |
|
|
|
|
|
file_map = {} |
|
|
|
|
|
schema_description = "\n\n".join( |
|
[f"### {name}\n{', '.join(cols)}" for name, cols in schema_dict.items()] |
|
) |
|
|
|
|
|
context_prompt = f""" |
|
You are a helpful assistant that helps users understand which parts of the Meta-Kaggle dataset they need for their analysis. |
|
|
|
Below is the dataset schema (CSV files): |
|
|
|
{schema_description} |
|
|
|
When the user asks a question, respond with: |
|
1. Which csv file(s) are needed |
|
2. Which column(s) are relevant |
|
3. If needed, describe any join keys (e.g., "CompetitionId", "UserId") |
|
Do not generate or run any code. Just guide the user on what parts of the dataset are needed. |
|
""" |
|
|
|
|
|
try: |
|
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
if not GOOGLE_API_KEY: |
|
raise ValueError("GOOGLE_API_KEY environment variable not set") |
|
genai.configure(api_key=GOOGLE_API_KEY) |
|
model = genai.GenerativeModel("gemini-1.5-flash") |
|
except Exception as e: |
|
print(f"Error setting up Gemini API: {e}") |
|
model = None |
|
|
|
|
|
def guide_user(prompt): |
|
if model is None: |
|
return "Error: Gemini API not properly configured." |
|
full_prompt = context_prompt + f"\n\nUser question: {prompt}\n\nAnswer:" |
|
result = model.generate_content(full_prompt) |
|
return result.text.strip() |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.Markdown("# Meta-Kaggle Dataset Navigator") |
|
gr.Markdown("Ask which CSV files and columns you need for your analysis!") |
|
|
|
with gr.Row(): |
|
input_text = gr.Textbox( |
|
label="Your Question", |
|
placeholder="E.g., Which files and columns do I need to analyze competition rankings?", |
|
lines=2 |
|
) |
|
with gr.Row(): |
|
output_text = gr.Textbox(label="Guidance", lines=10, interactive=False) |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Get Guidance") |
|
submit_button.click(fn=guide_user, inputs=input_text, outputs=output_text) |
|
|
|
demo.launch() |
|
|