Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save juarezjuniorgithub/843fe984736cd0dcda2380a546fc7c02 to your computer and use it in GitHub Desktop.
Save juarezjuniorgithub/843fe984736cd0dcda2380a546fc7c02 to your computer and use it in GitHub Desktop.
HibernateVectorDataTypeOracleDatabase.java
/*
Copyright (c) 2024, Oracle and/or its affiliates.
This software is dual-licensed to you under the Universal Permissive License
(UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License
2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose
either license.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package com.oracle.dev.jdbc;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Paths;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
import java.util.stream.IntStream;
import org.hibernate.Session;
import org.hibernate.SessionFactory;
import org.hibernate.Transaction;
import org.hibernate.cfg.Configuration;
import org.hibernate.query.Query;
public class HibernateVectorDataTypeOracleDatabase {
private Configuration configuration;
private SessionFactory sessionFactory;
public static void main(String[] args) throws SQLException {
HibernateVectorDataTypeOracleDatabase example = new HibernateVectorDataTypeOracleDatabase();
example.initializeDatabase();
example.executeVectorSearchExamples();
example.cleanDatabase();
}
public HibernateVectorDataTypeOracleDatabase() {
Properties properties = loadProperties();
configuration = new Configuration().addProperties(properties)
.addAnnotatedClass(HibernateVectorDataTypeEntity.class);
sessionFactory = configuration.buildSessionFactory();
}
public void cosineSimilaritySearch(float[] searchVector) {
try (Session session = sessionFactory.openSession()) {
Query<HibernateVectorDataTypeEntity> query = session.createQuery(
"SELECT e FROM HibernateVectorDataTypeEntity e " + "ORDER BY cosine_distance(e.embedding, :vector) ASC",
HibernateVectorDataTypeEntity.class);
query.setParameter("vector", searchVector);
List<HibernateVectorDataTypeEntity> results = query.getResultList();
System.out.println("\nCosine Similarity Results:");
results.forEach(e -> System.out.println("ID: " + e.getId() + " - " + Arrays.toString(e.getEmbedding())));
}
}
public void l1DistanceSearch(float[] searchVector) {
try (Session session = sessionFactory.openSession()) {
Query<HibernateVectorDataTypeEntity> query = session.createQuery(
"SELECT e FROM HibernateVectorDataTypeEntity e " + "ORDER BY l1_distance(e.embedding, :vector) ASC",
HibernateVectorDataTypeEntity.class);
query.setParameter("vector", searchVector);
List<HibernateVectorDataTypeEntity> results = query.getResultList();
System.out.println("\nL1 Distance Results:");
results.forEach(e -> System.out.println("ID: " + e.getId() + " - " + Arrays.toString(e.getEmbedding())));
}
}
public void l2DistanceSearch(float[] searchVector) {
try (Session session = sessionFactory.openSession()) {
Query<HibernateVectorDataTypeEntity> query = session.createQuery(
"SELECT e FROM HibernateVectorDataTypeEntity e " + "ORDER BY l2_distance(e.embedding, :vector) ASC",
HibernateVectorDataTypeEntity.class);
query.setParameter("vector", searchVector);
List<HibernateVectorDataTypeEntity> results = query.getResultList();
System.out.println("\nL2 Distance Results:");
results.forEach(e -> System.out.println("ID: " + e.getId() + " - " + Arrays.toString(e.getEmbedding())));
}
}
private void insertVector() {
try (Session session = sessionFactory.openSession()) {
Transaction tx = session.beginTransaction();
HibernateVectorDataTypeEntity entity = new HibernateVectorDataTypeEntity();
float[] emb1 = { randFloat(), randFloat(), randFloat() };
float[] emb2 = { randFloat(), randFloat(), randFloat() };
entity.setEmbedding(emb1);
entity.setEmbeddingTwo(emb2);
session.persist(entity);
tx.commit();
}
}
public HibernateVectorDataTypeEntity readVector(long id) {
try (Session session = sessionFactory.openSession()) {
return session.get(HibernateVectorDataTypeEntity.class, id);
}
}
public void updateVector(HibernateVectorDataTypeEntity entity) {
try (Session session = sessionFactory.openSession()) {
Transaction transaction = session.beginTransaction();
session.merge(entity);
transaction.commit();
}
}
public void deleteVector(long id) {
try (Session session = sessionFactory.openSession()) {
Transaction transaction = session.beginTransaction();
HibernateVectorDataTypeEntity entity = session.get(HibernateVectorDataTypeEntity.class, id);
if (entity != null) {
session.remove(entity);
}
transaction.commit();
} finally {
sessionFactory.close();
}
sessionFactory = null;
configuration = null;
}
private void initializeDatabase() {
IntStream.rangeClosed(1, 100).forEach(i -> insertVector());
}
private void executeVectorSearchExamples() {
float[] searchVector = new float[] { 1.0f, 2.0f, 3.0f };
cosineSimilaritySearch(searchVector);
l1DistanceSearch(searchVector);
l2DistanceSearch(searchVector);
}
private float randFloat() {
// a value from a normal distribution centered at 0 with a higher spread
return (float) (Math.random() * 100 - 50); // between [-50, 50]
}
private void cleanDatabase() {
try (Session session = sessionFactory.openSession()) {
Transaction tx = session.beginTransaction();
session.createMutationQuery("DELETE FROM HibernateVectorDataTypeEntity").executeUpdate();
tx.commit();
}
sessionFactory.close();
}
private Properties loadProperties() {
Properties properties = new Properties();
try (FileInputStream inputStream = new FileInputStream(
Paths.get("src/main/resources/hibernate.properties").toFile())) {
properties.load(inputStream);
} catch (IOException e) {
e.printStackTrace();
}
return properties;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment