Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ itertools = "0.10.5"
ryu = "1.0"
fast-float = "0.2.0"
lasso = "0.7.2"
candle-core = "0.9.2"

[dependencies.flate2]
version = "1.1"
Expand Down
9 changes: 5 additions & 4 deletions src/algos/aggregator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Aggregators take models, feature embeddings, and a feature set and convert them into
//! embeddings. They are constructed adhoc so can be used in parallel as well as within the python
//! interface
use simple_grad::*;
use candle_core::{Device, Tensor};
use rand::prelude::*;
use rand_xorshift::XorShiftRng;

Expand Down Expand Up @@ -141,13 +141,14 @@ impl <'a> EmbeddingBuilder for AttentionAggregator<'a> {
out.fill(0f32);
let it = features.iter().map(|feat_id| {
let e = self.embs.get_embedding(*feat_id);
(Constant::new(e.to_vec()), 1f32)
(Tensor::from_slice(e, e.len(), &Device::Cpu).unwrap(), 1f32)
}).collect::<Vec<_>>();

// No-op RNG
let mut rng = XorShiftRng::seed_from_u64(0);
let v = attention_mean(it.iter(), &self.mha, &mut rng);
v.value().iter().zip(out.iter_mut()).for_each(|(vi, outi)| {
let result = attention_mean(it.iter(), &self.mha, &mut rng);
let v = result.to_vec1::<f32>().unwrap();
v.iter().zip(out.iter_mut()).for_each(|(vi, outi)| {
*outi = *vi;
});
}
Expand Down
64 changes: 42 additions & 22 deletions src/algos/emb_aligner.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,70 @@
use simple_grad::*;
use candle_core::{Device, Tensor, Var};

pub fn align_embedding(
embedding: &[f32],
t_embeddings: &[(&[f32], f32)],
alpha: f32,
eps: f32,
max_epochs: usize
max_epochs: usize,
) -> Vec<f32> {
use_shared_pool(false);
let device = Device::Cpu;

let embs: Vec<_> = t_embeddings.iter().map(|(e, _d)| {
Constant::new(e.to_vec())
}).collect();
let embs: Vec<_> = t_embeddings
.iter()
.map(|(e, _d)| Tensor::from_slice(e, e.len(), &device).unwrap())
.collect();

let mut new_emb = embedding.to_vec();
let mut last_err = std::f32::INFINITY;
let mut i = 0;
loop {
// Put the embedding in a variable for graph
let ne = Variable::pooled(new_emb.as_slice());
// Put the embedding in a variable for autograd
let ne = Var::from_slice(new_emb.as_slice(), new_emb.len(), &device).unwrap();

// Compute the distances between the current embedding and the neighbor embeddings
// we weight it based on position of the anchors
let buff: Vec<_> = embs.iter().zip(t_embeddings.iter()).enumerate().map(|(pos, (e, (_, d)))| {
let euc_dist = (e.clone() - &ne).pow(2f32).sum().sqrt();
(euc_dist - *d).pow(2.) / ((pos + 1) as f32).sqrt()
}).collect();
let buff: Vec<_> = embs
.iter()
.zip(t_embeddings.iter())
.enumerate()
.map(|(pos, (e, (_, d)))| {
let diff = ne.as_tensor().sub(e).unwrap();
let euc_dist = diff.powf(2.0).unwrap().sum_all().unwrap().sqrt().unwrap();
let d_tensor = Tensor::from_slice(&[*d], 1usize, &device).unwrap();
let diff_d = euc_dist.sub(&d_tensor).unwrap();
let divisor =
Tensor::from_slice(&[((pos + 1) as f64).sqrt()], 1usize, &device).unwrap();
diff_d.powf(2.0).unwrap().div(&divisor).unwrap()
})
.collect();

let loss = buff.sum_all();
let loss = buff
.iter()
.map(|t| t.clone())
.reduce(|a, b| a.add(&b).unwrap())
.unwrap();

let mut graph = Graph::new();
graph.backward(&loss);
let grad_store = loss.backward().unwrap();

// Simple Backpropagation
let grad = graph.get_grad(&ne).unwrap();
new_emb.iter_mut().zip(grad.iter()).for_each(|(ei, gi)| {
*ei -= alpha * *gi;
});
let grad = grad_store.get(&ne).unwrap();
let grad_vec = grad.to_vec1::<f32>().unwrap();
new_emb
.iter_mut()
.zip(grad_vec.iter())
.for_each(|(ei, gi)| {
*ei -= alpha * *gi;
});

let cur_err = loss.value()[0];
let cur_err = loss.to_vec0::<f32>().unwrap();
if (last_err - cur_err).abs() / last_err < eps {
break
break;
}
last_err = cur_err;
i += 1;
if i >= max_epochs { break }
if i >= max_epochs {
break;
}
}
new_emb
}
Loading