Skip to content

Commit

Permalink
Removing code, adding linter and fixing lints
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Oct 23, 2024
1 parent 63fb678 commit 4f163d6
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 160 deletions.
8 changes: 3 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@

build-wasm:
wasm-pack build --target no-modules --out-dir pkg crates/lora-inspector-wasm --release --weak-refs

test:
cargo test --workspace

build-wasm:
wasm-pack build --target no-modules --out-dir pkg crates/lora-inspector-wasm --release --weak-refs

build-dev-wasm:
wasm-pack build --target no-modules --out-dir pkg crates/lora-inspector-wasm --release --weak-refs

dev-wasm:
cd crates/lora-inspector-wasm && yarn vite

dev-wasm-2:
dev-wasm-cors:
cd crates/lora-inspector-wasm/ && python simple-cors-server.py


49 changes: 29 additions & 20 deletions crates/inspector/src/weight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,9 @@ impl Weight for BufferedLoRAWeight {
// .collect()
// }

fn dora_scale(&self, base_name: &str) -> Result<Tensor, candle_core::Error> {
self.get(&format!("{base_name}.dora_scale"))
}
// fn dora_scale(&self, base_name: &str) -> Result<Tensor, candle_core::Error> {
// self.get(&format!("{base_name}.dora_scale"))
// }

fn dims(&self) -> HashSet<usize> {
self.buffered
Expand All @@ -480,14 +480,10 @@ impl Weight for BufferedLoRAWeight {
} else if k.contains("lokr_w1") {
self.get(k).map(|v| v.dims()[0]).ok()
} else if k.contains("b1.weight") {
// dbg!(self.get(k).map(|v| v.dims().to_vec()).unwrap());
// dbg!(self.get(k).map(|v| v.dims()[0]).ok());
// self.get(k).map(|v| v.dims().last().copied()).ok().flatten()
self.get(k).map(|v| v.dims()[0]).ok()
} else if k.contains("oft_diag") {
self.get(k).map(|v| v.dims().last().copied()).ok().flatten()
} else if k.contains("oft_blocks") {
// dbg!(self.get(k).map(|v| v.dims().to_vec()).unwrap());
self.get(k).map(|v| v.dims().last().copied()).ok().flatten()
} else {
None
Expand Down Expand Up @@ -538,16 +534,37 @@ pub trait WeightKey {
pub trait Weight {
fn get(&self, key: &str) -> Result<Tensor, candle_core::Error>;

/// Most common precision datatype
fn precision(&self) -> Option<DType>;

/// Scale LoRA weights by the alpha and combine the A/B weights
///
/// # Errors
///
/// This function will return an error if tensor operations fail.
fn scale_lora_weight(&self, base_name: &str) -> Result<Tensor, candle_core::Error>;
fn scale_glora_weights(&self, base_name: &str) -> Result<Tensor, candle_core::Error>;
/// Scale the weights by the alpha and combine with the LoHa/Hada/FedPara weights
///
/// # Errors
///
/// This function will return an error if the tensor operations fail.
fn scale_hada_weight(&self, base_name: &str) -> Result<Tensor, candle_core::Error>;

/// Scale the weights by the alpha and combine with the LoKr weights
///
/// # Errors
///
/// This function will return an error if the tensor operations fail.
fn scale_lokr_weight(&self, base_name: &str) -> Result<Tensor, candle_core::Error>;

/// Unique alphas in the tensors
fn alphas(&self) -> HashSet<Alpha>;
// fn dora_scales(&self) -> Vec<Vec<f32>>;
fn dora_scale(&self, key: &str) -> Result<Tensor, candle_core::Error>;

/// Unique dimensions in the tensors
fn dims(&self) -> HashSet<usize>;

/// All shapes dimensions by HashMap of tensor modules
fn shapes(&self) -> HashMap<String, Vec<usize>>;
}

Expand Down Expand Up @@ -604,14 +621,6 @@ impl WeightKey for LoRAWeight {
self.keys_by_key("alpha")
}

// fn up_keys(&self) -> Vec<String> {
// self.keys_by_key("lora_up")
// }
//
// fn down_keys(&self) -> Vec<String> {
// self.keys_by_key("lora_down")
// }

fn base_names(&self) -> Vec<String> {
self.weight_keys()
.iter()
Expand Down Expand Up @@ -899,9 +908,9 @@ impl Weight for LoRAWeight {
})
}

fn dora_scale(&self, base_name: &str) -> Result<Tensor, candle_core::Error> {
self.get(&format!("{base_name}.dora_scale"))
}
// fn dora_scale(&self, base_name: &str) -> Result<Tensor, candle_core::Error> {
// self.get(&format!("{base_name}.dora_scale"))
// }

fn dims(&self) -> HashSet<usize> {
self.tensors
Expand Down
Loading

0 comments on commit 4f163d6

Please sign in to comment.