baconnier commited on
Commit
4528f9c
·
verified ·
1 Parent(s): e4b9c9f

Update art_explorer.py

Browse files
Files changed (1) hide show
  1. art_explorer.py +12 -7
art_explorer.py CHANGED
@@ -1,5 +1,4 @@
1
  from typing import Optional, List, Dict
2
- import json
3
  import instructor
4
  from openai import OpenAI
5
  from prompts import SYSTEM_PROMPT, format_exploration_prompt, DEFAULT_RESPONSE
@@ -7,11 +6,16 @@ from models import ExplorationResponse
7
 
8
  class ExplorationPathGenerator:
9
  def __init__(self, api_key: str):
 
 
 
 
 
 
10
  self.client = instructor.patch(
11
- OpenAI(
12
- base_url="https://api.groq.com/openai/v1",
13
- api_key=api_key
14
- )
15
  )
16
 
17
  def generate_exploration_path(
@@ -36,8 +40,10 @@ class ExplorationPathGenerator:
36
  exploration_parameters=exploration_parameters
37
  )
38
 
 
39
  response = self.client.chat.completions.create(
40
  model="mixtral-8x7b-32768",
 
41
  messages=[
42
  {
43
  "role": "system",
@@ -49,8 +55,7 @@ class ExplorationPathGenerator:
49
  }
50
  ],
51
  temperature=0.7,
52
- max_tokens=2000,
53
- response_model=ExplorationResponse
54
  )
55
 
56
  return response.model_dump()
 
1
  from typing import Optional, List, Dict
 
2
  import instructor
3
  from openai import OpenAI
4
  from prompts import SYSTEM_PROMPT, format_exploration_prompt, DEFAULT_RESPONSE
 
6
 
7
  class ExplorationPathGenerator:
8
  def __init__(self, api_key: str):
9
+ # Create a client without instructor
10
+ client = OpenAI(
11
+ base_url="https://api.groq.com/openai/v1",
12
+ api_key=api_key
13
+ )
14
+ # Create an instructor client specifically for GROQ
15
  self.client = instructor.patch(
16
+ client=client,
17
+ mode="groq",
18
+ validate_responses=True
 
19
  )
20
 
21
  def generate_exploration_path(
 
40
  exploration_parameters=exploration_parameters
41
  )
42
 
43
+ # Use the GROQ-specific completion
44
  response = self.client.chat.completions.create(
45
  model="mixtral-8x7b-32768",
46
+ response_model=ExplorationResponse,
47
  messages=[
48
  {
49
  "role": "system",
 
55
  }
56
  ],
57
  temperature=0.7,
58
+ max_tokens=2000
 
59
  )
60
 
61
  return response.model_dump()