KevinStephenson
Adding in weaviate code
b110593
raw
history blame
7.51 kB
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: [email protected]
//
package objects
import (
"context"
"errors"
"fmt"
"strings"
"github.com/go-openapi/strfmt"
"github.com/weaviate/weaviate/entities/additional"
"github.com/weaviate/weaviate/entities/filters"
"github.com/weaviate/weaviate/entities/models"
"github.com/weaviate/weaviate/entities/schema"
"github.com/weaviate/weaviate/entities/search"
)
// GetObject Class from the connected DB
func (m *Manager) GetObject(ctx context.Context, principal *models.Principal,
class string, id strfmt.UUID, additional additional.Properties,
replProps *additional.ReplicationProperties, tenant string,
) (*models.Object, error) {
path := fmt.Sprintf("objects/%s", id)
if class != "" {
path = fmt.Sprintf("objects/%s/%s", class, id)
}
err := m.authorizer.Authorize(principal, "get", path)
if err != nil {
return nil, err
}
unlock, err := m.locks.LockConnector()
if err != nil {
return nil, NewErrInternal("could not acquire lock: %v", err)
}
defer unlock()
m.metrics.GetObjectInc()
defer m.metrics.GetObjectDec()
res, err := m.getObjectFromRepo(ctx, class, id, additional, replProps, tenant)
if err != nil {
return nil, err
}
if additional.Vector {
m.trackUsageSingle(res)
}
return res.ObjectWithVector(additional.Vector), nil
}
// GetObjects Class from the connected DB
func (m *Manager) GetObjects(ctx context.Context, principal *models.Principal,
offset *int64, limit *int64, sort *string, order *string, after *string,
addl additional.Properties, tenant string,
) ([]*models.Object, error) {
err := m.authorizer.Authorize(principal, "list", "objects")
if err != nil {
return nil, err
}
unlock, err := m.locks.LockConnector()
if err != nil {
return nil, NewErrInternal("could not acquire lock: %v", err)
}
defer unlock()
m.metrics.GetObjectInc()
defer m.metrics.GetObjectDec()
return m.getObjectsFromRepo(ctx, offset, limit, sort, order, after, addl, tenant)
}
func (m *Manager) GetObjectsClass(ctx context.Context, principal *models.Principal,
id strfmt.UUID,
) (*models.Class, error) {
err := m.authorizer.Authorize(principal, "get", fmt.Sprintf("objects/%s", id.String()))
if err != nil {
return nil, err
}
unlock, err := m.locks.LockConnector()
if err != nil {
return nil, NewErrInternal("could not acquire lock: %v", err)
}
defer unlock()
m.metrics.GetObjectInc()
defer m.metrics.GetObjectDec()
res, err := m.getObjectFromRepo(ctx, "", id, additional.Properties{}, nil, "")
if err != nil {
return nil, err
}
s, err := m.schemaManager.GetSchema(principal)
if err != nil {
return nil, err
}
return s.GetClass(schema.ClassName(res.ClassName)), nil
}
func (m *Manager) GetObjectClassFromName(ctx context.Context, principal *models.Principal,
className string,
) (*models.Class, error) {
s, err := m.schemaManager.GetSchema(principal)
if err != nil {
return nil, err
}
return s.GetClass(schema.ClassName(className)), nil
}
func (m *Manager) getObjectFromRepo(ctx context.Context, class string, id strfmt.UUID,
adds additional.Properties, repl *additional.ReplicationProperties, tenant string,
) (res *search.Result, err error) {
if class != "" {
res, err = m.vectorRepo.Object(ctx, class, id, search.SelectProperties{}, adds, repl, tenant)
} else {
res, err = m.vectorRepo.ObjectByID(ctx, id, search.SelectProperties{}, adds, tenant)
}
if err != nil {
switch err.(type) {
case ErrMultiTenancy:
return nil, NewErrMultiTenancy(fmt.Errorf("repo: object by id: %w", err))
default:
return nil, NewErrInternal("repo: object by id: %v", err)
}
}
if res == nil {
return nil, NewErrNotFound("no object with id '%s'", id)
}
if m.modulesProvider != nil {
res, err = m.modulesProvider.GetObjectAdditionalExtend(ctx, res, adds.ModuleParams)
if err != nil {
return nil, fmt.Errorf("get extend: %v", err)
}
}
return res, nil
}
func (m *Manager) getObjectsFromRepo(ctx context.Context,
offset, limit *int64, sort, order *string, after *string,
additional additional.Properties, tenant string,
) ([]*models.Object, error) {
smartOffset, smartLimit, err := m.localOffsetLimit(offset, limit)
if err != nil {
return nil, NewErrInternal("list objects: %v", err)
}
if after != nil {
return nil, NewErrInternal("list objects: after parameter not allowed, cursor must be specific to one class, set class query param")
}
res, err := m.vectorRepo.ObjectSearch(ctx, smartOffset, smartLimit,
nil, m.getSort(sort, order), additional, tenant)
if err != nil {
return nil, NewErrInternal("list objects: %v", err)
}
if m.modulesProvider != nil {
res, err = m.modulesProvider.ListObjectsAdditionalExtend(ctx, res, additional.ModuleParams)
if err != nil {
return nil, NewErrInternal("list extend: %v", err)
}
}
if additional.Vector {
m.trackUsageList(res)
}
return res.ObjectsWithVector(additional.Vector), nil
}
func (m *Manager) getSort(sort, order *string) []filters.Sort {
if sort != nil {
sortParams := strings.Split(*sort, ",")
var orderParams []string
if order != nil {
orderParams = strings.Split(*order, ",")
}
var res []filters.Sort
for i := range sortParams {
res = append(res, filters.Sort{
Path: []string{sortParams[i]},
Order: m.getOrder(orderParams, i),
})
}
return res
}
return nil
}
func (m *Manager) getOrder(order []string, i int) string {
if len(order) > i {
switch order[i] {
case "asc", "desc":
return order[i]
default:
return "asc"
}
}
return "asc"
}
func (m *Manager) localOffsetOrZero(paramOffset *int64) int {
offset := int64(0)
if paramOffset != nil {
offset = *paramOffset
}
return int(offset)
}
func (m *Manager) localLimitOrGlobalLimit(offset int64, paramMaxResults *int64) int {
limit := int64(m.config.Config.QueryDefaults.Limit)
// Get the max results from params, if exists
if paramMaxResults != nil {
limit = *paramMaxResults
}
return int(limit)
}
func (m *Manager) localOffsetLimit(paramOffset *int64, paramLimit *int64) (int, int, error) {
offset := m.localOffsetOrZero(paramOffset)
limit := m.localLimitOrGlobalLimit(int64(offset), paramLimit)
if int64(offset+limit) > m.config.Config.QueryMaximumResults {
return 0, 0, errors.New("query maximum results exceeded")
}
return offset, limit, nil
}
func (m *Manager) trackUsageSingle(res *search.Result) {
if res == nil {
return
}
m.metrics.AddUsageDimensions(res.ClassName, "get_rest", "single_include_vector", res.Dims)
}
func (m *Manager) trackUsageList(res search.Results) {
if len(res) == 0 {
return
}
m.metrics.AddUsageDimensions(res[0].ClassName, "get_rest", "list_include_vector", res[0].Dims)
}
func (m *Manager) getCursor(after *string, limit *int64) *filters.Cursor {
if after != nil {
if limit == nil {
// limit -1 means that no limit param was set
return &filters.Cursor{After: *after, Limit: -1}
}
return &filters.Cursor{After: *after, Limit: int(*limit)}
}
return nil
}